Skip to content
34 changes: 34 additions & 0 deletions src/data_designer/config/run_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from pydantic import Field, model_validator
from typing_extensions import Self

from data_designer.config.base import ConfigBase


class RunConfig(ConfigBase):
"""Runtime configuration for dataset generation.

Groups configuration options that control generation behavior but aren't
part of the dataset configuration itself.

Attributes:
disable_early_shutdown: If True, disables early shutdown entirely. Generation
will continue regardless of error rate. Default is False.
shutdown_error_rate: Error rate threshold (0.0-1.0) that triggers early shutdown.
When early shutdown is disabled, this value is normalized to 1.0. Default is 0.5.
shutdown_error_window: Minimum number of completed tasks before error rate
monitoring begins. Must be >= 0. Default is 10.
"""

disable_early_shutdown: bool = False
shutdown_error_rate: float = Field(default=0.5, ge=0.0, le=1.0)
shutdown_error_window: int = Field(default=10, ge=0)

@model_validator(mode="after")
def normalize_shutdown_settings(self) -> Self:
"""Set shutdown_error_rate to 1.0 when early shutdown is disabled."""
if self.disable_early_shutdown:
self.shutdown_error_rate = 1.0
return self
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,14 @@ def result_callback(result: ValidationResult, context: dict):
def error_callback(error: Exception, context: dict):
outputs[context["index"]] = ValidationResult.empty(size=len(batched_records[context["index"]]))

settings = self.resource_provider.run_config
with ConcurrentThreadExecutor(
max_workers=self.config.validator_params.max_parallel_requests,
column_name=self.config.name,
result_callback=result_callback,
error_callback=error_callback,
shutdown_error_rate=settings.shutdown_error_rate,
shutdown_error_window=settings.shutdown_error_window,
) as executor:
for i, batch in enumerate(batched_records):
executor.submit(lambda batch: self._validate_batch(validator, batch), batch, context={"index": i})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,14 @@ def _fan_out_with_threads(self, generator: WithModelGeneration, max_workers: int
f"🐙 Processing {generator.config.column_type} column '{generator.config.name}' "
f"with {max_workers} concurrent workers"
)
settings = self._resource_provider.run_config
with ConcurrentThreadExecutor(
max_workers=max_workers,
column_name=generator.config.name,
result_callback=self._worker_result_callback,
error_callback=self._worker_error_callback,
shutdown_error_rate=settings.shutdown_error_rate,
shutdown_error_window=settings.shutdown_error_window,
) as executor:
for i, record in self.batch_manager.iter_current_batch():
executor.submit(lambda record: generator.generate(record), record, context={"index": i})
Expand Down
4 changes: 4 additions & 0 deletions src/data_designer/engine/resources/resource_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from data_designer.config.base import ConfigBase
from data_designer.config.models import ModelConfig
from data_designer.config.run_config import RunConfig
from data_designer.config.seed_source import SeedSource
from data_designer.config.utils.type_helpers import StrEnum
from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
Expand All @@ -23,6 +24,7 @@ class ResourceProvider(ConfigBase):
artifact_storage: ArtifactStorage
blob_storage: ManagedBlobStorage | None = None
model_registry: ModelRegistry | None = None
run_config: RunConfig = RunConfig()
seed_reader: SeedReader | None = None


Expand All @@ -35,6 +37,7 @@ def create_resource_provider(
seed_reader_registry: SeedReaderRegistry,
blob_storage: ManagedBlobStorage | None = None,
seed_dataset_source: SeedSource | None = None,
run_config: RunConfig | None = None,
) -> ResourceProvider:
seed_reader = None
if seed_dataset_source:
Expand All @@ -51,4 +54,5 @@ def create_resource_provider(
),
blob_storage=blob_storage or init_managed_blob_storage(),
seed_reader=seed_reader,
run_config=run_config or RunConfig(),
)
2 changes: 2 additions & 0 deletions src/data_designer/essentials/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from data_designer.config.default_model_settings import resolve_seed_default_model_settings
from data_designer.config.exports import * # noqa: F403
from data_designer.config.run_config import RunConfig
from data_designer.config.validator_params import LocalCallableValidatorParams
from data_designer.interface.data_designer import DataDesigner
from data_designer.logging import LoggingConfig, configure_logging
Expand All @@ -21,6 +22,7 @@ def get_essentials_exports() -> list[str]:
local = [
DataDesigner.__name__,
LocalCallableValidatorParams.__name__,
RunConfig.__name__,
]

