|
16 | 16 | # |
17 | 17 |
|
18 | 18 | """Tests for apache_beam.ml.base.""" |
| 19 | +import contextlib |
| 20 | +import importlib |
19 | 21 | import math |
20 | 22 | import multiprocessing |
21 | 23 | import os |
|
25 | 27 | import tempfile |
26 | 28 | import time |
27 | 29 | import unittest |
| 30 | +import unittest.mock |
28 | 31 | from collections.abc import Iterable |
29 | 32 | from collections.abc import Mapping |
30 | 33 | from collections.abc import Sequence |
@@ -2319,6 +2322,264 @@ def test_batching_kwargs_none_values_omitted(self): |
2319 | 2322 | self.assertEqual(kwargs['min_batch_size'], 5) |
2320 | 2323 |
|
2321 | 2324 |
|
| 2325 | +class PaddingReportingStringModelHandler(base.ModelHandler[str, str, |
| 2326 | + FakeModel]): |
| 2327 | + """Reports each element with the max length of the batch it ran in.""" |
| 2328 | + def load_model(self): |
| 2329 | + return FakeModel() |
| 2330 | + |
| 2331 | + def run_inference(self, batch, model, inference_args=None): |
| 2332 | + max_len = max(len(s) for s in batch) |
| 2333 | + return [f'{s}:{max_len}' for s in batch] |
| 2334 | + |
| 2335 | + |
| 2336 | +class RunInferenceLengthAwareBatchingTest(unittest.TestCase): |
| 2337 | + """End-to-end tests for PR2 length-aware batching in RunInference.""" |
| 2338 | + def test_run_inference_with_length_aware_batch_elements(self): |
| 2339 | + handler = PaddingReportingStringModelHandler( |
| 2340 | + min_batch_size=2, |
| 2341 | + max_batch_size=2, |
| 2342 | + max_batch_duration_secs=60, |
| 2343 | + batch_length_fn=len, |
| 2344 | + batch_bucket_boundaries=[5]) |
| 2345 | + |
| 2346 | + examples = ['a', 'cccccc', 'bb', 'ddddddd'] |
| 2347 | + with TestPipeline('FnApiRunner') as p: |
| 2348 | + results = ( |
| 2349 | + p |
| 2350 | + | beam.Create(examples, reshuffle=False) |
| 2351 | + | base.RunInference(handler)) |
| 2352 | + assert_that(results, equal_to(['a:2', 'bb:2', 'cccccc:7', 'ddddddd:7'])) |
| 2353 | + |
| 2354 | + |
| 2355 | +class HandlerBucketingKwargsForwardingTest(unittest.TestCase): |
| 2356 | + """Verify each concrete ModelHandler forwards batch_length_fn and |
| 2357 | + batch_bucket_boundaries through to batch_elements_kwargs().""" |
| 2358 | + _BUCKETING_KWARGS = { |
| 2359 | + 'batch_length_fn': len, |
| 2360 | + 'batch_bucket_boundaries': [32], |
| 2361 | + } |
| 2362 | + |
| 2363 | + def _assert_bucketing_kwargs_forwarded(self, handler): |
| 2364 | + kwargs = handler.batch_elements_kwargs() |
| 2365 | + self.assertIs(kwargs['length_fn'], len) |
| 2366 | + self.assertEqual(kwargs['bucket_boundaries'], [32]) |
| 2367 | + |
| 2368 | + def _load_handler_class(self, case): |
| 2369 | + try: |
| 2370 | + module = importlib.import_module(case['module_name']) |
| 2371 | + except ImportError: |
| 2372 | + raise unittest.SkipTest(case['skip_message']) |
| 2373 | + return getattr(module, case['class_name']) |
| 2374 | + |
| 2375 | + @contextlib.contextmanager |
| 2376 | + def _handler_setup(self, case): |
| 2377 | + if not case.get('mock_aiplatform'): |
| 2378 | + yield |
| 2379 | + return |
| 2380 | + |
| 2381 | + with unittest.mock.patch( |
| 2382 | + 'apache_beam.ml.inference.vertex_ai_inference.aiplatform') as mock_aip: |
| 2383 | + mock_aip.init.return_value = None |
| 2384 | + mock_endpoint = unittest.mock.MagicMock() |
| 2385 | + mock_endpoint.list_models.return_value = ['fake-model'] |
| 2386 | + mock_aip.Endpoint.return_value = mock_endpoint |
| 2387 | + yield |
| 2388 | + |
| 2389 | + def _assert_handler_cases(self, cases): |
| 2390 | + for case in cases: |
| 2391 | + with self.subTest(handler=case['name']): |
| 2392 | + handler_cls = self._load_handler_class(case) |
| 2393 | + init_kwargs = dict(case['init_kwargs']) |
| 2394 | + init_kwargs.update(self._BUCKETING_KWARGS) |
| 2395 | + |
| 2396 | + with self._handler_setup(case): |
| 2397 | + handler = handler_cls(**init_kwargs) |
| 2398 | + |
| 2399 | + self._assert_bucketing_kwargs_forwarded(handler) |
| 2400 | + |
| 2401 | + def test_pytorch_handlers(self): |
| 2402 | + self._assert_handler_cases(( |
| 2403 | + { |
| 2404 | + 'name': 'pytorch_tensor', |
| 2405 | + 'module_name': 'apache_beam.ml.inference.pytorch_inference', |
| 2406 | + 'class_name': 'PytorchModelHandlerTensor', |
| 2407 | + 'skip_message': 'PyTorch not available', |
| 2408 | + 'init_kwargs': {}, |
| 2409 | + }, |
| 2410 | + { |
| 2411 | + 'name': 'pytorch_keyed_tensor', |
| 2412 | + 'module_name': 'apache_beam.ml.inference.pytorch_inference', |
| 2413 | + 'class_name': 'PytorchModelHandlerKeyedTensor', |
| 2414 | + 'skip_message': 'PyTorch not available', |
| 2415 | + 'init_kwargs': {}, |
| 2416 | + }, |
| 2417 | + )) |
| 2418 | + |
| 2419 | + def test_huggingface_handlers(self): |
| 2420 | + self._assert_handler_cases(( |
| 2421 | + { |
| 2422 | + 'name': 'huggingface_keyed_tensor', |
| 2423 | + 'module_name': 'apache_beam.ml.inference.huggingface_inference', |
| 2424 | + 'class_name': 'HuggingFaceModelHandlerKeyedTensor', |
| 2425 | + 'skip_message': 'HuggingFace transformers not available', |
| 2426 | + 'init_kwargs': { |
| 2427 | + 'model_uri': 'unused', |
| 2428 | + 'model_class': object, |
| 2429 | + 'framework': 'pt', |
| 2430 | + }, |
| 2431 | + }, |
| 2432 | + { |
| 2433 | + 'name': 'huggingface_tensor', |
| 2434 | + 'module_name': 'apache_beam.ml.inference.huggingface_inference', |
| 2435 | + 'class_name': 'HuggingFaceModelHandlerTensor', |
| 2436 | + 'skip_message': 'HuggingFace transformers not available', |
| 2437 | + 'init_kwargs': { |
| 2438 | + 'model_uri': 'unused', |
| 2439 | + 'model_class': object, |
| 2440 | + }, |
| 2441 | + }, |
| 2442 | + { |
| 2443 | + 'name': 'huggingface_pipeline', |
| 2444 | + 'module_name': 'apache_beam.ml.inference.huggingface_inference', |
| 2445 | + 'class_name': 'HuggingFacePipelineModelHandler', |
| 2446 | + 'skip_message': 'HuggingFace transformers not available', |
| 2447 | + 'init_kwargs': { |
| 2448 | + 'task': 'text-classification', |
| 2449 | + }, |
| 2450 | + }, |
| 2451 | + )) |
| 2452 | + |
| 2453 | + def test_sklearn_handlers(self): |
| 2454 | + self._assert_handler_cases(( |
| 2455 | + { |
| 2456 | + 'name': 'sklearn_numpy', |
| 2457 | + 'module_name': 'apache_beam.ml.inference.sklearn_inference', |
| 2458 | + 'class_name': 'SklearnModelHandlerNumpy', |
| 2459 | + 'skip_message': 'scikit-learn not available', |
| 2460 | + 'init_kwargs': { |
| 2461 | + 'model_uri': 'unused', |
| 2462 | + }, |
| 2463 | + }, |
| 2464 | + { |
| 2465 | + 'name': 'sklearn_pandas', |
| 2466 | + 'module_name': 'apache_beam.ml.inference.sklearn_inference', |
| 2467 | + 'class_name': 'SklearnModelHandlerPandas', |
| 2468 | + 'skip_message': 'scikit-learn not available', |
| 2469 | + 'init_kwargs': { |
| 2470 | + 'model_uri': 'unused', |
| 2471 | + }, |
| 2472 | + }, |
| 2473 | + )) |
| 2474 | + |
| 2475 | + def test_tensorflow_handlers(self): |
| 2476 | + self._assert_handler_cases(( |
| 2477 | + { |
| 2478 | + 'name': 'tensorflow_numpy', |
| 2479 | + 'module_name': 'apache_beam.ml.inference.tensorflow_inference', |
| 2480 | + 'class_name': 'TFModelHandlerNumpy', |
| 2481 | + 'skip_message': 'TensorFlow not available', |
| 2482 | + 'init_kwargs': { |
| 2483 | + 'model_uri': 'unused', |
| 2484 | + }, |
| 2485 | + }, |
| 2486 | + { |
| 2487 | + 'name': 'tensorflow_tensor', |
| 2488 | + 'module_name': 'apache_beam.ml.inference.tensorflow_inference', |
| 2489 | + 'class_name': 'TFModelHandlerTensor', |
| 2490 | + 'skip_message': 'TensorFlow not available', |
| 2491 | + 'init_kwargs': { |
| 2492 | + 'model_uri': 'unused', |
| 2493 | + }, |
| 2494 | + }, |
| 2495 | + )) |
| 2496 | + |
| 2497 | + def test_onnx_handler(self): |
| 2498 | + self._assert_handler_cases(({ |
| 2499 | + 'name': 'onnx_numpy', |
| 2500 | + 'module_name': 'apache_beam.ml.inference.onnx_inference', |
| 2501 | + 'class_name': 'OnnxModelHandlerNumpy', |
| 2502 | + 'skip_message': 'ONNX Runtime not available', |
| 2503 | + 'init_kwargs': { |
| 2504 | + 'model_uri': 'unused', |
| 2505 | + }, |
| 2506 | + }, )) |
| 2507 | + |
| 2508 | + def test_xgboost_handler(self): |
| 2509 | + self._assert_handler_cases(({ |
| 2510 | + 'name': 'xgboost_numpy', |
| 2511 | + 'module_name': 'apache_beam.ml.inference.xgboost_inference', |
| 2512 | + 'class_name': 'XGBoostModelHandlerNumpy', |
| 2513 | + 'skip_message': 'XGBoost not available', |
| 2514 | + 'init_kwargs': { |
| 2515 | + 'model_class': object, |
| 2516 | + 'model_state': 'unused', |
| 2517 | + }, |
| 2518 | + }, )) |
| 2519 | + |
| 2520 | + def test_tensorrt_handler(self): |
| 2521 | + self._assert_handler_cases(({ |
| 2522 | + 'name': 'tensorrt_numpy', |
| 2523 | + 'module_name': 'apache_beam.ml.inference.tensorrt_inference', |
| 2524 | + 'class_name': 'TensorRTEngineHandlerNumPy', |
| 2525 | + 'skip_message': 'TensorRT not available', |
| 2526 | + 'init_kwargs': { |
| 2527 | + 'min_batch_size': 1, |
| 2528 | + 'max_batch_size': 8, |
| 2529 | + }, |
| 2530 | + }, )) |
| 2531 | + |
| 2532 | + def test_vllm_handlers(self): |
| 2533 | + self._assert_handler_cases(( |
| 2534 | + { |
| 2535 | + 'name': 'vllm_completions', |
| 2536 | + 'module_name': 'apache_beam.ml.inference.vllm_inference', |
| 2537 | + 'class_name': 'VLLMCompletionsModelHandler', |
| 2538 | + 'skip_message': 'vLLM not available', |
| 2539 | + 'init_kwargs': { |
| 2540 | + 'model_name': 'unused', |
| 2541 | + }, |
| 2542 | + }, |
| 2543 | + { |
| 2544 | + 'name': 'vllm_chat', |
| 2545 | + 'module_name': 'apache_beam.ml.inference.vllm_inference', |
| 2546 | + 'class_name': 'VLLMChatModelHandler', |
| 2547 | + 'skip_message': 'vLLM not available', |
| 2548 | + 'init_kwargs': { |
| 2549 | + 'model_name': 'unused', |
| 2550 | + }, |
| 2551 | + }, |
| 2552 | + )) |
| 2553 | + |
| 2554 | + def test_vertex_ai_handler(self): |
| 2555 | + self._assert_handler_cases(({ |
| 2556 | + 'name': 'vertex_ai', |
| 2557 | + 'module_name': 'apache_beam.ml.inference.vertex_ai_inference', |
| 2558 | + 'class_name': 'VertexAIModelHandlerJSON', |
| 2559 | + 'skip_message': 'Vertex AI SDK not available', |
| 2560 | + 'init_kwargs': { |
| 2561 | + 'endpoint_id': 'unused', |
| 2562 | + 'project': 'unused', |
| 2563 | + 'location': 'unused', |
| 2564 | + }, |
| 2565 | + 'mock_aiplatform': True, |
| 2566 | + }, )) |
| 2567 | + |
| 2568 | + def test_gemini_handler(self): |
| 2569 | + self._assert_handler_cases(({ |
| 2570 | + 'name': 'gemini', |
| 2571 | + 'module_name': 'apache_beam.ml.inference.gemini_inference', |
| 2572 | + 'class_name': 'GeminiModelHandler', |
| 2573 | + 'skip_message': 'Google GenAI SDK not available', |
| 2574 | + 'init_kwargs': { |
| 2575 | + 'model_name': 'unused', |
| 2576 | + 'request_fn': lambda *args: None, |
| 2577 | + 'project': 'unused', |
| 2578 | + 'location': 'unused', |
| 2579 | + }, |
| 2580 | + }, )) |
| 2581 | + |
| 2582 | + |
2322 | 2583 | class SimpleFakeModelHandler(base.ModelHandler[int, int, FakeModel]): |
2323 | 2584 | def load_model(self): |
2324 | 2585 | return FakeModel() |
|
0 commit comments