Skip to content

Commit 69109b5

Browse files
committed
Tests to verify passthrough of early shutdown properties
1 parent 4edb96d commit 69109b5

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed

tests/engine/dataset_builders/test_column_wise_builder.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from data_designer.config.column_configs import LLMTextColumnConfig, SamplerColumnConfig
1010
from data_designer.config.dataset_builders import BuildStage
1111
from data_designer.config.processors import DropColumnsProcessorConfig
12+
from data_designer.engine.column_generators.generators.base import GenerationStrategy
1213
from data_designer.engine.dataset_builders.column_wise_builder import (
1314
MAX_CONCURRENCY_PER_NON_LLM_GENERATOR,
1415
ColumnWiseDatasetBuilder,
@@ -306,3 +307,78 @@ def test_emit_batch_inference_events_handles_multiple_models(
306307
events = [call[0][0] for call in mock_handler_instance.enqueue.call_args_list]
307308
model_names = {e.model for e in events}
308309
assert model_names == {"model-a", "model-b"}
310+
311+
312+
@pytest.mark.parametrize(
313+
"enable_early_shutdown,shutdown_error_rate,shutdown_error_window",
314+
[
315+
(True, 0.5, 10), # defaults
316+
(False, 0.8, 25), # custom values
317+
],
318+
)
319+
def test_column_wise_dataset_builder_stores_early_shutdown_params(
320+
stub_resource_provider: Mock,
321+
stub_test_column_configs: list,
322+
stub_test_processor_configs: list,
323+
enable_early_shutdown: bool,
324+
shutdown_error_rate: float,
325+
shutdown_error_window: int,
326+
) -> None:
327+
"""Test that ColumnWiseDatasetBuilder stores early shutdown parameters."""
328+
builder = ColumnWiseDatasetBuilder(
329+
column_configs=stub_test_column_configs,
330+
processor_configs=stub_test_processor_configs,
331+
resource_provider=stub_resource_provider,
332+
enable_early_shutdown=enable_early_shutdown,
333+
shutdown_error_rate=shutdown_error_rate,
334+
shutdown_error_window=shutdown_error_window,
335+
)
336+
337+
assert builder._enable_early_shutdown is enable_early_shutdown
338+
assert builder._shutdown_error_rate == shutdown_error_rate
339+
assert builder._shutdown_error_window == shutdown_error_window
340+
341+
342+
@pytest.mark.parametrize(
343+
"enable_early_shutdown,configured_rate,expected_rate",
344+
[
345+
(True, 0.7, 0.7), # enabled: use configured rate
346+
(False, 0.7, 1.0), # disabled: use 1.0 to effectively disable
347+
],
348+
)
349+
@patch("data_designer.engine.dataset_builders.column_wise_builder.ConcurrentThreadExecutor")
350+
def test_fan_out_with_threads_respects_enable_early_shutdown_flag(
351+
mock_executor_class: Mock,
352+
stub_resource_provider: Mock,
353+
stub_test_column_configs: list,
354+
stub_test_processor_configs: list,
355+
enable_early_shutdown: bool,
356+
configured_rate: float,
357+
expected_rate: float,
358+
) -> None:
359+
"""Test that _fan_out_with_threads passes correct shutdown_error_rate based on enable flag."""
360+
builder = ColumnWiseDatasetBuilder(
361+
column_configs=stub_test_column_configs,
362+
processor_configs=stub_test_processor_configs,
363+
resource_provider=stub_resource_provider,
364+
enable_early_shutdown=enable_early_shutdown,
365+
shutdown_error_rate=configured_rate,
366+
shutdown_error_window=20,
367+
)
368+
369+
mock_executor_class.return_value.__enter__ = Mock(return_value=Mock())
370+
mock_executor_class.return_value.__exit__ = Mock(return_value=False)
371+
372+
mock_generator = Mock()
373+
mock_generator.generation_strategy = GenerationStrategy.CELL_BY_CELL
374+
mock_generator.config.name = "test"
375+
mock_generator.config.column_type = "llm_text"
376+
377+
builder.batch_manager = Mock()
378+
builder.batch_manager.iter_current_batch.return_value = []
379+
380+
builder._fan_out_with_threads(mock_generator, max_workers=4)
381+
382+
call_kwargs = mock_executor_class.call_args[1]
383+
assert call_kwargs["shutdown_error_rate"] == expected_rate
384+
assert call_kwargs["shutdown_error_window"] == 20

0 commit comments

Comments
 (0)