Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import tempfile
import time
import unittest
import unittest.mock
from collections.abc import Iterable
from collections.abc import Mapping
from collections.abc import Sequence
Expand Down Expand Up @@ -2319,6 +2320,36 @@ def test_batching_kwargs_none_values_omitted(self):
self.assertEqual(kwargs['min_batch_size'], 5)


class PaddingReportingStringModelHandler(base.ModelHandler[str, str,
FakeModel]):
"""Reports each element with the max length of the batch it ran in."""
def load_model(self):
return FakeModel()

def run_inference(self, batch, model, inference_args=None):
max_len = max(len(s) for s in batch)
return [f'{s}:{max_len}' for s in batch]


class RunInferenceLengthAwareBatchingTest(unittest.TestCase):
"""End-to-end tests for PR2 length-aware batching in RunInference."""
def test_run_inference_with_length_aware_batch_elements(self):
handler = PaddingReportingStringModelHandler(
min_batch_size=2,
max_batch_size=2,
max_batch_duration_secs=60,
batch_length_fn=len,
batch_bucket_boundaries=[5])

examples = ['a', 'cccccc', 'bb', 'ddddddd']
with TestPipeline('FnApiRunner') as p:
results = (
p
| beam.Create(examples, reshuffle=False)
| base.RunInference(handler))
assert_that(results, equal_to(['a:2', 'bb:2', 'cccccc:7', 'ddddddd:7']))


