Skip to content

Commit 1b18b73

Browse files
committed
move classification functions to engine; remove required resources
1 parent 55453da commit 1b18b73

File tree

23 files changed

+61
-177
lines changed

23 files changed

+61
-177
lines changed

docs/plugins/example.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ class IndexMultiplierColumnGenerator(ColumnGenerator[IndexMultiplierColumnConfig
8181
name="index-multiplier",
8282
description="Generates values by multiplying the row index by a user-specified multiplier",
8383
generation_strategy=GenerationStrategy.FULL_COLUMN,
84-
required_resources=None,
8584
)
8685

8786
def generate(self, data: pd.DataFrame) -> pd.DataFrame:
@@ -110,7 +109,6 @@ class IndexMultiplierColumnGenerator(ColumnGenerator[IndexMultiplierColumnConfig
110109
- `metadata()` describes your generator and its requirements
111110
- `generation_strategy` can be `FULL_COLUMN`, `CELL_BY_CELL`
112111
- You have access to the configuration parameters via `self.config`
113-
- `required_resources` lists any required resources (models, artifact storages, etc.). This parameter will evolve in the near future, so keeping it as `None` is safe for now. That said, if your task will use the model registry, adding `data_designer.engine.resources.ResourceType.MODEL_REGISTRY` will enable automatic model health checking for your column generation task.
114112

115113
!!! info "Understanding generation_strategy"
116114
The `generation_strategy` specifies how the column generator will generate data.
@@ -179,7 +177,6 @@ class IndexMultiplierColumnGenerator(ColumnGenerator[IndexMultiplierColumnConfig
179177
name="index-multiplier",
180178
description="Generates values by multiplying the row index by a user-specified multiplier",
181179
generation_strategy=GenerationStrategy.FULL_COLUMN,
182-
required_resources=None,
183180
)
184181

185182
def generate(self, data: pd.DataFrame) -> pd.DataFrame:

src/data_designer/config/column_types.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -62,41 +62,6 @@
6262
)
6363

6464

65-
def column_type_used_in_execution_dag(column_type: str | DataDesignerColumnType) -> bool:
66-
"""Return True if the column type is used in the workflow execution DAG."""
67-
column_type = resolve_string_enum(column_type, DataDesignerColumnType)
68-
dag_column_types = {
69-
DataDesignerColumnType.EXPRESSION,
70-
DataDesignerColumnType.LLM_CODE,
71-
DataDesignerColumnType.LLM_JUDGE,
72-
DataDesignerColumnType.LLM_STRUCTURED,
73-
DataDesignerColumnType.LLM_TEXT,
74-
DataDesignerColumnType.VALIDATION,
75-
DataDesignerColumnType.EMBEDDING,
76-
}
77-
dag_column_types.update(plugin_manager.get_plugin_column_types(DataDesignerColumnType))
78-
return column_type in dag_column_types
79-
80-
81-
def column_type_is_model_generated(column_type: str | DataDesignerColumnType) -> bool:
82-
"""Return True if the column type is a model-generated column."""
83-
column_type = resolve_string_enum(column_type, DataDesignerColumnType)
84-
model_generated_column_types = {
85-
DataDesignerColumnType.LLM_TEXT,
86-
DataDesignerColumnType.LLM_CODE,
87-
DataDesignerColumnType.LLM_STRUCTURED,
88-
DataDesignerColumnType.LLM_JUDGE,
89-
DataDesignerColumnType.EMBEDDING,
90-
}
91-
model_generated_column_types.update(
92-
plugin_manager.get_plugin_column_types(
93-
DataDesignerColumnType,
94-
required_resources=["model_registry"],
95-
)
96-
)
97-
return column_type in model_generated_column_types
98-
99-
10065
def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType, **kwargs) -> ColumnConfigT:
10166
"""Create a Data Designer column config object from kwargs.
10267

src/data_designer/config/config_builder.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from data_designer.config.column_types import (
1919
ColumnConfigT,
2020
DataDesignerColumnType,
21-
column_type_is_model_generated,
2221
get_column_config_from_kwargs,
2322
get_column_display_order,
2423
)
@@ -422,23 +421,6 @@ def get_constraints(self, target_column: str) -> list[ColumnConstraintT]:
422421
"""
423422
return [c for c in self._constraints if c.target_column == target_column]
424423

425-
def get_llm_gen_columns(self) -> list[ColumnConfigT]:
426-
"""Get all model-generated column configurations.
427-
428-
Returns:
429-
A list of column configurations that use model generation.
430-
"""
431-
logger.warning("get_llm_gen_columns is deprecated. Use get_model_gen_columns instead.")
432-
return self.get_model_gen_columns()
433-
434-
def get_model_gen_columns(self) -> list[ColumnConfigT]:
435-
"""Get all model-generated column configurations.
436-
437-
Returns:
438-
A list of column configurations that use model generation.
439-
"""
440-
return [c for c in self._column_configs.values() if column_type_is_model_generated(c.column_type)]
441-
442424
def get_columns_of_type(self, column_type: DataDesignerColumnType) -> list[ColumnConfigT]:
443425
"""Get all column configurations of the specified type.
444426

src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,11 @@
3131
)
3232
from data_designer.engine.models.facade import ModelFacade
3333
from data_designer.engine.models.recipes.response_recipes import TextResponseRecipe
34-
from data_designer.engine.resources.resource_provider import ResourceType
3534

3635
logger = logging.getLogger(__name__)
3736

3837

3938
class JudgeScoreProfiler(ColumnProfiler[JudgeScoreProfilerConfig]):
40-
@staticmethod
41-
def get_required_resources() -> list[ResourceType]:
42-
return [ResourceType.MODEL_REGISTRY]
43-
4439
@staticmethod
4540
def metadata() -> ColumnProfilerMetadata:
4641
return ColumnProfilerMetadata(

src/data_designer/engine/column_generators/generators/base.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import pandas as pd
1313

1414
from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata, DataT, TaskConfigT
15-
from data_designer.engine.resources.resource_provider import ResourceType
1615

1716
if TYPE_CHECKING:
1817
from data_designer.config.models import BaseInferenceParams, ModelConfig
@@ -56,10 +55,6 @@ def generate(self, data: pd.DataFrame) -> pd.DataFrame: ...
5655
@abstractmethod
5756
def generate(self, data: DataT) -> DataT: ...
5857

59-
@staticmethod
60-
def get_required_resources() -> list[ResourceType]:
61-
return []
62-
6358
def log_pre_generation(self) -> None:
6459
"""A shared method to log info before the generator's `generate` method is called.
6560
@@ -79,17 +74,10 @@ def generate_from_scratch(self, num_records: int) -> pd.DataFrame: ...
7974

8075

8176
class ColumnGeneratorWithModelRegistry(ColumnGenerator[TaskConfigT], ABC):
82-
@staticmethod
83-
def get_required_resources() -> list[ResourceType]:
84-
return [ResourceType.MODEL_REGISTRY]
85-
8677
@property
8778
def model_registry(self) -> ModelRegistry:
8879
return self.resource_provider.model_registry
8980

90-
def get_inference_parameters(self, model_alias: str) -> BaseInferenceParams:
91-
return self.get_model_config(model_alias=model_alias).inference_parameters
92-
9381
def get_model(self, model_alias: str) -> ModelFacade:
9482
return self.model_registry.get_model(model_alias=model_alias)
9583

@@ -112,7 +100,7 @@ def model_config(self) -> ModelConfig:
112100

113101
@functools.cached_property
114102
def inference_parameters(self) -> BaseInferenceParams:
115-
return self.get_inference_parameters(model_alias=self.config.model_alias)
103+
return self.model_config.inference_parameters
116104

117105
def log_pre_generation(self) -> None:
118106
logger.info(f"Preparing {self.config.column_type} column generation")

src/data_designer/engine/column_generators/generators/samplers.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import logging
77
import random
88
from functools import partial
9-
from typing import TYPE_CHECKING, Callable
9+
from typing import Callable
1010

1111
import pandas as pd
1212

@@ -23,17 +23,10 @@
2323
from data_designer.engine.sampling_gen.entities.person import load_person_data_sampler
2424
from data_designer.engine.sampling_gen.generator import DatasetGenerator as SamplingDatasetGenerator
2525

26-
if TYPE_CHECKING:
27-
from data_designer.engine.resources.resource_provider import ResourceType
28-
2926
logger = logging.getLogger(__name__)
3027

3128

3229
class SamplerColumnGenerator(FromScratchColumnGenerator[SamplerMultiColumnConfig]):
33-
@staticmethod
34-
def get_required_resources() -> list[ResourceType]:
35-
return [ResourceType.BLOB_STORAGE]
36-
3730
@staticmethod
3831
def metadata() -> GeneratorMetadata:
3932
return GeneratorMetadata(

src/data_designer/engine/column_generators/generators/seed_dataset.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,13 @@
1919
from data_designer.engine.column_generators.utils.errors import SeedDatasetError
2020
from data_designer.engine.dataset_builders.multi_column_configs import SeedDatasetMultiColumnConfig
2121
from data_designer.engine.processing.utils import concat_datasets
22-
from data_designer.engine.resources.resource_provider import ResourceType
2322

2423
MAX_ZERO_RECORD_RESPONSE_FACTOR = 2
2524

2625
logger = logging.getLogger(__name__)
2726

2827

2928
class SeedDatasetColumnGenerator(FromScratchColumnGenerator[SeedDatasetMultiColumnConfig]):
30-
@staticmethod
31-
def get_required_resources() -> list[ResourceType]:
32-
return [ResourceType.DATASTORE]
33-
3429
@staticmethod
3530
def metadata() -> GeneratorMetadata:
3631
return GeneratorMetadata(
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
5+
from data_designer.config.column_types import DataDesignerColumnType
6+
from data_designer.config.utils.type_helpers import resolve_string_enum
7+
from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModelRegistry
8+
from data_designer.plugin_manager import PluginManager
9+
10+
plugin_manager = PluginManager()
11+
12+
13+
def column_type_used_in_execution_dag(column_type: str | DataDesignerColumnType) -> bool:
14+
"""Return True if the column type is used in the workflow execution DAG."""
15+
column_type = resolve_string_enum(column_type, DataDesignerColumnType)
16+
dag_column_types = {
17+
DataDesignerColumnType.EXPRESSION,
18+
DataDesignerColumnType.LLM_CODE,
19+
DataDesignerColumnType.LLM_JUDGE,
20+
DataDesignerColumnType.LLM_STRUCTURED,
21+
DataDesignerColumnType.LLM_TEXT,
22+
DataDesignerColumnType.VALIDATION,
23+
DataDesignerColumnType.EMBEDDING,
24+
}
25+
dag_column_types.update(plugin_manager.get_plugin_column_types(DataDesignerColumnType))
26+
return column_type in dag_column_types
27+
28+
29+
def column_type_is_model_generated(column_type: str | DataDesignerColumnType) -> bool:
30+
"""Return True if the column type is a model-generated column."""
31+
column_type = resolve_string_enum(column_type, DataDesignerColumnType)
32+
model_generated_column_types = {
33+
DataDesignerColumnType.LLM_TEXT,
34+
DataDesignerColumnType.LLM_CODE,
35+
DataDesignerColumnType.LLM_STRUCTURED,
36+
DataDesignerColumnType.LLM_JUDGE,
37+
DataDesignerColumnType.EMBEDDING,
38+
}
39+
for plugin in plugin_manager.get_column_generator_plugins():
40+
if issubclass(plugin.impl_cls, ColumnGeneratorWithModelRegistry):
41+
model_generated_column_types.add(plugin.name)
42+
return column_type in model_generated_column_types

src/data_designer/engine/configurable_task.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from data_designer.config.base import ConfigBase
1111
from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
12-
from data_designer.engine.resources.resource_provider import ResourceProvider, ResourceType
12+
from data_designer.engine.resources.resource_provider import ResourceProvider
1313

1414
DataT = TypeVar("DataT", dict, pd.DataFrame)
1515
TaskConfigT = TypeVar("ConfigT", bound=ConfigBase)
@@ -65,10 +65,6 @@ def resource_provider(self) -> ResourceProvider:
6565
@abstractmethod
6666
def metadata() -> ConfigurableTaskMetadata: ...
6767

68-
@staticmethod
69-
@abstractmethod
70-
def get_required_resources() -> list[ResourceType]: ...
71-
7268
def _initialize(self) -> None:
7369
"""An internal method for custom initialization logic, which will be called in the constructor."""
7470

src/data_designer/engine/dataset_builders/column_wise_builder.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import pandas as pd
1515

16-
from data_designer.config.column_types import ColumnConfigT, column_type_is_model_generated
16+
from data_designer.config.column_types import ColumnConfigT
1717
from data_designer.config.dataset_builders import BuildStage
1818
from data_designer.config.processors import (
1919
DropColumnsProcessorConfig,
@@ -25,6 +25,7 @@
2525
ColumnGeneratorWithSingleModel,
2626
GenerationStrategy,
2727
)
28+
from data_designer.engine.column_generators.utils.generator_classification import column_type_is_model_generated
2829
from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
2930
from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError
3031
from data_designer.engine.dataset_builders.multi_column_configs import (
@@ -42,7 +43,7 @@
4243
from data_designer.engine.processing.processors.base import Processor
4344
from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor
4445
from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry
45-
from data_designer.engine.resources.resource_provider import ResourceProvider, ResourceType
46+
from data_designer.engine.resources.resource_provider import ResourceProvider
4647

4748
if TYPE_CHECKING:
4849
from data_designer.engine.models.usage import ModelUsageStats
@@ -192,7 +193,7 @@ def _run_from_scratch_column_generator(self, generator: ColumnGenerator) -> None
192193

193194
def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None:
194195
max_workers = MAX_CONCURRENCY_PER_NON_LLM_GENERATOR
195-
if ResourceType.MODEL_REGISTRY in generator.get_required_resources():
196+
if isinstance(generator, ColumnGeneratorWithSingleModel):
196197
max_workers = generator.inference_parameters.max_parallel_requests
197198
self._fan_out_with_threads(generator, max_workers=max_workers)
198199

0 commit comments

Comments
 (0)