Skip to content

Commit e2f4c44

Browse files
committed
Change to a config type and helper settings
1 parent 69109b5 commit e2f4c44

File tree

9 files changed

+140
-73
lines changed

9 files changed

+140
-73
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from typing import Self
5+
6+
from pydantic import Field, model_validator
7+
8+
from data_designer.config.base import ConfigBase
9+
10+
11+
class RunConfig(ConfigBase):
12+
"""Runtime configuration for dataset generation.
13+
14+
Groups configuration options that control generation behavior but aren't
15+
part of the dataset configuration itself.
16+
17+
Attributes:
18+
disable_early_shutdown: If True, disables early shutdown entirely. Generation
19+
will continue regardless of error rate. Default is False.
20+
shutdown_error_rate: Error rate threshold (0.0-1.0) that triggers early shutdown.
21+
When early shutdown is disabled, this value is normalized to 1.0. Default is 0.5.
22+
shutdown_error_window: Minimum number of completed tasks before error rate
23+
monitoring begins. Must be >= 0. Default is 10.
24+
"""
25+
26+
disable_early_shutdown: bool = False
27+
shutdown_error_rate: float = Field(default=0.5, ge=0.0, le=1.0)
28+
shutdown_error_window: int = Field(default=10, ge=0)
29+
30+
@model_validator(mode="after")
31+
def normalize_shutdown_settings(self) -> Self:
32+
"""Set shutdown_error_rate to 1.0 when early shutdown is disabled."""
33+
if self.disable_early_shutdown:
34+
self.shutdown_error_rate = 1.0
35+
return self

src/data_designer/engine/column_generators/generators/validation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,14 @@ def result_callback(result: ValidationResult, context: dict):
123123
def error_callback(error: Exception, context: dict):
124124
outputs[context["index"]] = ValidationResult.empty(size=len(batched_records[context["index"]]))
125125

126+
settings = self.resource_provider.run_config
126127
with ConcurrentThreadExecutor(
127128
max_workers=self.config.validator_params.max_parallel_requests,
128129
column_name=self.config.name,
129130
result_callback=result_callback,
130131
error_callback=error_callback,
132+
shutdown_error_rate=settings.shutdown_error_rate,
133+
shutdown_error_window=settings.shutdown_error_window,
131134
) as executor:
132135
for i, batch in enumerate(batched_records):
133136
executor.submit(lambda batch: self._validate_batch(validator, batch), batch, context={"index": i})

src/data_designer/engine/dataset_builders/column_wise_builder.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,13 @@ def __init__(
6060
processor_configs: list[ProcessorConfig],
6161
resource_provider: ResourceProvider,
6262
registry: DataDesignerRegistry | None = None,
63-
enable_early_shutdown: bool = True,
64-
shutdown_error_rate: float = 0.5,
65-
shutdown_error_window: int = 10,
6663
):
6764
self.batch_manager = DatasetBatchManager(resource_provider.artifact_storage)
6865
self._resource_provider = resource_provider
6966
self._records_to_drop: set[int] = set()
7067
self._registry = registry or DataDesignerRegistry()
7168
self._column_configs = column_configs
7269
self._processors: dict[BuildStage, list[Processor]] = self._initialize_processors(processor_configs)
73-
self._enable_early_shutdown = enable_early_shutdown
74-
self._shutdown_error_rate = shutdown_error_rate
75-
self._shutdown_error_window = shutdown_error_window
7670
self._validate_column_configs()
7771

7872
@property
@@ -223,13 +217,14 @@ def _fan_out_with_threads(self, generator: WithModelGeneration, max_workers: int
223217
f"🐙 Processing {generator.config.column_type} column '{generator.config.name}' "
224218
f"with {max_workers} concurrent workers"
225219
)
220+
settings = self._resource_provider.run_config
226221
with ConcurrentThreadExecutor(
227222
max_workers=max_workers,
228223
column_name=generator.config.name,
229224
result_callback=self._worker_result_callback,
230225
error_callback=self._worker_error_callback,
231-
shutdown_error_rate=self._shutdown_error_rate if self._enable_early_shutdown else 1.0,
232-
shutdown_error_window=self._shutdown_error_window,
226+
shutdown_error_rate=settings.shutdown_error_rate,
227+
shutdown_error_window=settings.shutdown_error_window,
233228
) as executor:
234229
for i, record in self.batch_manager.iter_current_batch():
235230
executor.submit(lambda record: generator.generate(record), record, context={"index": i})

src/data_designer/engine/resources/resource_provider.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from data_designer.config.base import ConfigBase
55
from data_designer.config.models import ModelConfig
6+
from data_designer.config.run_settings import RunConfig
67
from data_designer.config.utils.type_helpers import StrEnum
78
from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
89
from data_designer.engine.model_provider import ModelProviderRegistry
@@ -23,6 +24,7 @@ class ResourceProvider(ConfigBase):
2324
blob_storage: ManagedBlobStorage | None = None
2425
datastore: SeedDatasetDataStore | None = None
2526
model_registry: ModelRegistry | None = None
27+
run_config: RunConfig = RunConfig()
2628

