Skip to content

Commit 722dc6d

Browse files
authored
Allow inference args to be passed in for most cases (#37094)
* Allow inference args to be passed in for most cases * CHANGES * tests * yapf
1 parent 468f7d5 commit 722dc6d

File tree

8 files changed

+24
-24
lines changed

8 files changed

+24
-24
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
## New Features / Improvements
7474

7575
* Support configuring Firestore database on ReadFn transforms (Java) ([#36904](https://github.com/apache/beam/issues/36904)).
76+
* (Python) Inference args are now allowed in most model handlers, except where they are explicitly/intentionally disallowed ([#37093](https://github.com/apache/beam/issues/37093)).
7677

7778
## Breaking Changes
7879

sdks/python/apache_beam/ml/inference/base.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -213,15 +213,12 @@ def batch_elements_kwargs(self) -> Mapping[str, Any]:
213213
return {}
214214

215215
def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
216-
"""Validates inference_args passed in the inference call.
217-
218-
Because most frameworks do not need extra arguments in their predict() call,
219-
the default behavior is to error out if inference_args are present.
220216
"""
221-
if inference_args:
222-
raise ValueError(
223-
'inference_args were provided, but should be None because this '
224-
'framework does not expect extra arguments on inferences.')
217+
Allows model handlers to provide some validation to make sure passed in
218+
inference args are valid. Some ModelHandlers throw here to disallow
219+
inference args altogether.
220+
"""
221+
pass
225222

226223
def update_model_path(self, model_path: Optional[str] = None):
227224
"""

sdks/python/apache_beam/ml/inference/base_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,12 @@ def run_inference(self, batch, unused_model, inference_args=None):
293293
'run_inference should not be called because error should already be '
294294
'thrown from the validate_inference_args check.')
295295

296+
def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
297+
if inference_args:
298+
raise ValueError(
299+
'inference_args were provided, but should be None because this '
300+
'framework does not expect extra arguments on inferences.')
301+
296302

297303
class FakeModelHandlerExpectedInferenceArgs(FakeModelHandler):
298304
def run_inference(self, batch, unused_model, inference_args=None):

sdks/python/apache_beam/ml/inference/pytorch_inference.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -342,9 +342,6 @@ def get_metrics_namespace(self) -> str:
342342
"""
343343
return 'BeamML_PyTorch'
344344

345-
def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
346-
pass
347-
348345
def batch_elements_kwargs(self):
349346
return self._batching_kwargs
350347

@@ -590,9 +587,6 @@ def get_metrics_namespace(self) -> str:
590587
"""
591588
return 'BeamML_PyTorch'
592589

593-
def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
594-
pass
595-
596590
def batch_elements_kwargs(self):
597591
return self._batching_kwargs
598592

sdks/python/apache_beam/ml/inference/sklearn_inference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,10 @@ def _default_numpy_inference_fn(
7373
model: BaseEstimator,
7474
batch: Sequence[numpy.ndarray],
7575
inference_args: Optional[dict[str, Any]] = None) -> Any:
76+
inference_args = {} if not inference_args else inference_args
7677
# vectorize data for better performance
7778
vectorized_batch = numpy.stack(batch, axis=0)
78-
return model.predict(vectorized_batch)
79+
return model.predict(vectorized_batch, **inference_args)
7980

8081

8182
class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,

sdks/python/apache_beam/ml/inference/tensorflow_inference.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,6 @@ def get_metrics_namespace(self) -> str:
219219
"""
220220
return 'BeamML_TF_Numpy'
221221

222-
def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
223-
pass
224-
225222
def batch_elements_kwargs(self):
226223
return self._batching_kwargs
227224

@@ -360,9 +357,6 @@ def get_metrics_namespace(self) -> str:
360357
"""
361358
return 'BeamML_TF_Tensor'
362359

363-
def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
364-
pass
365-
366360
def batch_elements_kwargs(self):
367361
return self._batching_kwargs
368362

sdks/python/apache_beam/ml/inference/tensorrt_inference.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,3 +341,13 @@ def share_model_across_processes(self) -> bool:
341341

342342
def model_copies(self) -> int:
343343
return self._model_copies
344+
345+
def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
346+
"""
347+
Currently, this model handler does not support inference args. Given that,
348+
we will throw if any are passed in.
349+
"""
350+
if inference_args:
351+
raise ValueError(
352+
'inference_args were provided, but should be None because this '
353+
'framework does not expect extra arguments on inferences.')

sdks/python/apache_beam/ml/inference/vertex_ai_inference.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,5 @@ def request(
207207
return utils._convert_to_result(
208208
batch, prediction.predictions, prediction.deployed_model_id)
209209

210-
def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
211-
pass
212-
213210
def batch_elements_kwargs(self) -> Mapping[str, Any]:
214211
return self._batching_kwargs

0 commit comments

Comments
 (0)