|
16 | 16 | # |
17 | 17 |
|
18 | 18 | """Tests for apache_beam.ml.base.""" |
19 | | -import contextlib |
20 | | -import importlib |
21 | 19 | import math |
22 | 20 | import multiprocessing |
23 | 21 | import os |
@@ -2352,234 +2350,6 @@ def test_run_inference_with_length_aware_batch_elements(self): |
2352 | 2350 | assert_that(results, equal_to(['a:2', 'bb:2', 'cccccc:7', 'ddddddd:7'])) |
2353 | 2351 |
|
2354 | 2352 |
|
2355 | | -class HandlerBucketingKwargsForwardingTest(unittest.TestCase): |
2356 | | - """Verify each concrete ModelHandler forwards batch_length_fn and |
2357 | | - batch_bucket_boundaries through to batch_elements_kwargs().""" |
2358 | | - _BUCKETING_KWARGS = { |
2359 | | - 'batch_length_fn': len, |
2360 | | - 'batch_bucket_boundaries': [32], |
2361 | | - } |
2362 | | - |
2363 | | - def _assert_bucketing_kwargs_forwarded(self, handler): |
2364 | | - kwargs = handler.batch_elements_kwargs() |
2365 | | - self.assertIs(kwargs['length_fn'], len) |
2366 | | - self.assertEqual(kwargs['bucket_boundaries'], [32]) |
2367 | | - |
2368 | | - def _load_handler_class(self, case): |
2369 | | - try: |
2370 | | - module = importlib.import_module(case['module_name']) |
2371 | | - except ImportError: |
2372 | | - raise unittest.SkipTest(case['skip_message']) |
2373 | | - return getattr(module, case['class_name']) |
2374 | | - |
2375 | | - @contextlib.contextmanager |
2376 | | - def _handler_setup(self, case): |
2377 | | - if not case.get('mock_aiplatform'): |
2378 | | - yield |
2379 | | - return |
2380 | | - |
2381 | | - with unittest.mock.patch( |
2382 | | - 'apache_beam.ml.inference.vertex_ai_inference.aiplatform') as mock_aip: |
2383 | | - mock_aip.init.return_value = None |
2384 | | - mock_endpoint = unittest.mock.MagicMock() |
2385 | | - mock_endpoint.list_models.return_value = ['fake-model'] |
2386 | | - mock_aip.Endpoint.return_value = mock_endpoint |
2387 | | - yield |
2388 | | - |
2389 | | - def _assert_handler_cases(self, cases): |
2390 | | - for case in cases: |
2391 | | - with self.subTest(handler=case['name']): |
2392 | | - handler_cls = self._load_handler_class(case) |
2393 | | - init_kwargs = dict(case['init_kwargs']) |
2394 | | - init_kwargs.update(self._BUCKETING_KWARGS) |
2395 | | - |
2396 | | - with self._handler_setup(case): |
2397 | | - handler = handler_cls(**init_kwargs) |
2398 | | - |
2399 | | - self._assert_bucketing_kwargs_forwarded(handler) |
2400 | | - |
2401 | | - def test_pytorch_handlers(self): |
2402 | | - self._assert_handler_cases(( |
2403 | | - { |
2404 | | - 'name': 'pytorch_tensor', |
2405 | | - 'module_name': 'apache_beam.ml.inference.pytorch_inference', |
2406 | | - 'class_name': 'PytorchModelHandlerTensor', |
2407 | | - 'skip_message': 'PyTorch not available', |
2408 | | - 'init_kwargs': {}, |
2409 | | - }, |
2410 | | - { |
2411 | | - 'name': 'pytorch_keyed_tensor', |
2412 | | - 'module_name': 'apache_beam.ml.inference.pytorch_inference', |
2413 | | - 'class_name': 'PytorchModelHandlerKeyedTensor', |
2414 | | - 'skip_message': 'PyTorch not available', |
2415 | | - 'init_kwargs': {}, |
2416 | | - }, |
2417 | | - )) |
2418 | | - |
2419 | | - def test_huggingface_handlers(self): |
2420 | | - self._assert_handler_cases(( |
2421 | | - { |
2422 | | - 'name': 'huggingface_keyed_tensor', |
2423 | | - 'module_name': 'apache_beam.ml.inference.huggingface_inference', |
2424 | | - 'class_name': 'HuggingFaceModelHandlerKeyedTensor', |
2425 | | - 'skip_message': 'HuggingFace transformers not available', |
2426 | | - 'init_kwargs': { |
2427 | | - 'model_uri': 'unused', |
2428 | | - 'model_class': object, |
2429 | | - 'framework': 'pt', |
2430 | | - }, |
2431 | | - }, |
2432 | | - { |
2433 | | - 'name': 'huggingface_tensor', |
2434 | | - 'module_name': 'apache_beam.ml.inference.huggingface_inference', |
2435 | | - 'class_name': 'HuggingFaceModelHandlerTensor', |
2436 | | - 'skip_message': 'HuggingFace transformers not available', |
2437 | | - 'init_kwargs': { |
2438 | | - 'model_uri': 'unused', |
2439 | | - 'model_class': object, |
2440 | | - }, |
2441 | | - }, |
2442 | | - { |
2443 | | - 'name': 'huggingface_pipeline', |
2444 | | - 'module_name': 'apache_beam.ml.inference.huggingface_inference', |
2445 | | - 'class_name': 'HuggingFacePipelineModelHandler', |
2446 | | - 'skip_message': 'HuggingFace transformers not available', |
2447 | | - 'init_kwargs': { |
2448 | | - 'task': 'text-classification', |
2449 | | - }, |
2450 | | - }, |
2451 | | - )) |
2452 | | - |
2453 | | - def test_sklearn_handlers(self): |
2454 | | - self._assert_handler_cases(( |
2455 | | - { |
2456 | | - 'name': 'sklearn_numpy', |
2457 | | - 'module_name': 'apache_beam.ml.inference.sklearn_inference', |
2458 | | - 'class_name': 'SklearnModelHandlerNumpy', |
2459 | | - 'skip_message': 'scikit-learn not available', |
2460 | | - 'init_kwargs': { |
2461 | | - 'model_uri': 'unused', |
2462 | | - }, |
2463 | | - }, |
2464 | | - { |
2465 | | - 'name': 'sklearn_pandas', |
2466 | | - 'module_name': 'apache_beam.ml.inference.sklearn_inference', |
2467 | | - 'class_name': 'SklearnModelHandlerPandas', |
2468 | | - 'skip_message': 'scikit-learn not available', |
2469 | | - 'init_kwargs': { |
2470 | | - 'model_uri': 'unused', |
2471 | | - }, |
2472 | | - }, |
2473 | | - )) |
2474 | | - |
2475 | | - def test_tensorflow_handlers(self): |
2476 | | - self._assert_handler_cases(( |
2477 | | - { |
2478 | | - 'name': 'tensorflow_numpy', |
2479 | | - 'module_name': 'apache_beam.ml.inference.tensorflow_inference', |
2480 | | - 'class_name': 'TFModelHandlerNumpy', |
2481 | | - 'skip_message': 'TensorFlow not available', |
2482 | | - 'init_kwargs': { |
2483 | | - 'model_uri': 'unused', |
2484 | | - }, |
2485 | | - }, |
2486 | | - { |
2487 | | - 'name': 'tensorflow_tensor', |
2488 | | - 'module_name': 'apache_beam.ml.inference.tensorflow_inference', |
2489 | | - 'class_name': 'TFModelHandlerTensor', |
2490 | | - 'skip_message': 'TensorFlow not available', |
2491 | | - 'init_kwargs': { |
2492 | | - 'model_uri': 'unused', |
2493 | | - }, |
2494 | | - }, |
2495 | | - )) |
2496 | | - |
2497 | | - def test_onnx_handler(self): |
2498 | | - self._assert_handler_cases(({ |
2499 | | - 'name': 'onnx_numpy', |
2500 | | - 'module_name': 'apache_beam.ml.inference.onnx_inference', |
2501 | | - 'class_name': 'OnnxModelHandlerNumpy', |
2502 | | - 'skip_message': 'ONNX Runtime not available', |
2503 | | - 'init_kwargs': { |
2504 | | - 'model_uri': 'unused', |
2505 | | - }, |
2506 | | - }, )) |
2507 | | - |
2508 | | - def test_xgboost_handler(self): |
2509 | | - self._assert_handler_cases(({ |
2510 | | - 'name': 'xgboost_numpy', |
2511 | | - 'module_name': 'apache_beam.ml.inference.xgboost_inference', |
2512 | | - 'class_name': 'XGBoostModelHandlerNumpy', |
2513 | | - 'skip_message': 'XGBoost not available', |
2514 | | - 'init_kwargs': { |
2515 | | - 'model_class': object, |
2516 | | - 'model_state': 'unused', |
2517 | | - }, |
2518 | | - }, )) |
2519 | | - |
2520 | | - def test_tensorrt_handler(self): |
2521 | | - self._assert_handler_cases(({ |
2522 | | - 'name': 'tensorrt_numpy', |
2523 | | - 'module_name': 'apache_beam.ml.inference.tensorrt_inference', |
2524 | | - 'class_name': 'TensorRTEngineHandlerNumPy', |
2525 | | - 'skip_message': 'TensorRT not available', |
2526 | | - 'init_kwargs': { |
2527 | | - 'min_batch_size': 1, |
2528 | | - 'max_batch_size': 8, |
2529 | | - }, |
2530 | | - }, )) |
2531 | | - |
2532 | | - def test_vllm_handlers(self): |
2533 | | - self._assert_handler_cases(( |
2534 | | - { |
2535 | | - 'name': 'vllm_completions', |
2536 | | - 'module_name': 'apache_beam.ml.inference.vllm_inference', |
2537 | | - 'class_name': 'VLLMCompletionsModelHandler', |
2538 | | - 'skip_message': 'vLLM not available', |
2539 | | - 'init_kwargs': { |
2540 | | - 'model_name': 'unused', |
2541 | | - }, |
2542 | | - }, |
2543 | | - { |
2544 | | - 'name': 'vllm_chat', |
2545 | | - 'module_name': 'apache_beam.ml.inference.vllm_inference', |
2546 | | - 'class_name': 'VLLMChatModelHandler', |
2547 | | - 'skip_message': 'vLLM not available', |
2548 | | - 'init_kwargs': { |
2549 | | - 'model_name': 'unused', |
2550 | | - }, |
2551 | | - }, |
2552 | | - )) |
2553 | | - |
2554 | | - def test_vertex_ai_handler(self): |
2555 | | - self._assert_handler_cases(({ |
2556 | | - 'name': 'vertex_ai', |
2557 | | - 'module_name': 'apache_beam.ml.inference.vertex_ai_inference', |
2558 | | - 'class_name': 'VertexAIModelHandlerJSON', |
2559 | | - 'skip_message': 'Vertex AI SDK not available', |
2560 | | - 'init_kwargs': { |
2561 | | - 'endpoint_id': 'unused', |
2562 | | - 'project': 'unused', |
2563 | | - 'location': 'unused', |
2564 | | - }, |
2565 | | - 'mock_aiplatform': True, |
2566 | | - }, )) |
2567 | | - |
2568 | | - def test_gemini_handler(self): |
2569 | | - self._assert_handler_cases(({ |
2570 | | - 'name': 'gemini', |
2571 | | - 'module_name': 'apache_beam.ml.inference.gemini_inference', |
2572 | | - 'class_name': 'GeminiModelHandler', |
2573 | | - 'skip_message': 'Google GenAI SDK not available', |
2574 | | - 'init_kwargs': { |
2575 | | - 'model_name': 'unused', |
2576 | | - 'request_fn': lambda *args: None, |
2577 | | - 'project': 'unused', |
2578 | | - 'location': 'unused', |
2579 | | - }, |
2580 | | - }, )) |
2581 | | - |
2582 | | - |
2583 | 2353 | class SimpleFakeModelHandler(base.ModelHandler[int, int, FakeModel]): |
2584 | 2354 | def load_model(self): |
2585 | 2355 | return FakeModel() |
|
0 commit comments