Skip to content

Commit f4f422c

Browse files
authored
[Integration] Expose length-aware batching in all ModelHandler subclasses (#37945)
* [Integration] Expose length-aware batching in all ModelHandler subclasses Completes the smart bucketing feature (#37531) by exposing batch_length_fn and batch_bucket_boundaries parameters across all concrete ModelHandler implementations. This allows users to enable length-aware batching on supported inference backends by passing these parameters directly to the handler constructor. - adds batch_length_fn / batch_bucket_boundaries to 16 handler classes - wires Gemini and Vertex AI batching params into _batching_kwargs - adds end-to-end RunInference coverage for length-aware batching - adds per-handler forwarding regression tests and fixes them to be hermetic * [Integration] Expose length-aware batching in all ModelHandler subclasses Completes the smart bucketing feature (#37531) by exposing batch_length_fn and batch_bucket_boundaries parameters across all concrete ModelHandler implementations. This allows users to enable length-aware batching on supported inference backends by passing these parameters directly to the handler constructor. - adds batch_length_fn / batch_bucket_boundaries to 16 handler classes - wires Gemini and Vertex AI batching params into _batching_kwargs - adds end-to-end RunInference coverage for length-aware batching - adds per-handler forwarding regression tests and fixes them to be hermetic * [Integration] Expose length-aware batching in all ModelHandler subclasses Completes the smart bucketing feature (#37531) by exposing batch_length_fn and batch_bucket_boundaries parameters across all concrete ModelHandler implementations. This allows users to enable length-aware batching on supported inference backends by passing these parameters directly to the handler constructor. - adds batch_length_fn / batch_bucket_boundaries to 16 handler classes - wires Gemini and Vertex AI batching params into _batching_kwargs - adds end-to-end RunInference coverage for length-aware batching - adds per-handler forwarding regression tests and fixes them to be hermetic * Remove redundant bucketing forwarding tests
1 parent 0b6a93f commit f4f422c

File tree

11 files changed

+167
-4
lines changed

11 files changed

+167
-4
lines changed

sdks/python/apache_beam/ml/inference/base_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import tempfile
2626
import time
2727
import unittest
28+
import unittest.mock
2829
from collections.abc import Iterable
2930
from collections.abc import Mapping
3031
from collections.abc import Sequence
@@ -2317,6 +2318,36 @@ def test_batching_kwargs_none_values_omitted(self):
23172318
self.assertEqual(kwargs['min_batch_size'], 5)
23182319

23192320

2321+
class PaddingReportingStringModelHandler(base.ModelHandler[str, str,
2322+
FakeModel]):
2323+
"""Reports each element with the max length of the batch it ran in."""
2324+
def load_model(self):
2325+
return FakeModel()
2326+
2327+
def run_inference(self, batch, model, inference_args=None):
2328+
max_len = max(len(s) for s in batch)
2329+
return [f'{s}:{max_len}' for s in batch]
2330+
2331+
2332+
class RunInferenceLengthAwareBatchingTest(unittest.TestCase):
2333+
"""End-to-end tests for PR2 length-aware batching in RunInference."""
2334+
def test_run_inference_with_length_aware_batch_elements(self):
2335+
handler = PaddingReportingStringModelHandler(
2336+
min_batch_size=2,
2337+
max_batch_size=2,
2338+
max_batch_duration_secs=60,
2339+
batch_length_fn=len,
2340+
batch_bucket_boundaries=[5])
2341+
2342+
examples = ['a', 'cccccc', 'bb', 'ddddddd']
2343+
with TestPipeline('FnApiRunner') as p:
2344+
results = (
2345+
p
2346+
| beam.Create(examples, reshuffle=False)
2347+
| base.RunInference(handler))
2348+
assert_that(results, equal_to(['a:2', 'bb:2', 'cccccc:7', 'ddddddd:7']))
2349+
2350+
23202351
class SimpleFakeModelHandler(base.ModelHandler[int, int, FakeModel]):
23212352
def load_model(self):
23222353
return FakeModel()

sdks/python/apache_beam/ml/inference/gemini_inference.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ def __init__(
117117
max_batch_duration_secs: Optional[int] = None,
118118
max_batch_weight: Optional[int] = None,
119119
element_size_fn: Optional[Callable[[Any], int]] = None,
120+
batch_length_fn: Optional[Callable[[Any], int]] = None,
121+
batch_bucket_boundaries: Optional[list[int]] = None,
120122
**kwargs):
121123
"""Implementation of the ModelHandler interface for Google Gemini.
122124
**NOTE:** This API and its implementation are under development and
@@ -158,6 +160,10 @@ def __init__(
158160
max_batch_weight: optional. the maximum total weight of a batch.
159161
element_size_fn: optional. a function that returns the size (weight)
160162
of an element.
163+
batch_length_fn: optional. a callable that returns the length of an
164+
element for length-aware batching.
165+
batch_bucket_boundaries: optional. a sorted list of positive boundary
166+
values for length-aware batching buckets.
161167
"""
162168
self._batching_kwargs = {}
163169
self._env_vars = kwargs.get('env_vars', {})
@@ -171,6 +177,10 @@ def __init__(
171177
self._batching_kwargs["max_batch_weight"] = max_batch_weight
172178
if element_size_fn is not None:
173179
self._batching_kwargs['element_size_fn'] = element_size_fn
180+
if batch_length_fn is not None:
181+
self._batching_kwargs['length_fn'] = batch_length_fn
182+
if batch_bucket_boundaries is not None:
183+
self._batching_kwargs['bucket_boundaries'] = batch_bucket_boundaries
174184

175185
self.model_name = model_name
176186
self.request_fn = request_fn

sdks/python/apache_beam/ml/inference/huggingface_inference.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,8 @@ def __init__(
229229
model_copies: Optional[int] = None,
230230
max_batch_weight: Optional[int] = None,
231231
element_size_fn: Optional[Callable[[Any], int]] = None,
232+
batch_length_fn: Optional[Callable[[Any], int]] = None,
233+
batch_bucket_boundaries: Optional[list[int]] = None,
232234
**kwargs):
233235
"""
234236
Implementation of the ModelHandler interface for HuggingFace with
@@ -266,6 +268,10 @@ def __init__(
266268
GPU capacity and want to maximize resource utilization.
267269
max_batch_weight: the maximum total weight of a batch.
268270
element_size_fn: a function that returns the size (weight) of an element.
271+
batch_length_fn: a callable that returns the length of an element for
272+
length-aware batching.
273+
batch_bucket_boundaries: a sorted list of positive boundary values for
274+
length-aware batching buckets.
269275
kwargs: 'env_vars' can be used to set environment variables
270276
before loading the model.
271277
@@ -278,6 +284,8 @@ def __init__(
278284
max_batch_duration_secs=max_batch_duration_secs,
279285
max_batch_weight=max_batch_weight,
280286
element_size_fn=element_size_fn,
287+
batch_length_fn=batch_length_fn,
288+
batch_bucket_boundaries=batch_bucket_boundaries,
281289
large_model=large_model,
282290
model_copies=model_copies,
283291
**kwargs)
@@ -411,6 +419,8 @@ def __init__(
411419
model_copies: Optional[int] = None,
412420
max_batch_weight: Optional[int] = None,
413421
element_size_fn: Optional[Callable[[Any], int]] = None,
422+
batch_length_fn: Optional[Callable[[Any], int]] = None,
423+
batch_bucket_boundaries: Optional[list[int]] = None,
414424
**kwargs):
415425
"""
416426
Implementation of the ModelHandler interface for HuggingFace with
@@ -448,6 +458,10 @@ def __init__(
448458
GPU capacity and want to maximize resource utilization.
449459
max_batch_weight: the maximum total weight of a batch.
450460
element_size_fn: a function that returns the size (weight) of an element.
461+
batch_length_fn: a callable that returns the length of an element for
462+
length-aware batching.
463+
batch_bucket_boundaries: a sorted list of positive boundary values for
464+
length-aware batching buckets.
451465
kwargs: 'env_vars' can be used to set environment variables
452466
before loading the model.
453467
@@ -460,6 +474,8 @@ def __init__(
460474
max_batch_duration_secs=max_batch_duration_secs,
461475
max_batch_weight=max_batch_weight,
462476
element_size_fn=element_size_fn,
477+
batch_length_fn=batch_length_fn,
478+
batch_bucket_boundaries=batch_bucket_boundaries,
463479
large_model=large_model,
464480
model_copies=model_copies,
465481
**kwargs)
@@ -576,6 +592,8 @@ def __init__(
576592
model_copies: Optional[int] = None,
577593
max_batch_weight: Optional[int] = None,
578594
element_size_fn: Optional[Callable[[Any], int]] = None,
595+
batch_length_fn: Optional[Callable[[Any], int]] = None,
596+
batch_bucket_boundaries: Optional[list[int]] = None,
579597
**kwargs):
580598
"""
581599
Implementation of the ModelHandler interface for Hugging Face Pipelines.
@@ -621,6 +639,10 @@ def __init__(
621639
GPU capacity and want to maximize resource utilization.
622640
max_batch_weight: the maximum total weight of a batch.
623641
element_size_fn: a function that returns the size (weight) of an element.
642+
batch_length_fn: a callable that returns the length of an element for
643+
length-aware batching.
644+
batch_bucket_boundaries: a sorted list of positive boundary values for
645+
length-aware batching buckets.
624646
kwargs: 'env_vars' can be used to set environment variables
625647
before loading the model.
626648
@@ -633,6 +655,8 @@ def __init__(
633655
max_batch_duration_secs=max_batch_duration_secs,
634656
max_batch_weight=max_batch_weight,
635657
element_size_fn=element_size_fn,
658+
batch_length_fn=batch_length_fn,
659+
batch_bucket_boundaries=batch_bucket_boundaries,
636660
large_model=large_model,
637661
model_copies=model_copies,
638662
**kwargs)

sdks/python/apache_beam/ml/inference/onnx_inference.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def __init__( #pylint: disable=dangerous-default-value
6868
max_batch_duration_secs: Optional[int] = None,
6969
max_batch_weight: Optional[int] = None,
7070
element_size_fn: Optional[Callable[[Any], int]] = None,
71+
batch_length_fn: Optional[Callable[[Any], int]] = None,
72+
batch_bucket_boundaries: Optional[list[int]] = None,
7173
**kwargs):
7274
""" Implementation of the ModelHandler interface for onnx
7375
using numpy arrays as input.
@@ -94,6 +96,10 @@ def __init__( #pylint: disable=dangerous-default-value
9496
before emitting; used in streaming contexts.
9597
max_batch_weight: the maximum total weight of a batch.
9698
element_size_fn: a function that returns the size (weight) of an element.
99+
batch_length_fn: a callable that returns the length of an element for
100+
length-aware batching.
101+
batch_bucket_boundaries: a sorted list of positive boundary values for
102+
length-aware batching buckets.
97103
kwargs: 'env_vars' can be used to set environment variables
98104
before loading the model.
99105
"""
@@ -103,6 +109,8 @@ def __init__( #pylint: disable=dangerous-default-value
103109
max_batch_duration_secs=max_batch_duration_secs,
104110
max_batch_weight=max_batch_weight,
105111
element_size_fn=element_size_fn,
112+
batch_length_fn=batch_length_fn,
113+
batch_bucket_boundaries=batch_bucket_boundaries,
106114
large_model=large_model,
107115
model_copies=model_copies,
108116
**kwargs)

sdks/python/apache_beam/ml/inference/pytorch_inference.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,8 @@ def __init__(
199199
load_model_args: Optional[dict[str, Any]] = None,
200200
max_batch_weight: Optional[int] = None,
201201
element_size_fn: Optional[Callable[[Any], int]] = None,
202+
batch_length_fn: Optional[Callable[[Any], int]] = None,
203+
batch_bucket_boundaries: Optional[list[int]] = None,
202204
**kwargs):
203205
"""Implementation of the ModelHandler interface for PyTorch.
204206
@@ -244,6 +246,10 @@ def __init__(
244246
function to specify custom config for loading models.
245247
max_batch_weight: the maximum total weight of a batch.
246248
element_size_fn: a function that returns the size (weight) of an element.
249+
batch_length_fn: a callable that returns the length of an element for
250+
length-aware batching.
251+
batch_bucket_boundaries: a sorted list of positive boundary values for
252+
length-aware batching buckets.
247253
kwargs: 'env_vars' can be used to set environment variables
248254
before loading the model.
249255
@@ -256,6 +262,8 @@ def __init__(
256262
max_batch_duration_secs=max_batch_duration_secs,
257263
max_batch_weight=max_batch_weight,
258264
element_size_fn=element_size_fn,
265+
batch_length_fn=batch_length_fn,
266+
batch_bucket_boundaries=batch_bucket_boundaries,
259267
large_model=large_model,
260268
model_copies=model_copies,
261269
**kwargs)
@@ -431,6 +439,8 @@ def __init__(
431439
load_model_args: Optional[dict[str, Any]] = None,
432440
max_batch_weight: Optional[int] = None,
433441
element_size_fn: Optional[Callable[[Any], int]] = None,
442+
batch_length_fn: Optional[Callable[[Any], int]] = None,
443+
batch_bucket_boundaries: Optional[list[int]] = None,
434444
**kwargs):
435445
"""Implementation of the ModelHandler interface for PyTorch.
436446
@@ -481,6 +491,10 @@ def __init__(
481491
function to specify custom config for loading models.
482492
max_batch_weight: the maximum total weight of a batch.
483493
element_size_fn: a function that returns the size (weight) of an element.
494+
batch_length_fn: a callable that returns the length of an element for
495+
length-aware batching.
496+
batch_bucket_boundaries: a sorted list of positive boundary values for
497+
length-aware batching buckets.
484498
kwargs: 'env_vars' can be used to set environment variables
485499
before loading the model.
486500
@@ -493,6 +507,8 @@ def __init__(
493507
max_batch_duration_secs=max_batch_duration_secs,
494508
max_batch_weight=max_batch_weight,
495509
element_size_fn=element_size_fn,
510+
batch_length_fn=batch_length_fn,
511+
batch_bucket_boundaries=batch_bucket_boundaries,
496512
large_model=large_model,
497513
model_copies=model_copies,
498514
**kwargs)

sdks/python/apache_beam/ml/inference/sklearn_inference.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ def __init__(
9595
model_copies: Optional[int] = None,
9696
max_batch_weight: Optional[int] = None,
9797
element_size_fn: Optional[Callable[[Any], int]] = None,
98+
batch_length_fn: Optional[Callable[[Any], int]] = None,
99+
batch_bucket_boundaries: Optional[list[int]] = None,
98100
**kwargs):
99101
""" Implementation of the ModelHandler interface for scikit-learn
100102
using numpy arrays as input.
@@ -126,6 +128,10 @@ def __init__(
126128
GPU capacity and want to maximize resource utilization.
127129
max_batch_weight: the maximum total weight of a batch.
128130
element_size_fn: a function that returns the size (weight) of an element.
131+
batch_length_fn: a callable that returns the length of an element for
132+
length-aware batching.
133+
batch_bucket_boundaries: a sorted list of positive boundary values for
134+
length-aware batching buckets.
129135
kwargs: 'env_vars' can be used to set environment variables
130136
before loading the model.
131137
"""
@@ -135,6 +141,8 @@ def __init__(
135141
max_batch_duration_secs=max_batch_duration_secs,
136142
max_batch_weight=max_batch_weight,
137143
element_size_fn=element_size_fn,
144+
batch_length_fn=batch_length_fn,
145+
batch_bucket_boundaries=batch_bucket_boundaries,
138146
large_model=large_model,
139147
model_copies=model_copies,
140148
**kwargs)
@@ -224,6 +232,8 @@ def __init__(
224232
model_copies: Optional[int] = None,
225233
max_batch_weight: Optional[int] = None,
226234
element_size_fn: Optional[Callable[[Any], int]] = None,
235+
batch_length_fn: Optional[Callable[[Any], int]] = None,
236+
batch_bucket_boundaries: Optional[list[int]] = None,
227237
**kwargs):
228238
"""Implementation of the ModelHandler interface for scikit-learn that
229239
supports pandas dataframes.
@@ -258,6 +268,10 @@ def __init__(
258268
GPU capacity and want to maximize resource utilization.
259269
max_batch_weight: the maximum total weight of a batch.
260270
element_size_fn: a function that returns the size (weight) of an element.
271+
batch_length_fn: a callable that returns the length of an element for
272+
length-aware batching.
273+
batch_bucket_boundaries: a sorted list of positive boundary values for
274+
length-aware batching buckets.
261275
kwargs: 'env_vars' can be used to set environment variables
262276
before loading the model.
263277
"""
@@ -267,6 +281,8 @@ def __init__(
267281
max_batch_duration_secs=max_batch_duration_secs,
268282
max_batch_weight=max_batch_weight,
269283
element_size_fn=element_size_fn,
284+
batch_length_fn=batch_length_fn,
285+
batch_bucket_boundaries=batch_bucket_boundaries,
270286
large_model=large_model,
271287
model_copies=model_copies,
272288
**kwargs)

sdks/python/apache_beam/ml/inference/tensorflow_inference.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ def __init__(
114114
model_copies: Optional[int] = None,
115115
max_batch_weight: Optional[int] = None,
116116
element_size_fn: Optional[Callable[[Any], int]] = None,
117+
batch_length_fn: Optional[Callable[[Any], int]] = None,
118+
batch_bucket_boundaries: Optional[list[int]] = None,
117119
**kwargs):
118120
"""Implementation of the ModelHandler interface for Tensorflow.
119121
@@ -145,6 +147,10 @@ def __init__(
145147
max_batch_weight: the maximum total weight of a batch.
146148
element_size_fn: a function that returns the size (weight) of an
147149
element.
150+
batch_length_fn: a callable that returns the length of an element for
151+
length-aware batching.
152+
batch_bucket_boundaries: a sorted list of positive boundary values for
153+
length-aware batching buckets.
148154
kwargs: 'env_vars' can be used to set environment variables
149155
before loading the model.
150156
@@ -157,6 +163,8 @@ def __init__(
157163
max_batch_duration_secs=max_batch_duration_secs,
158164
max_batch_weight=max_batch_weight,
159165
element_size_fn=element_size_fn,
166+
batch_length_fn=batch_length_fn,
167+
batch_bucket_boundaries=batch_bucket_boundaries,
160168
large_model=large_model,
161169
model_copies=model_copies,
162170
**kwargs)
@@ -242,6 +250,8 @@ def __init__(
242250
model_copies: Optional[int] = None,
243251
max_batch_weight: Optional[int] = None,
244252
element_size_fn: Optional[Callable[[Any], int]] = None,
253+
batch_length_fn: Optional[Callable[[Any], int]] = None,
254+
batch_bucket_boundaries: Optional[list[int]] = None,
245255
**kwargs):
246256
"""Implementation of the ModelHandler interface for Tensorflow.
247257
@@ -278,6 +288,10 @@ def __init__(
278288
max_batch_weight: the maximum total weight of a batch.
279289
element_size_fn: a function that returns the size (weight) of an
280290
element.
291+
batch_length_fn: a callable that returns the length of an element for
292+
length-aware batching.
293+
batch_bucket_boundaries: a sorted list of positive boundary values for
294+
length-aware batching buckets.
281295
kwargs: 'env_vars' can be used to set environment variables
282296
before loading the model.
283297
@@ -290,6 +304,8 @@ def __init__(
290304
max_batch_duration_secs=max_batch_duration_secs,
291305
max_batch_weight=max_batch_weight,
292306
element_size_fn=element_size_fn,
307+
batch_length_fn=batch_length_fn,
308+
batch_bucket_boundaries=batch_bucket_boundaries,
293309
large_model=large_model,
294310
model_copies=model_copies,
295311
**kwargs)

sdks/python/apache_beam/ml/inference/tensorrt_inference.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ def __init__(
232232
max_batch_duration_secs: Optional[int] = None,
233233
max_batch_weight: Optional[int] = None,
234234
element_size_fn: Optional[Callable[[Any], int]] = None,
235+
batch_length_fn: Optional[Callable[[Any], int]] = None,
236+
batch_bucket_boundaries: Optional[list[int]] = None,
235237
**kwargs):
236238
"""Implementation of the ModelHandler interface for TensorRT.
237239
@@ -262,6 +264,10 @@ def __init__(
262264
a batch before emitting; used in streaming contexts.
263265
max_batch_weight: the maximum total weight of a batch.
264266
element_size_fn: a function that returns the size (weight) of an element.
267+
batch_length_fn: a callable that returns the length of an element for
268+
length-aware batching.
269+
batch_bucket_boundaries: a sorted list of positive boundary values for
270+
length-aware batching buckets.
265271
kwargs: Additional arguments like 'engine_path' and 'onnx_path' are
266272
currently supported. 'env_vars' can be used to set environment variables
267273
before loading the model.
@@ -275,6 +281,8 @@ def __init__(
275281
max_batch_duration_secs=max_batch_duration_secs,
276282
max_batch_weight=max_batch_weight,
277283
element_size_fn=element_size_fn,
284+
batch_length_fn=batch_length_fn,
285+
batch_bucket_boundaries=batch_bucket_boundaries,
278286
large_model=large_model,
279287
model_copies=model_copies,
280288
**kwargs)

0 commit comments

Comments
 (0)