Skip to content

Commit 5267548

Browse files
committed
[Integration] Expose length-aware batching in all ModelHandler subclasses
Completes the smart bucketing feature (#37531) by exposing batch_length_fn and batch_bucket_boundaries parameters across all concrete ModelHandler implementations. This allows users to enable length-aware batching on supported inference backends by passing these parameters directly to the handler constructor. - adds batch_length_fn / batch_bucket_boundaries to 16 handler classes - wires Gemini and Vertex AI batching params into _batching_kwargs - adds end-to-end RunInference coverage for length-aware batching - adds per-handler forwarding regression tests and fixes them to be hermetic
1 parent 90a7db1 commit 5267548

File tree

11 files changed

+408
-6
lines changed

11 files changed

+408
-6
lines changed

sdks/python/apache_beam/ml/inference/base_test.py

Lines changed: 272 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#
1717

1818
"""Tests for apache_beam.ml.base."""
19+
import contextlib
20+
import importlib
1921
import math
2022
import multiprocessing
2123
import os
@@ -25,16 +27,16 @@
2527
import tempfile
2628
import time
2729
import unittest
30+
import unittest.mock
2831
from collections.abc import Iterable
2932
from collections.abc import Mapping
3033
from collections.abc import Sequence
3134
from typing import Any
3235
from typing import Optional
3336
from typing import Union
3437

35-
import pytest
36-
3738
import apache_beam as beam
39+
import pytest
3840
from apache_beam.examples.inference import run_inference_side_inputs
3941
from apache_beam.metrics.metric import MetricsFilter
4042
from apache_beam.ml.inference import base
@@ -2319,6 +2321,274 @@ def test_batching_kwargs_none_values_omitted(self):
23192321
self.assertEqual(kwargs['min_batch_size'], 5)
23202322

23212323

2324+
class PaddingReportingStringModelHandler(base.ModelHandler[str, str,
2325+
FakeModel]):
2326+
"""Reports each element with the max length of the batch it ran in."""
2327+
def load_model(self):
2328+
return FakeModel()
2329+
2330+
def run_inference(self, batch, model, inference_args=None):
2331+
max_len = max(len(s) for s in batch)
2332+
return [f'{s}:{max_len}' for s in batch]
2333+
2334+
2335+
class RunInferenceLengthAwareBatchingTest(unittest.TestCase):
2336+
"""End-to-end tests for PR2 length-aware batching in RunInference."""
2337+
def test_run_inference_with_length_aware_batch_elements(self):
2338+
handler = PaddingReportingStringModelHandler(
2339+
min_batch_size=2,
2340+
max_batch_size=2,
2341+
max_batch_duration_secs=60,
2342+
batch_length_fn=len,
2343+
batch_bucket_boundaries=[5])
2344+
2345+
examples = ['a', 'cccccc', 'bb', 'ddddddd']
2346+
with TestPipeline('FnApiRunner') as p:
2347+
results = (
2348+
p
2349+
| beam.Create(examples, reshuffle=False)
2350+
| base.RunInference(handler))
2351+
assert_that(results, equal_to(['a:2', 'bb:2', 'cccccc:7', 'ddddddd:7']))
2352+
2353+
2354+
class HandlerBucketingKwargsForwardingTest(unittest.TestCase):
2355+
"""Verify each concrete ModelHandler forwards batch_length_fn and
2356+
batch_bucket_boundaries through to batch_elements_kwargs()."""
2357+
_BUCKETING_KWARGS = {
2358+
'batch_length_fn': len,
2359+
'batch_bucket_boundaries': [32],
2360+
}
2361+
2362+
def _assert_bucketing_kwargs_forwarded(self, handler):
2363+
kwargs = handler.batch_elements_kwargs()
2364+
self.assertIs(kwargs['length_fn'], len)
2365+
self.assertEqual(kwargs['bucket_boundaries'], [32])
2366+
2367+
def _load_handler_class(self, case):
2368+
try:
2369+
module = importlib.import_module(case['module_name'])
2370+
except ImportError:
2371+
raise unittest.SkipTest(case['skip_message'])
2372+
return getattr(module, case['class_name'])
2373+
2374+
@contextlib.contextmanager
2375+
def _handler_setup(self, case):
2376+
if not case.get('mock_aiplatform'):
2377+
yield
2378+
return
2379+
2380+
with unittest.mock.patch(
2381+
'apache_beam.ml.inference.vertex_ai_inference.aiplatform') as mock_aip:
2382+
mock_aip.init.return_value = None
2383+
mock_endpoint = unittest.mock.MagicMock()
2384+
mock_endpoint.list_models.return_value = ['fake-model']
2385+
mock_aip.Endpoint.return_value = mock_endpoint
2386+
yield
2387+
2388+
def _assert_handler_cases(self, cases):
2389+
for case in cases:
2390+
with self.subTest(handler=case['name']):
2391+
handler_cls = self._load_handler_class(case)
2392+
init_kwargs = dict(case['init_kwargs'])
2393+
init_kwargs.update(self._BUCKETING_KWARGS)
2394+
2395+
with self._handler_setup(case):
2396+
handler = handler_cls(**init_kwargs)
2397+
2398+
self._assert_bucketing_kwargs_forwarded(handler)
2399+
2400+
def test_pytorch_handlers(self):
2401+
self._assert_handler_cases((
2402+
{
2403+
'name': 'pytorch_tensor',
2404+
'module_name': 'apache_beam.ml.inference.pytorch_inference',
2405+
'class_name': 'PytorchModelHandlerTensor',
2406+
'skip_message': 'PyTorch not available',
2407+
'init_kwargs': {},
2408+
},
2409+
{
2410+
'name': 'pytorch_keyed_tensor',
2411+
'module_name': 'apache_beam.ml.inference.pytorch_inference',
2412+
'class_name': 'PytorchModelHandlerKeyedTensor',
2413+
'skip_message': 'PyTorch not available',
2414+
'init_kwargs': {},
2415+
},
2416+
))
2417+
2418+
def test_huggingface_handlers(self):
2419+
self._assert_handler_cases((
2420+
{
2421+
'name': 'huggingface_keyed_tensor',
2422+
'module_name': 'apache_beam.ml.inference.huggingface_inference',
2423+
'class_name': 'HuggingFaceModelHandlerKeyedTensor',
2424+
'skip_message': 'HuggingFace transformers not available',
2425+
'init_kwargs': {
2426+
'model_uri': 'unused',
2427+
'model_class': object,
2428+
'framework': 'pt',
2429+
},
2430+
},
2431+
{
2432+
'name': 'huggingface_tensor',
2433+
'module_name': 'apache_beam.ml.inference.huggingface_inference',
2434+
'class_name': 'HuggingFaceModelHandlerTensor',
2435+
'skip_message': 'HuggingFace transformers not available',
2436+
'init_kwargs': {
2437+
'model_uri': 'unused',
2438+
'model_class': object,
2439+
},
2440+
},
2441+
{
2442+
'name': 'huggingface_pipeline',
2443+
'module_name': 'apache_beam.ml.inference.huggingface_inference',
2444+
'class_name': 'HuggingFacePipelineModelHandler',
2445+
'skip_message': 'HuggingFace transformers not available',
2446+
'init_kwargs': {
2447+
'task': 'text-classification',
2448+
},
2449+
},
2450+
))
2451+
2452+
def test_sklearn_handlers(self):
2453+
self._assert_handler_cases((
2454+
{
2455+
'name': 'sklearn_numpy',
2456+
'module_name': 'apache_beam.ml.inference.sklearn_inference',
2457+
'class_name': 'SklearnModelHandlerNumpy',
2458+
'skip_message': 'scikit-learn not available',
2459+
'init_kwargs': {
2460+
'model_uri': 'unused',
2461+
},
2462+
},
2463+
{
2464+
'name': 'sklearn_pandas',
2465+
'module_name': 'apache_beam.ml.inference.sklearn_inference',
2466+
'class_name': 'SklearnModelHandlerPandas',
2467+
'skip_message': 'scikit-learn not available',
2468+
'init_kwargs': {
2469+
'model_uri': 'unused',
2470+
},
2471+
},
2472+
))
2473+
2474+
def test_tensorflow_handlers(self):
2475+
self._assert_handler_cases((
2476+
{
2477+
'name': 'tensorflow_numpy',
2478+
'module_name': 'apache_beam.ml.inference.tensorflow_inference',
2479+
'class_name': 'TFModelHandlerNumpy',
2480+
'skip_message': 'TensorFlow not available',
2481+
'init_kwargs': {
2482+
'model_uri': 'unused',
2483+
},
2484+
},
2485+
{
2486+
'name': 'tensorflow_tensor',
2487+
'module_name': 'apache_beam.ml.inference.tensorflow_inference',
2488+
'class_name': 'TFModelHandlerTensor',
2489+
'skip_message': 'TensorFlow not available',
2490+
'init_kwargs': {
2491+
'model_uri': 'unused',
2492+
},
2493+
},
2494+
))
2495+
2496+
def test_onnx_handler(self):
2497+
self._assert_handler_cases((
2498+
{
2499+
'name': 'onnx_numpy',
2500+
'module_name': 'apache_beam.ml.inference.onnx_inference',
2501+
'class_name': 'OnnxModelHandlerNumpy',
2502+
'skip_message': 'ONNX Runtime not available',
2503+
'init_kwargs': {
2504+
'model_uri': 'unused',
2505+
},
2506+
},
2507+
))
2508+
2509+
def test_xgboost_handler(self):
2510+
self._assert_handler_cases((
2511+
{
2512+
'name': 'xgboost_numpy',
2513+
'module_name': 'apache_beam.ml.inference.xgboost_inference',
2514+
'class_name': 'XGBoostModelHandlerNumpy',
2515+
'skip_message': 'XGBoost not available',
2516+
'init_kwargs': {
2517+
'model_class': object,
2518+
'model_state': 'unused',
2519+
},
2520+
},
2521+
))
2522+
2523+
def test_tensorrt_handler(self):
2524+
self._assert_handler_cases((
2525+
{
2526+
'name': 'tensorrt_numpy',
2527+
'module_name': 'apache_beam.ml.inference.tensorrt_inference',
2528+
'class_name': 'TensorRTEngineHandlerNumPy',
2529+
'skip_message': 'TensorRT not available',
2530+
'init_kwargs': {
2531+
'min_batch_size': 1,
2532+
'max_batch_size': 8,
2533+
},
2534+
},
2535+
))
2536+
2537+
def test_vllm_handlers(self):
2538+
self._assert_handler_cases((
2539+
{
2540+
'name': 'vllm_completions',
2541+
'module_name': 'apache_beam.ml.inference.vllm_inference',
2542+
'class_name': 'VLLMCompletionsModelHandler',
2543+
'skip_message': 'vLLM not available',
2544+
'init_kwargs': {
2545+
'model_name': 'unused',
2546+
},
2547+
},
2548+
{
2549+
'name': 'vllm_chat',
2550+
'module_name': 'apache_beam.ml.inference.vllm_inference',
2551+
'class_name': 'VLLMChatModelHandler',
2552+
'skip_message': 'vLLM not available',
2553+
'init_kwargs': {
2554+
'model_name': 'unused',
2555+
},
2556+
},
2557+
))
2558+
2559+
def test_vertex_ai_handler(self):
2560+
self._assert_handler_cases((
2561+
{
2562+
'name': 'vertex_ai',
2563+
'module_name': 'apache_beam.ml.inference.vertex_ai_inference',
2564+
'class_name': 'VertexAIModelHandlerJSON',
2565+
'skip_message': 'Vertex AI SDK not available',
2566+
'init_kwargs': {
2567+
'endpoint_id': 'unused',
2568+
'project': 'unused',
2569+
'location': 'unused',
2570+
},
2571+
'mock_aiplatform': True,
2572+
},
2573+
))
2574+
2575+
def test_gemini_handler(self):
2576+
self._assert_handler_cases((
2577+
{
2578+
'name': 'gemini',
2579+
'module_name': 'apache_beam.ml.inference.gemini_inference',
2580+
'class_name': 'GeminiModelHandler',
2581+
'skip_message': 'Google GenAI SDK not available',
2582+
'init_kwargs': {
2583+
'model_name': 'unused',
2584+
'request_fn': lambda *args: None,
2585+
'project': 'unused',
2586+
'location': 'unused',
2587+
},
2588+
},
2589+
))
2590+
2591+
23222592
class SimpleFakeModelHandler(base.ModelHandler[int, int, FakeModel]):
23232593
def load_model(self):
23242594
return FakeModel()

sdks/python/apache_beam/ml/inference/gemini_inference.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ def __init__(
117117
max_batch_duration_secs: Optional[int] = None,
118118
max_batch_weight: Optional[int] = None,
119119
element_size_fn: Optional[Callable[[Any], int]] = None,
120+
batch_length_fn: Optional[Callable[[Any], int]] = None,
121+
batch_bucket_boundaries: Optional[list[int]] = None,
120122
**kwargs):
121123
"""Implementation of the ModelHandler interface for Google Gemini.
122124
**NOTE:** This API and its implementation are under development and
@@ -158,6 +160,10 @@ def __init__(
158160
max_batch_weight: optional. the maximum total weight of a batch.
159161
element_size_fn: optional. a function that returns the size (weight)
160162
of an element.
163+
batch_length_fn: optional. a callable that returns the length of an
164+
element for length-aware batching.
165+
batch_bucket_boundaries: optional. a sorted list of positive boundary
166+
values for length-aware batching buckets.
161167
"""
162168
self._batching_kwargs = {}
163169
self._env_vars = kwargs.get('env_vars', {})
@@ -171,6 +177,10 @@ def __init__(
171177
self._batching_kwargs["max_batch_weight"] = max_batch_weight
172178
if element_size_fn is not None:
173179
self._batching_kwargs['element_size_fn'] = element_size_fn
180+
if batch_length_fn is not None:
181+
self._batching_kwargs['length_fn'] = batch_length_fn
182+
if batch_bucket_boundaries is not None:
183+
self._batching_kwargs['bucket_boundaries'] = batch_bucket_boundaries
174184

175185
self.model_name = model_name
176186
self.request_fn = request_fn

0 commit comments

Comments
 (0)