Skip to content

Commit f2860fa

Browse files
authored
use utils._convert_to_result for huggingface_inference (#36593)
* fix(huggingface_inference): use utils._convert_to_result for batch processing The internal _convert_to_result function was incorrectly handling batches with multiple elements by wrapping predictions in a list. This caused all predictions to be grouped into a single result. Replace it with utils._convert_to_result which properly processes each element in the batch individually. Added test case to verify correct batch processing behavior. * test(huggingface): add batched examples test for tf inference Add test case to verify batch processing with tensorflow examples in huggingface inference
1 parent ef07e40 commit f2860fa

File tree

2 files changed

+29
-17
lines changed

2 files changed

+29
-17
lines changed

sdks/python/apache_beam/ml/inference/huggingface_inference.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -563,16 +563,6 @@ def get_metrics_namespace(self) -> str:
563563
return 'BeamML_HuggingFaceModelHandler_Tensor'
564564

565565

566-
def _convert_to_result(
567-
batch: Iterable,
568-
predictions: Union[Iterable, dict[Any, Iterable]],
569-
model_id: Optional[str] = None,
570-
) -> Iterable[PredictionResult]:
571-
return [
572-
PredictionResult(x, y, model_id) for x, y in zip(batch, [predictions])
573-
]
574-
575-
576566
def _default_pipeline_inference_fn(
577567
batch, pipeline, inference_args) -> Iterable[PredictionResult]:
578568
predicitons = pipeline(batch, **inference_args)
@@ -715,7 +705,7 @@ def run_inference(
715705
"""
716706
inference_args = {} if not inference_args else inference_args
717707
predictions = self._inference_fn(batch, pipeline, inference_args)
718-
return _convert_to_result(batch, predictions)
708+
return utils._convert_to_result(batch, predictions)
719709

720710
def update_model_path(self, model_path: Optional[str] = None):
721711
"""

sdks/python/apache_beam/ml/inference/huggingface_inference_test.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,34 @@ def test_framework_detection_tensorflow(self):
121121
inference_runner = HuggingFaceModelHandlerTensor(
122122
model_uri='unused',
123123
model_class=TFAutoModel,
124-
inference_fn=fake_inference_fn_tensor,
125-
inference_args={"add": True})
126-
batched_examples = [tf.constant([1]), tf.constant([10]), tf.constant([100])]
127-
inference_runner.run_inference(
128-
batched_examples, fake_model, inference_args={"add": True})
129-
self.assertEqual(inference_runner._framework, "tf")
124+
inference_fn=fake_inference_fn_tensor)
125+
batched_examples = [tf.constant(1), tf.constant(10), tf.constant(100)]
126+
inference_runner.run_inference(batched_examples, fake_model)
127+
self.assertEqual(inference_runner._framework, 'tf')
128+
129+
def test_convert_to_result_batch_processing(self):
130+
"""Test that utils._convert_to_result correctly handles
131+
batches with multiple elements."""
132+
133+
# Test case that reproduces the bug: batch size > 1
134+
batch = ["input1", "input2"]
135+
predictions = [{
136+
"translation_text": "output1"
137+
}, {
138+
"translation_text": "output2"
139+
}]
140+
141+
results = list(utils._convert_to_result(batch, predictions))
142+
143+
# Should return 2 results, not 1
144+
self.assertEqual(
145+
len(results), 2, "Should return one result per batch element")
146+
147+
# Check that each result has the correct input and output
148+
self.assertEqual(results[0].example, "input1")
149+
self.assertEqual(results[0].inference, {"translation_text": "output1"})
150+
self.assertEqual(results[1].example, "input2")
151+
self.assertEqual(results[1].inference, {"translation_text": "output2"})
130152

131153

132154
if __name__ == '__main__':

0 commit comments

Comments
 (0)