Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,19 @@ def __init__(
processor_configs: list[ProcessorConfig],
resource_provider: ResourceProvider,
registry: DataDesignerRegistry | None = None,
enable_early_shutdown: bool = True,
shutdown_error_rate: float = 0.5,
shutdown_error_window: int = 10,
):
self.batch_manager = DatasetBatchManager(resource_provider.artifact_storage)
self._resource_provider = resource_provider
self._records_to_drop: set[int] = set()
self._registry = registry or DataDesignerRegistry()
self._column_configs = column_configs
self._processors: dict[BuildStage, list[Processor]] = self._initialize_processors(processor_configs)
self._enable_early_shutdown = enable_early_shutdown
self._shutdown_error_rate = shutdown_error_rate
self._shutdown_error_window = shutdown_error_window
self._validate_column_configs()

@property
Expand Down Expand Up @@ -222,6 +228,8 @@ def _fan_out_with_threads(self, generator: WithModelGeneration, max_workers: int
column_name=generator.config.name,
result_callback=self._worker_result_callback,
error_callback=self._worker_error_callback,
shutdown_error_rate=self._shutdown_error_rate if self._enable_early_shutdown else 1.0,
shutdown_error_window=self._shutdown_error_window,
) as executor:
for i, record in self.batch_manager.iter_current_batch():
executor.submit(lambda record: generator.generate(record), record, context={"index": i})
Expand Down
30 changes: 28 additions & 2 deletions src/data_designer/interface/data_designer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ def create(
*,
num_records: int = DEFAULT_NUM_RECORDS,
dataset_name: str = "dataset",
enable_early_shutdown: bool = True,
shutdown_error_rate: float = 0.5,
shutdown_error_window: int = 10,
) -> DatasetCreationResults:
"""Create dataset and save results to the local artifact storage.

Expand All @@ -171,6 +174,15 @@ def create(
a datetime stamp. For example, if the dataset name is "awesome_dataset" and a directory
with the same name already exists, the dataset will be saved to a new directory
with the name "awesome_dataset_2025-01-01_12-00-00".
enable_early_shutdown: If True (default), dataset generation will terminate
early if the error rate exceeds `shutdown_error_rate` after
`shutdown_error_window` tasks complete. Set to False to disable
early shutdown entirely (ignores `shutdown_error_rate` and
`shutdown_error_window`).
shutdown_error_rate: Error rate threshold (0.0-1.0) that triggers early
shutdown. Only used when `enable_early_shutdown=True`. Default is 0.5 (50%).
shutdown_error_window: Minimum number of completed tasks before error rate
monitoring begins. Only used when `enable_early_shutdown=True`. Default is 10.

Returns:
DatasetCreationResults object with methods for loading the generated dataset,
Expand All @@ -184,7 +196,13 @@ def create(

resource_provider = self._create_resource_provider(dataset_name, config_builder)

builder = self._create_dataset_builder(config_builder, resource_provider)
builder = self._create_dataset_builder(
config_builder,
resource_provider,
enable_early_shutdown=enable_early_shutdown,
shutdown_error_rate=shutdown_error_rate,
shutdown_error_window=shutdown_error_window,
)

try:
builder.build(num_records=num_records, buffer_size=self._buffer_size)
Expand Down Expand Up @@ -334,12 +352,20 @@ def _resolve_model_providers(self, model_providers: list[ModelProvider] | None)
return model_providers or []

def _create_dataset_builder(
self, config_builder: DataDesignerConfigBuilder, resource_provider: ResourceProvider
self,
config_builder: DataDesignerConfigBuilder,
resource_provider: ResourceProvider,
enable_early_shutdown: bool = True,
shutdown_error_rate: float = 0.5,
shutdown_error_window: int = 10,
) -> ColumnWiseDatasetBuilder:
return ColumnWiseDatasetBuilder(
column_configs=compile_dataset_builder_column_configs(config_builder.build(raise_exceptions=True)),
processor_configs=config_builder.get_processor_configs(),
resource_provider=resource_provider,
enable_early_shutdown=enable_early_shutdown,
shutdown_error_rate=shutdown_error_rate,
shutdown_error_window=shutdown_error_window,
)

def _create_dataset_profiler(
Expand Down
Loading