Skip to content

Commit 82c280b

Browse files
committed
Remove redundant bucketing forwarding tests
1 parent 78c936d commit 82c280b

File tree

1 file changed

+0
-230
lines changed

1 file changed

+0
-230
lines changed

sdks/python/apache_beam/ml/inference/base_test.py

Lines changed: 0 additions & 230 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
#
1717

1818
"""Tests for apache_beam.ml.base."""
19-
import contextlib
20-
import importlib
2119
import math
2220
import multiprocessing
2321
import os
@@ -2352,234 +2350,6 @@ def test_run_inference_with_length_aware_batch_elements(self):
23522350
assert_that(results, equal_to(['a:2', 'bb:2', 'cccccc:7', 'ddddddd:7']))
23532351

23542352

2355-
class HandlerBucketingKwargsForwardingTest(unittest.TestCase):
2356-
"""Verify each concrete ModelHandler forwards batch_length_fn and
2357-
batch_bucket_boundaries through to batch_elements_kwargs()."""
2358-
_BUCKETING_KWARGS = {
2359-
'batch_length_fn': len,
2360-
'batch_bucket_boundaries': [32],
2361-
}
2362-
2363-
def _assert_bucketing_kwargs_forwarded(self, handler):
2364-
kwargs = handler.batch_elements_kwargs()
2365-
self.assertIs(kwargs['length_fn'], len)
2366-
self.assertEqual(kwargs['bucket_boundaries'], [32])
2367-
2368-
def _load_handler_class(self, case):
2369-
try:
2370-
module = importlib.import_module(case['module_name'])
2371-
except ImportError:
2372-
raise unittest.SkipTest(case['skip_message'])
2373-
return getattr(module, case['class_name'])
2374-
2375-
@contextlib.contextmanager
2376-
def _handler_setup(self, case):
2377-
if not case.get('mock_aiplatform'):
2378-
yield
2379-
return
2380-
2381-
with unittest.mock.patch(
2382-
'apache_beam.ml.inference.vertex_ai_inference.aiplatform') as mock_aip:
2383-
mock_aip.init.return_value = None
2384-
mock_endpoint = unittest.mock.MagicMock()
2385-
mock_endpoint.list_models.return_value = ['fake-model']
2386-
mock_aip.Endpoint.return_value = mock_endpoint
2387-
yield
2388-
2389-
def _assert_handler_cases(self, cases):
2390-
for case in cases:
2391-
with self.subTest(handler=case['name']):
2392-
handler_cls = self._load_handler_class(case)
2393-
init_kwargs = dict(case['init_kwargs'])
2394-
init_kwargs.update(self._BUCKETING_KWARGS)
2395-
2396-
with self._handler_setup(case):
2397-
handler = handler_cls(**init_kwargs)
2398-
2399-
self._assert_bucketing_kwargs_forwarded(handler)
2400-
2401-
def test_pytorch_handlers(self):
2402-
self._assert_handler_cases((
2403-
{
2404-
'name': 'pytorch_tensor',
2405-
'module_name': 'apache_beam.ml.inference.pytorch_inference',
2406-
'class_name': 'PytorchModelHandlerTensor',
2407-
'skip_message': 'PyTorch not available',
2408-
'init_kwargs': {},
2409-
},
2410-
{
2411-
'name': 'pytorch_keyed_tensor',
2412-
'module_name': 'apache_beam.ml.inference.pytorch_inference',
2413-
'class_name': 'PytorchModelHandlerKeyedTensor',
2414-
'skip_message': 'PyTorch not available',
2415-
'init_kwargs': {},
2416-
},
2417-
))
2418-
2419-
def test_huggingface_handlers(self):
2420-
self._assert_handler_cases((
2421-
{
2422-
'name': 'huggingface_keyed_tensor',
2423-
'module_name': 'apache_beam.ml.inference.huggingface_inference',
2424-
'class_name': 'HuggingFaceModelHandlerKeyedTensor',
2425-
'skip_message': 'HuggingFace transformers not available',
2426-
'init_kwargs': {
2427-
'model_uri': 'unused',
2428-
'model_class': object,
2429-
'framework': 'pt',
2430-
},
2431-
},
2432-
{
2433-
'name': 'huggingface_tensor',
2434-
'module_name': 'apache_beam.ml.inference.huggingface_inference',
2435-
'class_name': 'HuggingFaceModelHandlerTensor',
2436-
'skip_message': 'HuggingFace transformers not available',
2437-
'init_kwargs': {
2438-
'model_uri': 'unused',
2439-
'model_class': object,
2440-
},
2441-
},
2442-
{
2443-
'name': 'huggingface_pipeline',
2444-
'module_name': 'apache_beam.ml.inference.huggingface_inference',
2445-
'class_name': 'HuggingFacePipelineModelHandler',
2446-
'skip_message': 'HuggingFace transformers not available',
2447-
'init_kwargs': {
2448-
'task': 'text-classification',
2449-
},
2450-
},
2451-
))
2452-
2453-
def test_sklearn_handlers(self):
2454-
self._assert_handler_cases((
2455-
{
2456-
'name': 'sklearn_numpy',
2457-
'module_name': 'apache_beam.ml.inference.sklearn_inference',
2458-
'class_name': 'SklearnModelHandlerNumpy',
2459-
'skip_message': 'scikit-learn not available',
2460-
'init_kwargs': {
2461-
'model_uri': 'unused',
2462-
},
2463-
},
2464-
{
2465-
'name': 'sklearn_pandas',
2466-
'module_name': 'apache_beam.ml.inference.sklearn_inference',
2467-
'class_name': 'SklearnModelHandlerPandas',
2468-
'skip_message': 'scikit-learn not available',
2469-
'init_kwargs': {
2470-
'model_uri': 'unused',
2471-
},
2472-
},
2473-
))
2474-
2475-
def test_tensorflow_handlers(self):
2476-
self._assert_handler_cases((
2477-
{
2478-
'name': 'tensorflow_numpy',
2479-
'module_name': 'apache_beam.ml.inference.tensorflow_inference',
2480-
'class_name': 'TFModelHandlerNumpy',
2481-
'skip_message': 'TensorFlow not available',
2482-
'init_kwargs': {
2483-
'model_uri': 'unused',
2484-
},
2485-
},
2486-
{
2487-
'name': 'tensorflow_tensor',
2488-
'module_name': 'apache_beam.ml.inference.tensorflow_inference',
2489-
'class_name': 'TFModelHandlerTensor',
2490-
'skip_message': 'TensorFlow not available',
2491-
'init_kwargs': {
2492-
'model_uri': 'unused',
2493-
},
2494-
},
2495-
))
2496-
2497-
def test_onnx_handler(self):
2498-
self._assert_handler_cases(({
2499-
'name': 'onnx_numpy',
2500-
'module_name': 'apache_beam.ml.inference.onnx_inference',
2501-
'class_name': 'OnnxModelHandlerNumpy',
2502-
'skip_message': 'ONNX Runtime not available',
2503-
'init_kwargs': {
2504-
'model_uri': 'unused',
2505-
},
2506-
}, ))
2507-
2508-
def test_xgboost_handler(self):
2509-
self._assert_handler_cases(({
2510-
'name': 'xgboost_numpy',
2511-
'module_name': 'apache_beam.ml.inference.xgboost_inference',
2512-
'class_name': 'XGBoostModelHandlerNumpy',
2513-
'skip_message': 'XGBoost not available',
2514-
'init_kwargs': {
2515-
'model_class': object,
2516-
'model_state': 'unused',
2517-
},
2518-
}, ))
2519-
2520-
def test_tensorrt_handler(self):
2521-
self._assert_handler_cases(({
2522-
'name': 'tensorrt_numpy',
2523-
'module_name': 'apache_beam.ml.inference.tensorrt_inference',
2524-
'class_name': 'TensorRTEngineHandlerNumPy',
2525-
'skip_message': 'TensorRT not available',
2526-
'init_kwargs': {
2527-
'min_batch_size': 1,
2528-
'max_batch_size': 8,
2529-
},
2530-
}, ))
2531-
2532-
def test_vllm_handlers(self):
2533-
self._assert_handler_cases((
2534-
{
2535-
'name': 'vllm_completions',
2536-
'module_name': 'apache_beam.ml.inference.vllm_inference',
2537-
'class_name': 'VLLMCompletionsModelHandler',
2538-
'skip_message': 'vLLM not available',
2539-
'init_kwargs': {
2540-
'model_name': 'unused',
2541-
},
2542-
},
2543-
{
2544-
'name': 'vllm_chat',
2545-
'module_name': 'apache_beam.ml.inference.vllm_inference',
2546-
'class_name': 'VLLMChatModelHandler',
2547-
'skip_message': 'vLLM not available',
2548-
'init_kwargs': {
2549-
'model_name': 'unused',
2550-
},
2551-
},
2552-
))
2553-
2554-
def test_vertex_ai_handler(self):
2555-
self._assert_handler_cases(({
2556-
'name': 'vertex_ai',
2557-
'module_name': 'apache_beam.ml.inference.vertex_ai_inference',
2558-
'class_name': 'VertexAIModelHandlerJSON',
2559-
'skip_message': 'Vertex AI SDK not available',
2560-
'init_kwargs': {
2561-
'endpoint_id': 'unused',
2562-
'project': 'unused',
2563-
'location': 'unused',
2564-
},
2565-
'mock_aiplatform': True,
2566-
}, ))
2567-
2568-
def test_gemini_handler(self):
2569-
self._assert_handler_cases(({
2570-
'name': 'gemini',
2571-
'module_name': 'apache_beam.ml.inference.gemini_inference',
2572-
'class_name': 'GeminiModelHandler',
2573-
'skip_message': 'Google GenAI SDK not available',
2574-
'init_kwargs': {
2575-
'model_name': 'unused',
2576-
'request_fn': lambda *args: None,
2577-
'project': 'unused',
2578-
'location': 'unused',
2579-
},
2580-
}, ))
2581-
2582-
25832353
class SimpleFakeModelHandler(base.ModelHandler[int, int, FakeModel]):
25842354
def load_model(self):
25852355
return FakeModel()

0 commit comments

Comments
 (0)