|
9 | 9 | from data_designer.config.column_configs import LLMTextColumnConfig, SamplerColumnConfig |
10 | 10 | from data_designer.config.dataset_builders import BuildStage |
11 | 11 | from data_designer.config.processors import DropColumnsProcessorConfig |
| 12 | +from data_designer.engine.column_generators.generators.base import GenerationStrategy |
12 | 13 | from data_designer.engine.dataset_builders.column_wise_builder import ( |
13 | 14 | MAX_CONCURRENCY_PER_NON_LLM_GENERATOR, |
14 | 15 | ColumnWiseDatasetBuilder, |
@@ -306,3 +307,78 @@ def test_emit_batch_inference_events_handles_multiple_models( |
306 | 307 | events = [call[0][0] for call in mock_handler_instance.enqueue.call_args_list] |
307 | 308 | model_names = {e.model for e in events} |
308 | 309 | assert model_names == {"model-a", "model-b"} |
| 310 | + |
| 311 | + |
| 312 | +@pytest.mark.parametrize( |
| 313 | + "enable_early_shutdown,shutdown_error_rate,shutdown_error_window", |
| 314 | + [ |
| 315 | + (True, 0.5, 10), # defaults |
| 316 | + (False, 0.8, 25), # custom values |
| 317 | + ], |
| 318 | +) |
| 319 | +def test_column_wise_dataset_builder_stores_early_shutdown_params( |
| 320 | + stub_resource_provider: Mock, |
| 321 | + stub_test_column_configs: list, |
| 322 | + stub_test_processor_configs: list, |
| 323 | + enable_early_shutdown: bool, |
| 324 | + shutdown_error_rate: float, |
| 325 | + shutdown_error_window: int, |
| 326 | +) -> None: |
| 327 | + """Test that ColumnWiseDatasetBuilder stores early shutdown parameters.""" |
| 328 | + builder = ColumnWiseDatasetBuilder( |
| 329 | + column_configs=stub_test_column_configs, |
| 330 | + processor_configs=stub_test_processor_configs, |
| 331 | + resource_provider=stub_resource_provider, |
| 332 | + enable_early_shutdown=enable_early_shutdown, |
| 333 | + shutdown_error_rate=shutdown_error_rate, |
| 334 | + shutdown_error_window=shutdown_error_window, |
| 335 | + ) |
| 336 | + |
| 337 | + assert builder._enable_early_shutdown is enable_early_shutdown |
| 338 | + assert builder._shutdown_error_rate == shutdown_error_rate |
| 339 | + assert builder._shutdown_error_window == shutdown_error_window |
| 340 | + |
| 341 | + |
| 342 | +@pytest.mark.parametrize( |
| 343 | + "enable_early_shutdown,configured_rate,expected_rate", |
| 344 | + [ |
| 345 | + (True, 0.7, 0.7), # enabled: use configured rate |
| 346 | + (False, 0.7, 1.0), # disabled: use 1.0 to effectively disable |
| 347 | + ], |
| 348 | +) |
| 349 | +@patch("data_designer.engine.dataset_builders.column_wise_builder.ConcurrentThreadExecutor") |
| 350 | +def test_fan_out_with_threads_respects_enable_early_shutdown_flag( |
| 351 | + mock_executor_class: Mock, |
| 352 | + stub_resource_provider: Mock, |
| 353 | + stub_test_column_configs: list, |
| 354 | + stub_test_processor_configs: list, |
| 355 | + enable_early_shutdown: bool, |
| 356 | + configured_rate: float, |
| 357 | + expected_rate: float, |
| 358 | +) -> None: |
| 359 | + """Test that _fan_out_with_threads passes correct shutdown_error_rate based on enable flag.""" |
| 360 | + builder = ColumnWiseDatasetBuilder( |
| 361 | + column_configs=stub_test_column_configs, |
| 362 | + processor_configs=stub_test_processor_configs, |
| 363 | + resource_provider=stub_resource_provider, |
| 364 | + enable_early_shutdown=enable_early_shutdown, |
| 365 | + shutdown_error_rate=configured_rate, |
| 366 | + shutdown_error_window=20, |
| 367 | + ) |
| 368 | + |
| 369 | + mock_executor_class.return_value.__enter__ = Mock(return_value=Mock()) |
| 370 | + mock_executor_class.return_value.__exit__ = Mock(return_value=False) |
| 371 | + |
| 372 | + mock_generator = Mock() |
| 373 | + mock_generator.generation_strategy = GenerationStrategy.CELL_BY_CELL |
| 374 | + mock_generator.config.name = "test" |
| 375 | + mock_generator.config.column_type = "llm_text" |
| 376 | + |
| 377 | + builder.batch_manager = Mock() |
| 378 | + builder.batch_manager.iter_current_batch.return_value = [] |
| 379 | + |
| 380 | + builder._fan_out_with_threads(mock_generator, max_workers=4) |
| 381 | + |
| 382 | + call_kwargs = mock_executor_class.call_args[1] |
| 383 | + assert call_kwargs["shutdown_error_rate"] == expected_rate |
| 384 | + assert call_kwargs["shutdown_error_window"] == 20 |
0 commit comments