diff --git a/src/data_designer/config/run_config.py b/src/data_designer/config/run_config.py new file mode 100644 index 00000000..2af5a67a --- /dev/null +++ b/src/data_designer/config/run_config.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from pydantic import Field, model_validator +from typing_extensions import Self + +from data_designer.config.base import ConfigBase + + +class RunConfig(ConfigBase): + """Runtime configuration for dataset generation. + + Groups configuration options that control generation behavior but aren't + part of the dataset configuration itself. + + Attributes: + disable_early_shutdown: If True, disables early shutdown entirely. Generation + will continue regardless of error rate. Default is False. + shutdown_error_rate: Error rate threshold (0.0-1.0) that triggers early shutdown. + When early shutdown is disabled, this value is normalized to 1.0. Default is 0.5. + shutdown_error_window: Minimum number of completed tasks before error rate + monitoring begins. Must be >= 0. Default is 10. + """ + + disable_early_shutdown: bool = False + shutdown_error_rate: float = Field(default=0.5, ge=0.0, le=1.0) + shutdown_error_window: int = Field(default=10, ge=0) + + @model_validator(mode="after") + def normalize_shutdown_settings(self) -> Self: + """Set shutdown_error_rate to 1.0 when early shutdown is disabled.""" + if self.disable_early_shutdown: + self.shutdown_error_rate = 1.0 + return self diff --git a/src/data_designer/engine/column_generators/generators/validation.py b/src/data_designer/engine/column_generators/generators/validation.py index f46296b9..a2e2b3c9 100644 --- a/src/data_designer/engine/column_generators/generators/validation.py +++ b/src/data_designer/engine/column_generators/generators/validation.py @@ -123,11 +123,14 @@ def result_callback(result: ValidationResult, context: dict): def error_callback(error: Exception, context: dict): outputs[context["index"]] = ValidationResult.empty(size=len(batched_records[context["index"]])) + settings = self.resource_provider.run_config with ConcurrentThreadExecutor( max_workers=self.config.validator_params.max_parallel_requests, column_name=self.config.name, result_callback=result_callback, error_callback=error_callback, + shutdown_error_rate=settings.shutdown_error_rate, + shutdown_error_window=settings.shutdown_error_window, ) as executor: for i, batch in enumerate(batched_records): executor.submit(lambda batch: self._validate_batch(validator, batch), batch, context={"index": i}) diff --git a/src/data_designer/engine/dataset_builders/column_wise_builder.py b/src/data_designer/engine/dataset_builders/column_wise_builder.py index 77b6459b..0f54ddf1 100644 --- a/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -217,11 +217,14 @@ def _fan_out_with_threads(self, generator: WithModelGeneration, max_workers: int f"🐙 Processing {generator.config.column_type} column '{generator.config.name}' " f"with {max_workers} concurrent workers" ) + settings = self._resource_provider.run_config with ConcurrentThreadExecutor( max_workers=max_workers, column_name=generator.config.name, result_callback=self._worker_result_callback, error_callback=self._worker_error_callback, + shutdown_error_rate=settings.shutdown_error_rate, + shutdown_error_window=settings.shutdown_error_window, ) as executor: for i, record in self.batch_manager.iter_current_batch(): executor.submit(lambda record: generator.generate(record), record, context={"index": i}) diff --git a/src/data_designer/engine/resources/resource_provider.py b/src/data_designer/engine/resources/resource_provider.py index c28225d3..9c7c85cc 100644 --- a/src/data_designer/engine/resources/resource_provider.py +++ b/src/data_designer/engine/resources/resource_provider.py @@ -3,6 +3,7 @@ from data_designer.config.base import ConfigBase from data_designer.config.models import ModelConfig +from data_designer.config.run_config import RunConfig from data_designer.config.seed_source import SeedSource from data_designer.config.utils.type_helpers import StrEnum from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage @@ -23,6 +24,7 @@ class ResourceProvider(ConfigBase): artifact_storage: ArtifactStorage blob_storage: ManagedBlobStorage | None = None model_registry: ModelRegistry | None = None + run_config: RunConfig = RunConfig() seed_reader: SeedReader | None = None @@ -35,6 +37,7 @@ def create_resource_provider( seed_reader_registry: SeedReaderRegistry, blob_storage: ManagedBlobStorage | None = None, seed_dataset_source: SeedSource | None = None, + run_config: RunConfig | None = None, ) -> ResourceProvider: seed_reader = None if seed_dataset_source: @@ -51,4 +54,5 @@ def create_resource_provider( ), blob_storage=blob_storage or init_managed_blob_storage(), seed_reader=seed_reader, + run_config=run_config or RunConfig(), ) diff --git a/src/data_designer/essentials/__init__.py b/src/data_designer/essentials/__init__.py index 6c2d021e..5ef89a13 100644 --- a/src/data_designer/essentials/__init__.py +++ b/src/data_designer/essentials/__init__.py @@ -3,6 +3,7 @@ from data_designer.config.default_model_settings import resolve_seed_default_model_settings from data_designer.config.exports import * # noqa: F403 +from data_designer.config.run_config import RunConfig from data_designer.config.validator_params import LocalCallableValidatorParams from data_designer.interface.data_designer import DataDesigner from data_designer.logging import LoggingConfig, configure_logging @@ -21,6 +22,7 @@ def get_essentials_exports() -> list[str]: local = [ DataDesigner.__name__, LocalCallableValidatorParams.__name__, + RunConfig.__name__, ] return logging + local + get_config_exports() # noqa: F405 diff --git a/src/data_designer/interface/data_designer.py b/src/data_designer/interface/data_designer.py index 2dba3d21..5e9d14c4 100644 --- a/src/data_designer/interface/data_designer.py +++ b/src/data_designer/interface/data_designer.py @@ -20,6 +20,7 @@ ModelProvider, ) from data_designer.config.preview_results import PreviewResults +from data_designer.config.run_config import RunConfig from data_designer.config.utils.constants import ( DEFAULT_NUM_RECORDS, MANAGED_ASSETS_PATH, @@ -108,6 +109,7 @@ def __init__( self._secret_resolver = secret_resolver or DEFAULT_SECRET_RESOLVER self._artifact_path = Path(artifact_path) if artifact_path is not None else Path.cwd() / "artifacts" self._buffer_size = DEFAULT_BUFFER_SIZE + self._run_config = RunConfig() self._managed_assets_path = Path(managed_assets_path or MANAGED_ASSETS_PATH) self._model_providers = self._resolve_model_providers(model_providers) self._model_provider_registry = resolve_model_provider_registry( @@ -311,6 +313,20 @@ def set_buffer_size(self, buffer_size: int) -> None: raise InvalidBufferValueError("Buffer size must be greater than 0.") self._buffer_size = buffer_size + def set_run_config(self, run_config: RunConfig) -> None: + """Set the runtime configuration for dataset generation. + + Args: + run_config: A RunConfig instance containing runtime settings such as + early shutdown behavior. Import RunConfig from data_designer.essentials. + + Example: + >>> from data_designer.essentials import DataDesigner, RunConfig + >>> dd = DataDesigner() + >>> dd.set_run_config(RunConfig(disable_early_shutdown=True)) + """ + self._run_config = run_config + def _resolve_model_providers(self, model_providers: list[ModelProvider] | None) -> list[ModelProvider]: if model_providers is None: model_providers = get_default_providers() @@ -327,7 +343,9 @@ def _resolve_model_providers(self, model_providers: list[ModelProvider] | None) return model_providers or [] def _create_dataset_builder( - self, config_builder: DataDesignerConfigBuilder, resource_provider: ResourceProvider + self, + config_builder: DataDesignerConfigBuilder, + resource_provider: ResourceProvider, ) -> ColumnWiseDatasetBuilder: config = compile_data_designer_config(config_builder, resource_provider) @@ -365,6 +383,7 @@ def _create_resource_provider( blob_storage=init_managed_blob_storage(str(self._managed_assets_path)), seed_dataset_source=seed_dataset_source, seed_reader_registry=self._seed_reader_registry, + run_config=self._run_config, ) def _get_interface_info(self, model_providers: list[ModelProvider]) -> InterfaceInfo: diff --git a/tests/engine/conftest.py b/tests/engine/conftest.py index 1c3393f9..c6fdc172 100644 --- a/tests/engine/conftest.py +++ b/tests/engine/conftest.py @@ -6,6 +6,7 @@ import pandas as pd import pytest +from data_designer.config.run_config import RunConfig from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage from data_designer.engine.models.facade import ModelFacade from data_designer.engine.models.registry import ModelRegistry @@ -36,6 +37,7 @@ def stub_resource_provider(tmp_path, stub_model_facade): mock_provider.artifact_storage = ArtifactStorage(artifact_path=tmp_path) mock_provider.blob_storage = Mock(spec=ManagedBlobStorage) mock_provider.seed_reader = Mock() + mock_provider.run_config = RunConfig() return mock_provider diff --git a/tests/engine/dataset_builders/test_column_wise_builder.py b/tests/engine/dataset_builders/test_column_wise_builder.py index dd8c84db..38bbd006 100644 --- a/tests/engine/dataset_builders/test_column_wise_builder.py +++ b/tests/engine/dataset_builders/test_column_wise_builder.py @@ -9,6 +9,7 @@ from data_designer.config.column_configs import LLMTextColumnConfig, SamplerColumnConfig from data_designer.config.dataset_builders import BuildStage from data_designer.config.processors import DropColumnsProcessorConfig +from data_designer.engine.column_generators.generators.base import GenerationStrategy from data_designer.engine.dataset_builders.column_wise_builder import ( MAX_CONCURRENCY_PER_NON_LLM_GENERATOR, ColumnWiseDatasetBuilder, @@ -306,3 +307,55 @@ def test_emit_batch_inference_events_handles_multiple_models( events = [call[0][0] for call in mock_handler_instance.enqueue.call_args_list] model_names = {e.model for e in events} assert model_names == {"model-a", "model-b"} + + +@pytest.mark.parametrize( + "disable_early_shutdown,configured_rate,expected_rate,shutdown_error_window", + [ + (False, 0.7, 0.7, 20), # enabled: use configured rate + (True, 0.7, 1.0, 20), # disabled: use 1.0 to effectively disable + (False, 0.5, 0.5, 10), # defaults + ], +) +@patch("data_designer.engine.dataset_builders.column_wise_builder.ConcurrentThreadExecutor") +def test_fan_out_with_threads_uses_early_shutdown_settings_from_resource_provider( + mock_executor_class: Mock, + stub_resource_provider: Mock, + stub_test_column_configs: list, + stub_test_processor_configs: list, + disable_early_shutdown: bool, + configured_rate: float, + expected_rate: float, + shutdown_error_window: int, +) -> None: + """Test that _fan_out_with_threads uses run settings from resource_provider.""" + from data_designer.config.run_config import RunConfig + + stub_resource_provider.run_config = RunConfig( + disable_early_shutdown=disable_early_shutdown, + shutdown_error_rate=configured_rate, + shutdown_error_window=shutdown_error_window, + ) + + builder = ColumnWiseDatasetBuilder( + column_configs=stub_test_column_configs, + processor_configs=stub_test_processor_configs, + resource_provider=stub_resource_provider, + ) + + mock_executor_class.return_value.__enter__ = Mock(return_value=Mock()) + mock_executor_class.return_value.__exit__ = Mock(return_value=False) + + mock_generator = Mock() + mock_generator.generation_strategy = GenerationStrategy.CELL_BY_CELL + mock_generator.config.name = "test" + mock_generator.config.column_type = "llm_text" + + builder.batch_manager = Mock() + builder.batch_manager.iter_current_batch.return_value = [] + + builder._fan_out_with_threads(mock_generator, max_workers=4) + + call_kwargs = mock_executor_class.call_args[1] + assert call_kwargs["shutdown_error_rate"] == expected_rate + assert call_kwargs["shutdown_error_window"] == shutdown_error_window diff --git a/tests/interface/test_data_designer.py b/tests/interface/test_data_designer.py index c901faf2..745d11f7 100644 --- a/tests/interface/test_data_designer.py +++ b/tests/interface/test_data_designer.py @@ -13,6 +13,7 @@ from data_designer.config.errors import InvalidConfigError from data_designer.config.models import ModelProvider from data_designer.config.processors import DropColumnsProcessorConfig +from data_designer.config.run_config import RunConfig from data_designer.config.sampler_params import CategorySamplerParams, SamplerType from data_designer.config.seed_source import HuggingFaceSeedSource from data_designer.engine.secret_resolver import CompositeResolver, EnvironmentResolver, PlaintextResolver @@ -103,6 +104,63 @@ def test_set_buffer_size_raises_error_for_invalid_buffer_size(stub_artifact_path data_designer.set_buffer_size(0) +def test_run_config_setting_persists(stub_artifact_path, stub_model_providers): + """Test that run config setting persists across multiple calls.""" + data_designer = DataDesigner(artifact_path=stub_artifact_path, model_providers=stub_model_providers) + + # Test default values + assert data_designer._run_config.disable_early_shutdown is False + assert data_designer._run_config.shutdown_error_rate == 0.5 + assert data_designer._run_config.shutdown_error_window == 10 + + # Test setting custom values + data_designer.set_run_config( + RunConfig( + disable_early_shutdown=True, + shutdown_error_rate=0.8, + shutdown_error_window=25, + ) + ) + assert data_designer._run_config.disable_early_shutdown is True + assert data_designer._run_config.shutdown_error_rate == 1.0 # normalized when disabled + assert data_designer._run_config.shutdown_error_window == 25 + + # Test updating values + data_designer.set_run_config( + RunConfig( + disable_early_shutdown=False, + shutdown_error_rate=0.3, + shutdown_error_window=5, + ) + ) + assert data_designer._run_config.disable_early_shutdown is False + assert data_designer._run_config.shutdown_error_rate == 0.3 + assert data_designer._run_config.shutdown_error_window == 5 + + +def test_run_config_normalizes_error_rate_when_disabled(stub_artifact_path, stub_model_providers): + """Test that shutdown_error_rate is normalized to 1.0 when disabled.""" + data_designer = DataDesigner(artifact_path=stub_artifact_path, model_providers=stub_model_providers) + + # When enabled (default), shutdown_error_rate should use the configured value + data_designer.set_run_config( + RunConfig( + disable_early_shutdown=False, + shutdown_error_rate=0.7, + ) + ) + assert data_designer._run_config.shutdown_error_rate == 0.7 + + # When disabled, shutdown_error_rate should be normalized to 1.0 + data_designer.set_run_config( + RunConfig( + disable_early_shutdown=True, + shutdown_error_rate=0.7, + ) + ) + assert data_designer._run_config.shutdown_error_rate == 1.0 + + def test_create_dataset_e2e_using_only_sampler_columns( stub_sampler_only_config_builder, stub_artifact_path, stub_model_providers, stub_managed_assets_path ):