Skip to content

Commit e375b45

Browse files
committed
Enable model manager on the it test
1 parent 6a0bd89 commit e375b45

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

sdks/python/apache_beam/ml/inference/model_manager_it_test.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def test_sentiment_analysis_on_gpu_large_input(self):
6262

6363
pcoll = pipeline | 'CreateInputs' >> beam.Create(examples)
6464

65-
predictions = pcoll | 'RunInference' >> RunInference(model_handler)
65+
predictions = pcoll | 'RunInference' >> RunInference(
66+
model_handler, use_model_manager=True)
6667

6768
actual_labels = predictions | beam.Map(lambda x: x.inference['label'])
6869

@@ -112,7 +113,8 @@ def test_sentiment_analysis_large_roberta_gpu(self):
112113
] * DUPLICATE_FACTOR
113114

114115
pcoll = pipeline | 'CreateInputs' >> beam.Create(examples)
115-
predictions = pcoll | 'RunInference' >> RunInference(model_handler)
116+
predictions = pcoll | 'RunInference' >> RunInference(
117+
model_handler, use_model_manager=True)
116118
actual_labels = predictions | beam.Map(lambda x: x.inference['label'])
117119

118120
expected_labels = [
@@ -171,10 +173,12 @@ def test_parallel_inference_branches(self):
171173
inputs = pipeline | 'CreateInputs' >> beam.Create(examples)
172174
_ = (
173175
inputs
174-
| 'RunTranslation' >> RunInference(translator_handler)
176+
| 'RunTranslation' >> RunInference(
177+
translator_handler, use_model_manager=True)
175178
| 'ExtractSpanish' >>
176179
beam.Map(lambda x: x.inference['translation_text']))
177180
_ = (
178181
inputs
179-
| 'RunSentiment' >> RunInference(sentiment_handler)
182+
| 'RunSentiment' >> RunInference(
183+
sentiment_handler, use_model_manager=True)
180184
| 'ExtractLabel' >> beam.Map(lambda x: x.inference['label']))

0 commit comments

Comments
 (0)