return logging + local + get_config_exports() # noqa: F405
Expand Down
21 changes: 20 additions & 1 deletion src/data_designer/interface/data_designer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ModelProvider,
)
from data_designer.config.preview_results import PreviewResults
from data_designer.config.run_config import RunConfig
from data_designer.config.utils.constants import (
DEFAULT_NUM_RECORDS,
MANAGED_ASSETS_PATH,
Expand Down Expand Up @@ -108,6 +109,7 @@ def __init__(
self._secret_resolver = secret_resolver or DEFAULT_SECRET_RESOLVER
self._artifact_path = Path(artifact_path) if artifact_path is not None else Path.cwd() / "artifacts"
self._buffer_size = DEFAULT_BUFFER_SIZE
self._run_config = RunConfig()
self._managed_assets_path = Path(managed_assets_path or MANAGED_ASSETS_PATH)
self._model_providers = self._resolve_model_providers(model_providers)
self._model_provider_registry = resolve_model_provider_registry(
Expand Down Expand Up @@ -311,6 +313,20 @@ def set_buffer_size(self, buffer_size: int) -> None:
raise InvalidBufferValueError("Buffer size must be greater than 0.")
self._buffer_size = buffer_size

def set_run_config(self, run_config: RunConfig) -> None:
"""Set the runtime configuration for dataset generation.

Args:
run_config: A RunConfig instance containing runtime settings such as
early shutdown behavior. Import RunConfig from data_designer.essentials.

Example:
>>> from data_designer.essentials import DataDesigner, RunConfig
>>> dd = DataDesigner()
>>> dd.set_run_config(RunConfig(disable_early_shutdown=True))
"""
self._run_config = run_config

def _resolve_model_providers(self, model_providers: list[ModelProvider] | None) -> list[ModelProvider]:
if model_providers is None:
model_providers = get_default_providers()
Expand All @@ -327,7 +343,9 @@ def _resolve_model_providers(self, model_providers: list[ModelProvider] | None)
return model_providers or []

def _create_dataset_builder(
self, config_builder: DataDesignerConfigBuilder, resource_provider: ResourceProvider
self,
config_builder: DataDesignerConfigBuilder,
resource_provider: ResourceProvider,
) -> ColumnWiseDatasetBuilder:
config = compile_data_designer_config(config_builder, resource_provider)

Expand Down Expand Up @@ -365,6 +383,7 @@ def _create_resource_provider(
blob_storage=init_managed_blob_storage(str(self._managed_assets_path)),
seed_dataset_source=seed_dataset_source,
seed_reader_registry=self._seed_reader_registry,
run_config=self._run_config,
)

def _get_interface_info(self, model_providers: list[ModelProvider]) -> InterfaceInfo:
Expand Down
2 changes: 2 additions & 0 deletions tests/engine/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd
import pytest

from data_designer.config.run_config import RunConfig
from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
from data_designer.engine.models.facade import ModelFacade
from data_designer.engine.models.registry import ModelRegistry
Expand Down Expand Up @@ -36,6 +37,7 @@ def stub_resource_provider(tmp_path, stub_model_facade):
mock_provider.artifact_storage = ArtifactStorage(artifact_path=tmp_path)
mock_provider.blob_storage = Mock(spec=ManagedBlobStorage)
mock_provider.seed_reader = Mock()
mock_provider.run_config = RunConfig()
return mock_provider


Expand Down
53 changes: 53 additions & 0 deletions tests/engine/dataset_builders/test_column_wise_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from data_designer.config.column_configs import LLMTextColumnConfig, SamplerColumnConfig
from data_designer.config.dataset_builders import BuildStage
from data_designer.config.processors import DropColumnsProcessorConfig
from data_designer.engine.column_generators.generators.base import GenerationStrategy
from data_designer.engine.dataset_builders.column_wise_builder import (
MAX_CONCURRENCY_PER_NON_LLM_GENERATOR,
ColumnWiseDatasetBuilder,
Expand Down Expand Up @@ -306,3 +307,55 @@ def test_emit_batch_inference_events_handles_multiple_models(
events = [call[0][0] for call in mock_handler_instance.enqueue.call_args_list]
model_names = {e.model for e in events}
assert model_names == {"model-a", "model-b"}


@pytest.mark.parametrize(
"disable_early_shutdown,configured_rate,expected_rate,shutdown_error_window",
[
(False, 0.7, 0.7, 20), # enabled: use configured rate
(True, 0.7, 1.0, 20), # disabled: use 1.0 to effectively disable
(False, 0.5, 0.5, 10), # defaults
],
)
@patch("data_designer.engine.dataset_builders.column_wise_builder.ConcurrentThreadExecutor")
def test_fan_out_with_threads_uses_early_shutdown_settings_from_resource_provider(
mock_executor_class: Mock,
stub_resource_provider: Mock,
stub_test_column_configs: list,
stub_test_processor_configs: list,
disable_early_shutdown: bool,
configured_rate: float,
expected_rate: float,
shutdown_error_window: int,
) -> None:
"""Test that _fan_out_with_threads uses run settings from resource_provider."""
from data_designer.config.run_config import RunConfig

stub_resource_provider.run_config = RunConfig(
disable_early_shutdown=disable_early_shutdown,
shutdown_error_rate=configured_rate,
shutdown_error_window=shutdown_error_window,
)

builder = ColumnWiseDatasetBuilder(
column_configs=stub_test_column_configs,
processor_configs=stub_test_processor_configs,
resource_provider=stub_resource_provider,
)

mock_executor_class.return_value.__enter__ = Mock(return_value=Mock())
mock_executor_class.return_value.__exit__ = Mock(return_value=False)

mock_generator = Mock()
mock_generator.generation_strategy = GenerationStrategy.CELL_BY_CELL
mock_generator.config.name = "test"
mock_generator.config.column_type = "llm_text"

builder.batch_manager = Mock()
builder.batch_manager.iter_current_batch.return_value = []

builder._fan_out_with_threads(mock_generator, max_workers=4)

call_kwargs = mock_executor_class.call_args[1]
assert call_kwargs["shutdown_error_rate"] == expected_rate
assert call_kwargs["shutdown_error_window"] == shutdown_error_window
58 changes: 58 additions & 0 deletions tests/interface/test_data_designer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from data_designer.config.errors import InvalidConfigError
from data_designer.config.models import ModelProvider
from data_designer.config.processors import DropColumnsProcessorConfig
from data_designer.config.run_config import RunConfig
from data_designer.config.sampler_params import CategorySamplerParams, SamplerType
from data_designer.config.seed_source import HuggingFaceSeedSource
from data_designer.engine.secret_resolver import CompositeResolver, EnvironmentResolver, PlaintextResolver
Expand Down Expand Up @@ -103,6 +104,63 @@ def test_set_buffer_size_raises_error_for_invalid_buffer_size(stub_artifact_path
data_designer.set_buffer_size(0)


def test_run_config_setting_persists(stub_artifact_path, stub_model_providers):
"""Test that run config setting persists across multiple calls."""
data_designer = DataDesigner(artifact_path=stub_artifact_path, model_providers=stub_model_providers)

# Test default values
assert data_designer._run_config.disable_early_shutdown is False
assert data_designer._run_config.shutdown_error_rate == 0.5
assert data_designer._run_config.shutdown_error_window == 10

# Test setting custom values
data_designer.set_run_config(
RunConfig(
disable_early_shutdown=True,
shutdown_error_rate=0.8,
shutdown_error_window=25,
)
)
assert data_designer._run_config.disable_early_shutdown is True
assert data_designer._run_config.shutdown_error_rate == 1.0 # normalized when disabled
assert data_designer._run_config.shutdown_error_window == 25

# Test updating values
data_designer.set_run_config(
RunConfig(
disable_early_shutdown=False,
shutdown_error_rate=0.3,
shutdown_error_window=5,
)
)
assert data_designer._run_config.disable_early_shutdown is False
assert data_designer._run_config.shutdown_error_rate == 0.3
assert data_designer._run_config.shutdown_error_window == 5


def test_run_config_normalizes_error_rate_when_disabled(stub_artifact_path, stub_model_providers):
"""Test that shutdown_error_rate is normalized to 1.0 when disabled."""
data_designer = DataDesigner(artifact_path=stub_artifact_path, model_providers=stub_model_providers)

# When enabled (default), shutdown_error_rate should use the configured value
data_designer.set_run_config(
RunConfig(
disable_early_shutdown=False,
shutdown_error_rate=0.7,
)
)
assert data_designer._run_config.shutdown_error_rate == 0.7

# When disabled, shutdown_error_rate should be normalized to 1.0
data_designer.set_run_config(
RunConfig(
disable_early_shutdown=True,
shutdown_error_rate=0.7,
)
)
assert data_designer._run_config.shutdown_error_rate == 1.0


def test_create_dataset_e2e_using_only_sampler_columns(
stub_sampler_only_config_builder, stub_artifact_path, stub_model_providers, stub_managed_assets_path
):
Expand Down