2729

2830
def create_resource_provider(

src/data_designer/essentials/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from data_designer.config.default_model_settings import resolve_seed_default_model_settings
55
from data_designer.config.exports import * # noqa: F403
6+
from data_designer.config.run_settings import RunConfig
67
from data_designer.config.validator_params import LocalCallableValidatorParams
78
from data_designer.interface.data_designer import DataDesigner
89
from data_designer.logging import LoggingConfig, configure_logging
@@ -21,6 +22,7 @@ def get_essentials_exports() -> list[str]:
2122
local = [
2223
DataDesigner.__name__,
2324
LocalCallableValidatorParams.__name__,
25+
RunConfig.__name__,
2426
]
2527

2628
return logging + local + get_config_exports() # noqa: F405

src/data_designer/interface/data_designer.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
ModelProvider,
2121
)
2222
from data_designer.config.preview_results import PreviewResults
23+
from data_designer.config.run_settings import RunConfig
2324
from data_designer.config.seed import LocalSeedDatasetReference
2425
from data_designer.config.utils.constants import (
2526
DEFAULT_NUM_RECORDS,
@@ -97,6 +98,7 @@ def __init__(
9798
self._secret_resolver = secret_resolver or CompositeResolver([EnvironmentResolver(), PlaintextResolver()])
9899
self._artifact_path = Path(artifact_path) if artifact_path is not None else Path.cwd() / "artifacts"
99100
self._buffer_size = DEFAULT_BUFFER_SIZE
101+
self._run_config = RunConfig()
100102
self._managed_assets_path = Path(managed_assets_path or MANAGED_ASSETS_PATH)
101103
self._model_providers = self._resolve_model_providers(model_providers)
102104
self._model_provider_registry = resolve_model_provider_registry(
@@ -154,9 +156,6 @@ def create(
154156
*,
155157
num_records: int = DEFAULT_NUM_RECORDS,
156158
dataset_name: str = "dataset",
157-
enable_early_shutdown: bool = True,
158-
shutdown_error_rate: float = 0.5,
159-
shutdown_error_window: int = 10,
160159
) -> DatasetCreationResults:
161160
"""Create dataset and save results to the local artifact storage.
162161
@@ -174,15 +173,6 @@ def create(
174173
a datetime stamp. For example, if the dataset name is "awesome_dataset" and a directory
175174
with the same name already exists, the dataset will be saved to a new directory
176175
with the name "awesome_dataset_2025-01-01_12-00-00".
177-
enable_early_shutdown: If True (default), dataset generation will terminate
178-
early if the error rate exceeds `shutdown_error_rate` after
179-
`shutdown_error_window` tasks complete. Set to False to disable
180-
early shutdown entirely (ignores `shutdown_error_rate` and
181-
`shutdown_error_window`).
182-
shutdown_error_rate: Error rate threshold (0.0-1.0) that triggers early
183-
shutdown. Only used when `enable_early_shutdown=True`. Default is 0.5 (50%).
184-
shutdown_error_window: Minimum number of completed tasks before error rate
185-
monitoring begins. Only used when `enable_early_shutdown=True`. Default is 10.
186176
187177
Returns:
188178
DatasetCreationResults object with methods for loading the generated dataset,
@@ -196,13 +186,7 @@ def create(
196186

197187
resource_provider = self._create_resource_provider(dataset_name, config_builder)
198188

199-
builder = self._create_dataset_builder(
200-
config_builder,
201-
resource_provider,
202-
enable_early_shutdown=enable_early_shutdown,
203-
shutdown_error_rate=shutdown_error_rate,
204-
shutdown_error_window=shutdown_error_window,
205-
)
189+
builder = self._create_dataset_builder(config_builder, resource_provider)
206190

207191
try:
208192
builder.build(num_records=num_records, buffer_size=self._buffer_size)
@@ -336,6 +320,20 @@ def set_buffer_size(self, buffer_size: int) -> None:
336320
raise InvalidBufferValueError("Buffer size must be greater than 0.")
337321
self._buffer_size = buffer_size
338322

323+
def set_run_config(self, run_config: RunConfig) -> None:
324+
"""Set the runtime configuration for dataset generation.
325+
326+
Args:
327+
run_config: A RunConfig instance containing runtime settings such as
328+
early shutdown behavior. Import RunConfig from data_designer.essentials.
329+
330+
Example:
331+
>>> from data_designer.essentials import DataDesigner, RunConfig
332+
>>> dd = DataDesigner()
333+
>>> dd.set_run_config(RunConfig(disable_early_shutdown=True))
334+
"""
335+
self._run_config = run_config
336+
339337
def _resolve_model_providers(self, model_providers: list[ModelProvider] | None) -> list[ModelProvider]:
340338
if model_providers is None:
341339
model_providers = get_default_providers()
@@ -355,17 +353,11 @@ def _create_dataset_builder(
355353
self,
356354
config_builder: DataDesignerConfigBuilder,
357355
resource_provider: ResourceProvider,
358-
enable_early_shutdown: bool = True,
359-
shutdown_error_rate: float = 0.5,
360-
shutdown_error_window: int = 10,
361356
) -> ColumnWiseDatasetBuilder:
362357
return ColumnWiseDatasetBuilder(
363358
column_configs=compile_dataset_builder_column_configs(config_builder.build(raise_exceptions=True)),
364359
processor_configs=config_builder.get_processor_configs(),
365360
resource_provider=resource_provider,
366-
enable_early_shutdown=enable_early_shutdown,
367-
shutdown_error_rate=shutdown_error_rate,
368-
shutdown_error_window=shutdown_error_window,
369361
)
370362

371363
def _create_dataset_profiler(
@@ -400,6 +392,7 @@ def _create_resource_provider(
400392
token=settings.token,
401393
)
402394
),
395+
run_config=self._run_config,
403396
)
404397

405398
def _get_interface_info(self, model_providers: list[ModelProvider]) -> InterfaceInfo:

tests/engine/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pandas as pd
77
import pytest
88

9+
from data_designer.config.run_settings import RunConfig
910
from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
1011
from data_designer.engine.models.facade import ModelFacade
1112
from data_designer.engine.models.registry import ModelRegistry
@@ -36,6 +37,7 @@ def stub_resource_provider(tmp_path, stub_model_facade):
3637
mock_provider.artifact_storage = ArtifactStorage(artifact_path=tmp_path)
3738
mock_provider.blob_storage = Mock(spec=ManagedBlobStorage)
3839
mock_provider.datastore = Mock()
40+
mock_provider.run_config = RunConfig()
3941
return mock_provider
4042

4143

tests/engine/dataset_builders/test_column_wise_builder.py

Lines changed: 17 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -310,60 +310,37 @@ def test_emit_batch_inference_events_handles_multiple_models(
310310

311311

312312
@pytest.mark.parametrize(
313-
"enable_early_shutdown,shutdown_error_rate,shutdown_error_window",
313+
"disable_early_shutdown,configured_rate,expected_rate,shutdown_error_window",
314314
[
315-
(True, 0.5, 10), # defaults
316-
(False, 0.8, 25), # custom values
317-
],
318-
)
319-
def test_column_wise_dataset_builder_stores_early_shutdown_params(
320-
stub_resource_provider: Mock,
321-
stub_test_column_configs: list,
322-
stub_test_processor_configs: list,
323-
enable_early_shutdown: bool,
324-
shutdown_error_rate: float,
325-
shutdown_error_window: int,
326-
) -> None:
327-
"""Test that ColumnWiseDatasetBuilder stores early shutdown parameters."""
328-
builder = ColumnWiseDatasetBuilder(
329-
column_configs=stub_test_column_configs,
330-
processor_configs=stub_test_processor_configs,
331-
resource_provider=stub_resource_provider,
332-
enable_early_shutdown=enable_early_shutdown,
333-
shutdown_error_rate=shutdown_error_rate,
334-
shutdown_error_window=shutdown_error_window,
335-
)
336-
337-
assert builder._enable_early_shutdown is enable_early_shutdown
338-
assert builder._shutdown_error_rate == shutdown_error_rate
339-
assert builder._shutdown_error_window == shutdown_error_window
340-
341-
342-
@pytest.mark.parametrize(
343-
"enable_early_shutdown,configured_rate,expected_rate",
344-
[
345-
(True, 0.7, 0.7), # enabled: use configured rate
346-
(False, 0.7, 1.0), # disabled: use 1.0 to effectively disable
315+
(False, 0.7, 0.7, 20), # enabled: use configured rate
316+
(True, 0.7, 1.0, 20), # disabled: use 1.0 to effectively disable
317+
(False, 0.5, 0.5, 10), # defaults
347318
],
348319
)
349320
@patch("data_designer.engine.dataset_builders.column_wise_builder.ConcurrentThreadExecutor")
350-
def test_fan_out_with_threads_respects_enable_early_shutdown_flag(
321+
def test_fan_out_with_threads_uses_early_shutdown_settings_from_resource_provider(
351322
mock_executor_class: Mock,
352323
stub_resource_provider: Mock,
353324
stub_test_column_configs: list,
354325
stub_test_processor_configs: list,
355-
enable_early_shutdown: bool,
326+
disable_early_shutdown: bool,
356327
configured_rate: float,
357328
expected_rate: float,
329+
shutdown_error_window: int,
358330
) -> None:
359-
"""Test that _fan_out_with_threads passes correct shutdown_error_rate based on enable flag."""
331+
"""Test that _fan_out_with_threads uses run settings from resource_provider."""
332+
from data_designer.config.run_settings import RunConfig
333+
334+
stub_resource_provider.run_config = RunConfig(
335+
disable_early_shutdown=disable_early_shutdown,
336+
shutdown_error_rate=configured_rate,
337+
shutdown_error_window=shutdown_error_window,
338+
)
339+
360340
builder = ColumnWiseDatasetBuilder(
361341
column_configs=stub_test_column_configs,
362342
processor_configs=stub_test_processor_configs,
363343
resource_provider=stub_resource_provider,
364-
enable_early_shutdown=enable_early_shutdown,
365-
shutdown_error_rate=configured_rate,
366-
shutdown_error_window=20,
367344
)
368345

369346
mock_executor_class.return_value.__enter__ = Mock(return_value=Mock())
@@ -381,4 +358,4 @@ def test_fan_out_with_threads_respects_enable_early_shutdown_flag(
381358

382359
call_kwargs = mock_executor_class.call_args[1]
383360
assert call_kwargs["shutdown_error_rate"] == expected_rate
384-
assert call_kwargs["shutdown_error_window"] == 20
361+
assert call_kwargs["shutdown_error_window"] == shutdown_error_window

tests/interface/test_data_designer.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from data_designer.config.dataset_builders import BuildStage
1414
from data_designer.config.errors import InvalidFileFormatError
1515
from data_designer.config.processors import DropColumnsProcessorConfig
16+
from data_designer.config.run_settings import RunConfig
1617
from data_designer.config.seed import LocalSeedDatasetReference
1718
from data_designer.engine.model_provider import ModelProvider
1819
from data_designer.engine.secret_resolver import CompositeResolver, EnvironmentResolver, PlaintextResolver
@@ -268,6 +269,63 @@ def test_set_buffer_size_raises_error_for_invalid_buffer_size(stub_artifact_path
268269
data_designer.set_buffer_size(0)
269270

270271

272+
def test_run_config_setting_persists(stub_artifact_path, stub_model_providers):
273+
"""Test that run config setting persists across multiple calls."""
274+
data_designer = DataDesigner(artifact_path=stub_artifact_path, model_providers=stub_model_providers)
275+
276+
# Test default values
277+
assert data_designer._run_config.disable_early_shutdown is False
278+
assert data_designer._run_config.shutdown_error_rate == 0.5
279+
assert data_designer._run_config.shutdown_error_window == 10
280+
281+
# Test setting custom values (note: shutdown_error_rate is normalized to 1.0 when disabled)
282+
data_designer.set_run_config(
283+
RunConfig(
284+
disable_early_shutdown=True,
285+
shutdown_error_rate=0.8,
286+
shutdown_error_window=25,
287+
)
288+
)
289+
assert data_designer._run_config.disable_early_shutdown is True
290+
assert data_designer._run_config.shutdown_error_rate == 1.0 # normalized when disabled
291+
assert data_designer._run_config.shutdown_error_window == 25
292+
293+
# Test updating values
294+
data_designer.set_run_config(
295+
RunConfig(
296+
disable_early_shutdown=False,
297+
shutdown_error_rate=0.3,
298+
shutdown_error_window=5,
299+
)
300+
)
301+
assert data_designer._run_config.disable_early_shutdown is False
302+
assert data_designer._run_config.shutdown_error_rate == 0.3
303+
assert data_designer._run_config.shutdown_error_window == 5
304+
305+
306+
def test_run_config_normalizes_error_rate_when_disabled(stub_artifact_path, stub_model_providers):
307+
"""Test that shutdown_error_rate is normalized to 1.0 when disabled."""
308+
data_designer = DataDesigner(artifact_path=stub_artifact_path, model_providers=stub_model_providers)
309+
310+
# When enabled (default), shutdown_error_rate should use the configured value
311+
data_designer.set_run_config(
312+
RunConfig(
313+
disable_early_shutdown=False,
314+
shutdown_error_rate=0.7,
315+
)
316+
)
317+
assert data_designer._run_config.shutdown_error_rate == 0.7
318+
319+
# When disabled, shutdown_error_rate should be normalized to 1.0
320+
data_designer.set_run_config(
321+
RunConfig(
322+
disable_early_shutdown=True,
323+
shutdown_error_rate=0.7,
324+
)
325+
)
326+
assert data_designer._run_config.shutdown_error_rate == 1.0
327+
328+
271329
def test_multiple_seed_references_can_be_created():
272330
"""Test that multiple seed references can be created from different sources."""
273331
with tempfile.TemporaryDirectory() as temp_dir:

0 commit comments

Comments
 (0)