Skip to content

Commit b2b8c70

Browse files
committed
[Integration] Expose length-aware batching in all ModelHandler subclasses
Completes the smart bucketing feature (#37531) by exposing batch_length_fn and batch_bucket_boundaries parameters across all concrete ModelHandler implementations. This allows users to enable length-aware batching on supported inference backends by passing these parameters directly to the handler constructor. - adds batch_length_fn / batch_bucket_boundaries to 16 handler classes - wires Gemini and Vertex AI batching params into _batching_kwargs - adds end-to-end RunInference coverage for length-aware batching - adds per-handler forwarding regression tests and fixes them to be hermetic
1 parent 90a7db1 commit b2b8c70

File tree

11 files changed

+397
-4
lines changed

11 files changed

+397
-4
lines changed

sdks/python/apache_beam/ml/inference/base_test.py

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#
1717

1818
"""Tests for apache_beam.ml.base."""
19+
import contextlib
20+
import importlib
1921
import math
2022
import multiprocessing
2123
import os
@@ -25,6 +27,7 @@
2527
import tempfile
2628
import time
2729
import unittest
30+
import unittest.mock
2831
from collections.abc import Iterable
2932
from collections.abc import Mapping
3033
from collections.abc import Sequence
@@ -2319,6 +2322,264 @@ def test_batching_kwargs_none_values_omitted(self):
23192322
self.assertEqual(kwargs['min_batch_size'], 5)
23202323

23212324

2325+
class PaddingReportingStringModelHandler(base.ModelHandler[str, str,
2326+
FakeModel]):
2327+
"""Reports each element with the max length of the batch it ran in."""
2328+
def load_model(self):
2329+
return FakeModel()
2330+
2331+
def run_inference(self, batch, model, inference_args=None):
2332+
max_len = max(len(s) for s in batch)
2333+
return [f'{s}:{max_len}' for s in batch]
2334+
2335+
2336+
class RunInferenceLengthAwareBatchingTest(unittest.TestCase):
2337+
"""End-to-end tests for PR2 length-aware batching in RunInference."""
2338+
def test_run_inference_with_length_aware_batch_elements(self):
2339+
handler = PaddingReportingStringModelHandler(
2340+
min_batch_size=2,
2341+
max_batch_size=2,
2342+
max_batch_duration_secs=60,
2343+
batch_length_fn=len,
2344+
batch_bucket_boundaries=[5])
2345+
2346+
examples = ['a', 'cccccc', 'bb', 'ddddddd']
2347+
with TestPipeline('FnApiRunner') as p:
2348+
results = (
2349+
p
2350+
| beam.Create(examples, reshuffle=False)
2351+
| base.RunInference(handler))
2352+
assert_that(results, equal_to(['a:2', 'bb:2', 'cccccc:7', 'ddddddd:7']))
2353+
2354+
2355+
class HandlerBucketingKwargsForwardingTest(unittest.TestCase):
2356+
"""Verify each concrete ModelHandler forwards batch_length_fn and
2357+
batch_bucket_boundaries through to batch_elements_kwargs()."""
2358+
_BUCKETING_KWARGS = {
2359+
'batch_length_fn': len,
2360+
'batch_bucket_boundaries': [32],
2361+
}
2362+
2363+
def _assert_bucketing_kwargs_forwarded(self, handler):
2364+
kwargs = handler.batch_elements_kwargs()
2365+
self.assertIs(kwargs['length_fn'], len)
2366+
self.assertEqual(kwargs['bucket_boundaries'], [32])
2367+
2368+
def _load_handler_class(self, case):
2369+
try:
2370+
module = importlib.import_module(case['module_name'])
2371+
except ImportError:
2372+
raise unittest.SkipTest(case['skip_message'])
2373+
return getattr(module, case['class_name'])
2374+
2375+
@contextlib.contextmanager
2376+
def _handler_setup(self, case):
2377+
if not case.get('mock_aiplatform'):
2378+
yield
2379+
return
2380+
2381+
with unittest.mock.patch(
2382+
'apache_beam.ml.inference.vertex_ai_inference.aiplatform') as mock_aip:
2383+
mock_aip.init.return_value = None
2384+
mock_endpoint = unittest.mock.MagicMock()
2385+
mock_endpoint.list_models.return_value = ['fake-model']
2386+
mock_aip.Endpoint.return_value = mock_endpoint
2387+
yield
2388+
2389+
def _assert_handler_cases(self, cases):
2390+
for case in cases:
2391+
with self.subTest(handler=case['name']):
2392+
handler_cls = self._load_handler_class(case)
2393+
init_kwargs = dict(case['init_kwargs'])
2394+
init_kwargs.update(self._BUCKETING_KWARGS)
2395+
2396+
with self._handler_setup(case):
2397+
handler = handler_cls(**init_kwargs)
2398+
2399+
self._assert_bucketing_kwargs_forwarded(handler)
2400+
2401+
def test_pytorch_handlers(self):
2402+
self._assert_handler_cases((
2403+
{
2404+
'name': 'pytorch_tensor',
2405+
'module_name': 'apache_beam.ml.inference.pytorch_inference',
2406+
'class_name': 'PytorchModelHandlerTensor',
2407+
'skip_message': 'PyTorch not available',
2408+
'init_kwargs': {},
2409+
},
2410+
{
2411+
'name': 'pytorch_keyed_tensor',
2412+
'module_name': 'apache_beam.ml.inference.pytorch_inference',
2413+
'class_name': 'PytorchModelHandlerKeyedTensor',
2414+
'skip_message': 'PyTorch not available',
2415+
'init_kwargs': {},
2416+
},
2417+
))
2418+
2419+
def test_huggingface_handlers(self):
2420+
self._assert_handler_cases((
2421+
{
2422+
'name': 'huggingface_keyed_tensor',
2423+
'module_name': 'apache_beam.ml.inference.huggingface_inference',
2424+
'class_name': 'HuggingFaceModelHandlerKeyedTensor',
2425+
'skip_message': 'HuggingFace transformers not available',
2426+
'init_kwargs': {
2427+
'model_uri': 'unused',
2428+
'model_class': object,
2429+
'framework': 'pt',
2430+
},
2431+
},
2432+
{
2433+
'name': 'huggingface_tensor',
2434+
'module_name': 'apache_beam.ml.inference.huggingface_inference',
2435+
'class_name': 'HuggingFaceModelHandlerTensor',
2436+
'skip_message': 'HuggingFace transformers not available',
2437+
'init_kwargs': {
2438+
'model_uri': 'unused',
2439+
'model_class': object,
2440+
},
2441+
},
2442+
{
2443+
'name': 'huggingface_pipeline',
2444+
'module_name': 'apache_beam.ml.inference.huggingface_inference',
2445+
'class_name': 'HuggingFacePipelineModelHandler',
2446+
'skip_message': 'HuggingFace transformers not available',
2447+
'init_kwargs': {
2448+
'task': 'text-classification',
2449+
},
2450+
},
2451+
))
2452+
2453+
def test_sklearn_handlers(self):
2454+
self._assert_handler_cases((
2455+
{
2456+
'name': 'sklearn_numpy',
2457+
'module_name': 'apache_beam.ml.inference.sklearn_inference',
2458+
'class_name': 'SklearnModelHandlerNumpy',
2459+
'skip_message': 'scikit-learn not available',
2460+
'init_kwargs': {
2461+
'model_uri': 'unused',
2462+
},
2463+
},
2464+
{
2465+
'name': 'sklearn_pandas',
2466+
'module_name': 'apache_beam.ml.inference.sklearn_inference',
2467+
'class_name': 'SklearnModelHandlerPandas',
2468+
'skip_message': 'scikit-learn not available',
2469+
'init_kwargs': {
2470+
'model_uri': 'unused',
2471+
},
2472+
},
2473+
))
2474+
2475+
def test_tensorflow_handlers(self):
2476+
self._assert_handler_cases((
2477+
{
2478+
'name': 'tensorflow_numpy',
2479+
'module_name': 'apache_beam.ml.inference.tensorflow_inference',
2480+
'class_name': 'TFModelHandlerNumpy',
2481+
'skip_message': 'TensorFlow not available',
2482+
'init_kwargs': {
2483+
'model_uri': 'unused',
2484+
},
2485+
},
2486+
{
2487+
'name': 'tensorflow_tensor',
2488+
'module_name': 'apache_beam.ml.inference.tensorflow_inference',
2489+
'class_name': 'TFModelHandlerTensor',
2490+
'skip_message': 'TensorFlow not available',
2491+
'init_kwargs': {
2492+
'model_uri': 'unused',
2493+
},
2494+
},
2495+
))
2496+
2497+
def test_onnx_handler(self):
2498+
self._assert_handler_cases(({
2499+
'name': 'onnx_numpy',
2500+
'module_name': 'apache_beam.ml.inference.onnx_inference',
2501+
'class_name': 'OnnxModelHandlerNumpy',
2502+
'skip_message': 'ONNX Runtime not available',
2503+
'init_kwargs': {
2504+
'model_uri': 'unused',
2505+
},
2506+
}, ))
2507+
2508+
def test_xgboost_handler(self):
2509+
self._assert_handler_cases(({
2510+
'name': 'xgboost_numpy',
2511+
'module_name': 'apache_beam.ml.inference.xgboost_inference',
2512+
'class_name': 'XGBoostModelHandlerNumpy',
2513+
'skip_message': 'XGBoost not available',
2514+
'init_kwargs': {
2515+
'model_class': object,
2516+
'model_state': 'unused',
2517+
},
2518+
}, ))
2519+
2520+
def test_tensorrt_handler(self):
2521+
self._assert_handler_cases(({
2522+
'name': 'tensorrt_numpy',
2523+
'module_name': 'apache_beam.ml.inference.tensorrt_inference',
2524+
'class_name': 'TensorRTEngineHandlerNumPy',
2525+
'skip_message': 'TensorRT not available',
2526+
'init_kwargs': {
2527+
'min_batch_size': 1,
2528+
'max_batch_size': 8,
2529+
},
2530+
}, ))
2531+
2532+
def test_vllm_handlers(self):
2533+
self._assert_handler_cases((
2534+
{
2535+
'name': 'vllm_completions',
2536+
'module_name': 'apache_beam.ml.inference.vllm_inference',
2537+
'class_name': 'VLLMCompletionsModelHandler',
2538+
'skip_message': 'vLLM not available',
2539+
'init_kwargs': {
2540+
'model_name': 'unused',
2541+
},
2542+
},
2543+
{
2544+
'name': 'vllm_chat',
2545+
'module_name': 'apache_beam.ml.inference.vllm_inference',
2546+
'class_name': 'VLLMChatModelHandler',
2547+
'skip_message': 'vLLM not available',
2548+
'init_kwargs': {
2549+
'model_name': 'unused',
2550+
},
2551+
},
2552+
))
2553+
2554+
def test_vertex_ai_handler(self):
2555+
self._assert_handler_cases(({
2556+
'name': 'vertex_ai',
2557+
'module_name': 'apache_beam.ml.inference.vertex_ai_inference',
2558+
'class_name': 'VertexAIModelHandlerJSON',
2559+
'skip_message': 'Vertex AI SDK not available',
2560+
'init_kwargs': {
2561+
'endpoint_id': 'unused',
2562+
'project': 'unused',
2563+
'location': 'unused',
2564+
},
2565+
'mock_aiplatform': True,
2566+
}, ))
2567+
2568+
def test_gemini_handler(self):
2569+
self._assert_handler_cases(({
2570+
'name': 'gemini',
2571+
'module_name': 'apache_beam.ml.inference.gemini_inference',
2572+
'class_name': 'GeminiModelHandler',
2573+
'skip_message': 'Google GenAI SDK not available',
2574+
'init_kwargs': {
2575+
'model_name': 'unused',
2576+
'request_fn': lambda *args: None,
2577+
'project': 'unused',
2578+
'location': 'unused',
2579+
},
2580+
}, ))
2581+
2582+
23222583
class SimpleFakeModelHandler(base.ModelHandler[int, int, FakeModel]):
23232584
def load_model(self):
23242585
return FakeModel()

sdks/python/apache_beam/ml/inference/gemini_inference.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ def __init__(
117117
max_batch_duration_secs: Optional[int] = None,
118118
max_batch_weight: Optional[int] = None,
119119
element_size_fn: Optional[Callable[[Any], int]] = None,
120+
batch_length_fn: Optional[Callable[[Any], int]] = None,
121+
batch_bucket_boundaries: Optional[list[int]] = None,
120122
**kwargs):
121123
"""Implementation of the ModelHandler interface for Google Gemini.
122124
**NOTE:** This API and its implementation are under development and
@@ -158,6 +160,10 @@ def __init__(
158160
max_batch_weight: optional. the maximum total weight of a batch.
159161
element_size_fn: optional. a function that returns the size (weight)
160162
of an element.
163+
batch_length_fn: optional. a callable that returns the length of an
164+
element for length-aware batching.
165+
batch_bucket_boundaries: optional. a sorted list of positive boundary
166+
values for length-aware batching buckets.
161167
"""
162168
self._batching_kwargs = {}
163169
self._env_vars = kwargs.get('env_vars', {})
@@ -171,6 +177,10 @@ def __init__(
171177
self._batching_kwargs["max_batch_weight"] = max_batch_weight
172178
if element_size_fn is not None:
173179
self._batching_kwargs['element_size_fn'] = element_size_fn
180+
if batch_length_fn is not None:
181+
self._batching_kwargs['length_fn'] = batch_length_fn
182+
if batch_bucket_boundaries is not None:
183+
self._batching_kwargs['bucket_boundaries'] = batch_bucket_boundaries
174184

175185
self.model_name = model_name
176186
self.request_fn = request_fn

sdks/python/apache_beam/ml/inference/huggingface_inference.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,8 @@ def __init__(
229229
model_copies: Optional[int] = None,
230230
max_batch_weight: Optional[int] = None,
231231
element_size_fn: Optional[Callable[[Any], int]] = None,
232+
batch_length_fn: Optional[Callable[[Any], int]] = None,
233+
batch_bucket_boundaries: Optional[list[int]] = None,
232234
**kwargs):
233235
"""
234236
Implementation of the ModelHandler interface for HuggingFace with
@@ -266,6 +268,10 @@ def __init__(
266268
GPU capacity and want to maximize resource utilization.
267269
max_batch_weight: the maximum total weight of a batch.
268270
element_size_fn: a function that returns the size (weight) of an element.
271+
batch_length_fn: a callable that returns the length of an element for
272+
length-aware batching.
273+
batch_bucket_boundaries: a sorted list of positive boundary values for
274+
length-aware batching buckets.
269275
kwargs: 'env_vars' can be used to set environment variables
270276
before loading the model.
271277
@@ -278,6 +284,8 @@ def __init__(
278284
max_batch_duration_secs=max_batch_duration_secs,
279285
max_batch_weight=max_batch_weight,
280286
element_size_fn=element_size_fn,
287+
batch_length_fn=batch_length_fn,
288+
batch_bucket_boundaries=batch_bucket_boundaries,
281289
large_model=large_model,
282290
model_copies=model_copies,
283291
**kwargs)
@@ -411,6 +419,8 @@ def __init__(
411419
model_copies: Optional[int] = None,
412420
max_batch_weight: Optional[int] = None,
413421
element_size_fn: Optional[Callable[[Any], int]] = None,
422+
batch_length_fn: Optional[Callable[[Any], int]] = None,
423+
batch_bucket_boundaries: Optional[list[int]] = None,
414424
**kwargs):
415425
"""
416426
Implementation of the ModelHandler interface for HuggingFace with
@@ -448,6 +458,10 @@ def __init__(
448458
GPU capacity and want to maximize resource utilization.
449459
max_batch_weight: the maximum total weight of a batch.
450460
element_size_fn: a function that returns the size (weight) of an element.
461+
batch_length_fn: a callable that returns the length of an element for
462+
length-aware batching.
463+
batch_bucket_boundaries: a sorted list of positive boundary values for
464+
length-aware batching buckets.
451465
kwargs: 'env_vars' can be used to set environment variables
452466
before loading the model.
453467
@@ -460,6 +474,8 @@ def __init__(
460474
max_batch_duration_secs=max_batch_duration_secs,
461475
max_batch_weight=max_batch_weight,
462476
element_size_fn=element_size_fn,
477+
batch_length_fn=batch_length_fn,
478+
batch_bucket_boundaries=batch_bucket_boundaries,
463479
large_model=large_model,
464480
model_copies=model_copies,
465481
**kwargs)
@@ -576,6 +592,8 @@ def __init__(
576592
model_copies: Optional[int] = None,
577593
max_batch_weight: Optional[int] = None,
578594
element_size_fn: Optional[Callable[[Any], int]] = None,
595+
batch_length_fn: Optional[Callable[[Any], int]] = None,
596+
batch_bucket_boundaries: Optional[list[int]] = None,
579597
**kwargs):
580598
"""
581599
Implementation of the ModelHandler interface for Hugging Face Pipelines.
@@ -621,6 +639,10 @@ def __init__(
621639
GPU capacity and want to maximize resource utilization.
622640
max_batch_weight: the maximum total weight of a batch.
623641
element_size_fn: a function that returns the size (weight) of an element.
642+
batch_length_fn: a callable that returns the length of an element for
643+
length-aware batching.
644+
batch_bucket_boundaries: a sorted list of positive boundary values for
645+
length-aware batching buckets.
624646
kwargs: 'env_vars' can be used to set environment variables
625647
before loading the model.
626648
@@ -633,6 +655,8 @@ def __init__(
633655
max_batch_duration_secs=max_batch_duration_secs,
634656
max_batch_weight=max_batch_weight,
635657
element_size_fn=element_size_fn,
658+
batch_length_fn=batch_length_fn,
659+
batch_bucket_boundaries=batch_bucket_boundaries,
636660
large_model=large_model,
637661
model_copies=model_copies,
638662
**kwargs)

0 commit comments

Comments
 (0)