class SimpleFakeModelHandler(base.ModelHandler[int, int, FakeModel]):
def load_model(self):
return FakeModel()
Expand Down
10 changes: 10 additions & 0 deletions sdks/python/apache_beam/ml/inference/gemini_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def __init__(
max_batch_duration_secs: Optional[int] = None,
max_batch_weight: Optional[int] = None,
element_size_fn: Optional[Callable[[Any], int]] = None,
batch_length_fn: Optional[Callable[[Any], int]] = None,
batch_bucket_boundaries: Optional[list[int]] = None,
**kwargs):
"""Implementation of the ModelHandler interface for Google Gemini.
**NOTE:** This API and its implementation are under development and
Expand Down Expand Up @@ -158,6 +160,10 @@ def __init__(
max_batch_weight: optional. the maximum total weight of a batch.
element_size_fn: optional. a function that returns the size (weight)
of an element.
batch_length_fn: optional. a callable that returns the length of an
element for length-aware batching.
batch_bucket_boundaries: optional. a sorted list of positive boundary
values for length-aware batching buckets.
"""
self._batching_kwargs = {}
self._env_vars = kwargs.get('env_vars', {})
Expand All @@ -171,6 +177,10 @@ def __init__(
self._batching_kwargs["max_batch_weight"] = max_batch_weight
if element_size_fn is not None:
self._batching_kwargs['element_size_fn'] = element_size_fn
if batch_length_fn is not None:
self._batching_kwargs['length_fn'] = batch_length_fn
if batch_bucket_boundaries is not None:
self._batching_kwargs['bucket_boundaries'] = batch_bucket_boundaries

self.model_name = model_name
self.request_fn = request_fn
Expand Down
24 changes: 24 additions & 0 deletions sdks/python/apache_beam/ml/inference/huggingface_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ def __init__(
model_copies: Optional[int] = None,
max_batch_weight: Optional[int] = None,
element_size_fn: Optional[Callable[[Any], int]] = None,
batch_length_fn: Optional[Callable[[Any], int]] = None,
batch_bucket_boundaries: Optional[list[int]] = None,
**kwargs):
"""
Implementation of the ModelHandler interface for HuggingFace with
Expand Down Expand Up @@ -266,6 +268,10 @@ def __init__(
GPU capacity and want to maximize resource utilization.
max_batch_weight: the maximum total weight of a batch.
element_size_fn: a function that returns the size (weight) of an element.
batch_length_fn: a callable that returns the length of an element for
length-aware batching.
batch_bucket_boundaries: a sorted list of positive boundary values for
length-aware batching buckets.
kwargs: 'env_vars' can be used to set environment variables
before loading the model.

Expand All @@ -278,6 +284,8 @@ def __init__(
max_batch_duration_secs=max_batch_duration_secs,
max_batch_weight=max_batch_weight,
element_size_fn=element_size_fn,
batch_length_fn=batch_length_fn,
batch_bucket_boundaries=batch_bucket_boundaries,
large_model=large_model,
model_copies=model_copies,
**kwargs)
Expand Down Expand Up @@ -411,6 +419,8 @@ def __init__(
model_copies: Optional[int] = None,
max_batch_weight: Optional[int] = None,
element_size_fn: Optional[Callable[[Any], int]] = None,
batch_length_fn: Optional[Callable[[Any], int]] = None,
batch_bucket_boundaries: Optional[list[int]] = None,
**kwargs):
"""
Implementation of the ModelHandler interface for HuggingFace with
Expand Down Expand Up @@ -448,6 +458,10 @@ def __init__(
GPU capacity and want to maximize resource utilization.
max_batch_weight: the maximum total weight of a batch.
element_size_fn: a function that returns the size (weight) of an element.
batch_length_fn: a callable that returns the length of an element for
length-aware batching.
batch_bucket_boundaries: a sorted list of positive boundary values for
length-aware batching buckets.
kwargs: 'env_vars' can be used to set environment variables
before loading the model.

Expand All @@ -460,6 +474,8 @@ def __init__(
max_batch_duration_secs=max_batch_duration_secs,
max_batch_weight=max_batch_weight,
element_size_fn=element_size_fn,
batch_length_fn=batch_length_fn,
batch_bucket_boundaries=batch_bucket_boundaries,
large_model=large_model,
model_copies=model_copies,
**kwargs)
Expand Down Expand Up @@ -576,6 +592,8 @@ def __init__(
model_copies: Optional[int] = None,
max_batch_weight: Optional[int] = None,
element_size_fn: Optional[Callable[[Any], int]] = None,
batch_length_fn: Optional[Callable[[Any], int]] = None,
batch_bucket_boundaries: Optional[list[int]] = None,
**kwargs):
"""
Implementation of the ModelHandler interface for Hugging Face Pipelines.
Expand Down Expand Up @@ -621,6 +639,10 @@ def __init__(
GPU capacity and want to maximize resource utilization.
max_batch_weight: the maximum total weight of a batch.
element_size_fn: a function that returns the size (weight) of an element.
batch_length_fn: a callable that returns the length of an element for
length-aware batching.
batch_bucket_boundaries: a sorted list of positive boundary values for
length-aware batching buckets.
kwargs: 'env_vars' can be used to set environment variables
before loading the model.

Expand All @@ -633,6 +655,8 @@ def __init__(
max_batch_duration_secs=max_batch_duration_secs,
max_batch_weight=max_batch_weight,
element_size_fn=element_size_fn,
batch_length_fn=batch_length_fn,
batch_bucket_boundaries=batch_bucket_boundaries,
large_model=large_model,
model_copies=model_copies,
**kwargs)
Expand Down
8 changes: 8 additions & 0 deletions sdks/python/apache_beam/ml/inference/onnx_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def __init__( #pylint: disable=dangerous-default-value
max_batch_duration_secs: Optional[int] = None,
max_batch_weight: Optional[int] = None,
element_size_fn: Optional[Callable[[Any], int]] = None,
batch_length_fn: Optional[Callable[[Any], int]] = None,
batch_bucket_boundaries: Optional[list[int]] = None,
**kwargs):
""" Implementation of the ModelHandler interface for onnx
using numpy arrays as input.
Expand All @@ -94,6 +96,10 @@ def __init__( #pylint: disable=dangerous-default-value
before emitting; used in streaming contexts.
max_batch_weight: the maximum total weight of a batch.
element_size_fn: a function that returns the size (weight) of an element.
batch_length_fn: a callable that returns the length of an element for
length-aware batching.
batch_bucket_boundaries: a sorted list of positive boundary values for
length-aware batching buckets.
kwargs: 'env_vars' can be used to set environment variables
before loading the model.
"""
Expand All @@ -103,6 +109,8 @@ def __init__( #pylint: disable=dangerous-default-value
max_batch_duration_secs=max_batch_duration_secs,
max_batch_weight=max_batch_weight,
element_size_fn=element_size_fn,
batch_length_fn=batch_length_fn,
batch_bucket_boundaries=batch_bucket_boundaries,
large_model=large_model,
model_copies=model_copies,
**kwargs)
Expand Down
16 changes: 16 additions & 0 deletions sdks/python/apache_beam/ml/inference/pytorch_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ def __init__(
load_model_args: Optional[dict[str, Any]] = None,
max_batch_weight: Optional[int] = None,
element_size_fn: Optional[Callable[[Any], int]] = None,
batch_length_fn: Optional[Callable[[Any], int]] = None,
batch_bucket_boundaries: Optional[list[int]] = None,
**kwargs):
"""Implementation of the ModelHandler interface for PyTorch.

Expand Down Expand Up @@ -244,6 +246,10 @@ def __init__(
function to specify custom config for loading models.
max_batch_weight: the maximum total weight of a batch.
element_size_fn: a function that returns the size (weight) of an element.
batch_length_fn: a callable that returns the length of an element for
length-aware batching.
batch_bucket_boundaries: a sorted list of positive boundary values for
length-aware batching buckets.
kwargs: 'env_vars' can be used to set environment variables
before loading the model.

Expand All @@ -256,6 +262,8 @@ def __init__(
max_batch_duration_secs=max_batch_duration_secs,
max_batch_weight=max_batch_weight,
element_size_fn=element_size_fn,
batch_length_fn=batch_length_fn,
batch_bucket_boundaries=batch_bucket_boundaries,
large_model=large_model,
model_copies=model_copies,
**kwargs)
Expand Down Expand Up @@ -431,6 +439,8 @@ def __init__(
load_model_args: Optional[dict[str, Any]] = None,
max_batch_weight: Optional[int] = None,
element_size_fn: Optional[Callable[[Any], int]] = None,
batch_length_fn: Optional[Callable[[Any], int]] = None,
batch_bucket_boundaries: Optional[list[int]] = None,
**kwargs):
"""Implementation of the ModelHandler interface for PyTorch.

Expand Down Expand Up @@ -481,6 +491,10 @@ def __init__(
function to specify custom config for loading models.
max_batch_weight: the maximum total weight of a batch.
element_size_fn: a function that returns the size (weight) of an element.
batch_length_fn: a callable that returns the length of an element for
length-aware batching.
batch_bucket_boundaries: a sorted list of positive boundary values for
length-aware batching buckets.
kwargs: 'env_vars' can be used to set environment variables
before loading the model.

Expand All @@ -493,6 +507,8 @@ def __init__(
max_batch_duration_secs=max_batch_duration_secs,
max_batch_weight=max_batch_weight,
element_size_fn=element_size_fn,
batch_length_fn=batch_length_fn,
batch_bucket_boundaries=batch_bucket_boundaries,
large_model=large_model,
model_copies=model_copies,
**kwargs)
Expand Down
16 changes: 16 additions & 0 deletions sdks/python/apache_beam/ml/inference/sklearn_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def __init__(
model_copies: Optional[int] = None,
max_batch_weight: Optional[int] = None,
element_size_fn: Optional[Callable[[Any], int]] = None,
batch_length_fn: Optional[Callable[[Any], int]] = None,
batch_bucket_boundaries: Optional[list[int]] = None,
**kwargs):
""" Implementation of the ModelHandler interface for scikit-learn
using numpy arrays as input.
Expand Down Expand Up @@ -126,6 +128,10 @@ def __init__(
GPU capacity and want to maximize resource utilization.
max_batch_weight: the maximum total weight of a batch.
element_size_fn: a function that returns the size (weight) of an element.
batch_length_fn: a callable that returns the length of an element for
length-aware batching.
batch_bucket_boundaries: a sorted list of positive boundary values for
length-aware batching buckets.
kwargs: 'env_vars' can be used to set environment variables
before loading the model.
"""
Expand All @@ -135,6 +141,8 @@ def __init__(
max_batch_duration_secs=max_batch_duration_secs,
max_batch_weight=max_batch_weight,
element_size_fn=element_size_fn,
batch_length_fn=batch_length_fn,
batch_bucket_boundaries=batch_bucket_boundaries,
large_model=large_model,
model_copies=model_copies,
**kwargs)
Expand Down Expand Up @@ -224,6 +232,8 @@ def __init__(
model_copies: Optional[int] = None,
max_batch_weight: Optional[int] = None,
element_size_fn: Optional[Callable[[Any], int]] = None,
batch_length_fn: Optional[Callable[[Any], int]] = None,
batch_bucket_boundaries: Optional[list[int]] = None,
**kwargs):
"""Implementation of the ModelHandler interface for scikit-learn that
supports pandas dataframes.
Expand Down Expand Up @@ -258,6 +268,10 @@ def __init__(
GPU capacity and want to maximize resource utilization.
max_batch_weight: the maximum total weight of a batch.
element_size_fn: a function that returns the size (weight) of an element.
batch_length_fn: a callable that returns the length of an element for
length-aware batching.
batch_bucket_boundaries: a sorted list of positive boundary values for
length-aware batching buckets.
kwargs: 'env_vars' can be used to set environment variables
before loading the model.
"""
Expand All @@ -267,6 +281,8 @@ def __init__(
max_batch_duration_secs=max_batch_duration_secs,
max_batch_weight=max_batch_weight,
element_size_fn=element_size_fn,
batch_length_fn=batch_length_fn,
batch_bucket_boundaries=batch_bucket_boundaries,
large_model=large_model,
model_copies=model_copies,
**kwargs)
Expand Down
16 changes: 16 additions & 0 deletions sdks/python/apache_beam/ml/inference/tensorflow_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def __init__(
model_copies: Optional[int] = None,
max_batch_weight: Optional[int] = None,
element_size_fn: Optional[Callable[[Any], int]] = None,
batch_length_fn: Optional[Callable[[Any], int]] = None,
batch_bucket_boundaries: Optional[list[int]] = None,
**kwargs):
"""Implementation of the ModelHandler interface for Tensorflow.

Expand Down Expand Up @@ -145,6 +147,10 @@ def __init__(
max_batch_weight: the maximum total weight of a batch.
element_size_fn: a function that returns the size (weight) of an
element.
batch_length_fn: a callable that returns the length of an element for
length-aware batching.
batch_bucket_boundaries: a sorted list of positive boundary values for
length-aware batching buckets.
kwargs: 'env_vars' can be used to set environment variables
before loading the model.

Expand All @@ -157,6 +163,8 @@ def __init__(
max_batch_duration_secs=max_batch_duration_secs,
max_batch_weight=max_batch_weight,
element_size_fn=element_size_fn,
batch_length_fn=batch_length_fn,
batch_bucket_boundaries=batch_bucket_boundaries,
large_model=large_model,
model_copies=model_copies,
**kwargs)
Expand Down Expand Up @@ -242,6 +250,8 @@ def __init__(
model_copies: Optional[int] = None,
max_batch_weight: Optional[int] = None,
element_size_fn: Optional[Callable[[Any], int]] = None,
batch_length_fn: Optional[Callable[[Any], int]] = None,
batch_bucket_boundaries: Optional[list[int]] = None,
**kwargs):
"""Implementation of the ModelHandler interface for Tensorflow.

Expand Down Expand Up @@ -278,6 +288,10 @@ def __init__(
max_batch_weight: the maximum total weight of a batch.
element_size_fn: a function that returns the size (weight) of an
element.
batch_length_fn: a callable that returns the length of an element for
length-aware batching.
batch_bucket_boundaries: a sorted list of positive boundary values for
length-aware batching buckets.
kwargs: 'env_vars' can be used to set environment variables
before loading the model.

Expand All @@ -290,6 +304,8 @@ def __init__(
max_batch_duration_secs=max_batch_duration_secs,
max_batch_weight=max_batch_weight,
element_size_fn=element_size_fn,
batch_length_fn=batch_length_fn,
batch_bucket_boundaries=batch_bucket_boundaries,
large_model=large_model,
model_copies=model_copies,
**kwargs)
Expand Down
8 changes: 8 additions & 0 deletions sdks/python/apache_beam/ml/inference/tensorrt_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ def __init__(
max_batch_duration_secs: Optional[int] = None,
max_batch_weight: Optional[int] = None,
element_size_fn: Optional[Callable[[Any], int]] = None,
batch_length_fn: Optional[Callable[[Any], int]] = None,
batch_bucket_boundaries: Optional[list[int]] = None,
**kwargs):
"""Implementation of the ModelHandler interface for TensorRT.

Expand Down Expand Up @@ -262,6 +264,10 @@ def __init__(
a batch before emitting; used in streaming contexts.
max_batch_weight: the maximum total weight of a batch.
element_size_fn: a function that returns the size (weight) of an element.
batch_length_fn: a callable that returns the length of an element for
length-aware batching.
batch_bucket_boundaries: a sorted list of positive boundary values for
length-aware batching buckets.
kwargs: Additional arguments like 'engine_path' and 'onnx_path' are
currently supported. 'env_vars' can be used to set environment variables
before loading the model.
Expand All @@ -275,6 +281,8 @@ def __init__(
max_batch_duration_secs=max_batch_duration_secs,
max_batch_weight=max_batch_weight,
element_size_fn=element_size_fn,
batch_length_fn=batch_length_fn,
batch_bucket_boundaries=batch_bucket_boundaries,
large_model=large_model,
model_copies=model_copies,
**kwargs)
Expand Down
Loading
Loading