diff --git a/README.md b/README.md index 21610ba1..74f56bb5 100644 --- a/README.md +++ b/README.md @@ -139,7 +139,7 @@ ModelConfig( top_p=0.9, max_tokens=4096, ), - ) +) ``` The value `openai/gpt-oss-20b` would be collected. diff --git a/docs/plugins/example.md b/docs/plugins/example.md index cb398f88..4d52f19f 100644 --- a/docs/plugins/example.md +++ b/docs/plugins/example.md @@ -81,7 +81,6 @@ class IndexMultiplierColumnGenerator(ColumnGenerator[IndexMultiplierColumnConfig name="index-multiplier", description="Generates values by multiplying the row index by a user-specified multiplier", generation_strategy=GenerationStrategy.FULL_COLUMN, - required_resources=None, ) def generate(self, data: pd.DataFrame) -> pd.DataFrame: @@ -110,7 +109,6 @@ class IndexMultiplierColumnGenerator(ColumnGenerator[IndexMultiplierColumnConfig - `metadata()` describes your generator and its requirements - `generation_strategy` can be `FULL_COLUMN`, `CELL_BY_CELL` - You have access to the configuration parameters via `self.config` -- `required_resources` lists any required resources (models, artifact storages, etc.). This parameter will evolve in the near future, so keeping it as `None` is safe for now. That said, if your task will use the model registry, adding `data_designer.engine.resources.ResourceType.MODEL_REGISTRY` will enable automatic model health checking for your column generation task. !!! info "Understanding generation_strategy" The `generation_strategy` specifies how the column generator will generate data. @@ -179,7 +177,6 @@ class IndexMultiplierColumnGenerator(ColumnGenerator[IndexMultiplierColumnConfig name="index-multiplier", description="Generates values by multiplying the row index by a user-specified multiplier", generation_strategy=GenerationStrategy.FULL_COLUMN, - required_resources=None, ) def generate(self, data: pd.DataFrame) -> pd.DataFrame: diff --git a/src/data_designer/config/column_types.py b/src/data_designer/config/column_types.py index 7583d4ac..93b233c9 100644 --- a/src/data_designer/config/column_types.py +++ b/src/data_designer/config/column_types.py @@ -62,41 +62,6 @@ ) -def column_type_used_in_execution_dag(column_type: str | DataDesignerColumnType) -> bool: - """Return True if the column type is used in the workflow execution DAG.""" - column_type = resolve_string_enum(column_type, DataDesignerColumnType) - dag_column_types = { - DataDesignerColumnType.EXPRESSION, - DataDesignerColumnType.LLM_CODE, - DataDesignerColumnType.LLM_JUDGE, - DataDesignerColumnType.LLM_STRUCTURED, - DataDesignerColumnType.LLM_TEXT, - DataDesignerColumnType.VALIDATION, - DataDesignerColumnType.EMBEDDING, - } - dag_column_types.update(plugin_manager.get_plugin_column_types(DataDesignerColumnType)) - return column_type in dag_column_types - - -def column_type_is_model_generated(column_type: str | DataDesignerColumnType) -> bool: - """Return True if the column type is a model-generated column.""" - column_type = resolve_string_enum(column_type, DataDesignerColumnType) - model_generated_column_types = { - DataDesignerColumnType.LLM_TEXT, - DataDesignerColumnType.LLM_CODE, - DataDesignerColumnType.LLM_STRUCTURED, - DataDesignerColumnType.LLM_JUDGE, - DataDesignerColumnType.EMBEDDING, - } - model_generated_column_types.update( - plugin_manager.get_plugin_column_types( - DataDesignerColumnType, - required_resources=["model_registry"], - ) - ) - return column_type in model_generated_column_types - - def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType, **kwargs) -> ColumnConfigT: """Create a Data Designer column config object from kwargs. diff --git a/src/data_designer/config/config_builder.py b/src/data_designer/config/config_builder.py index 1852dbc9..52bdb582 100644 --- a/src/data_designer/config/config_builder.py +++ b/src/data_designer/config/config_builder.py @@ -18,7 +18,6 @@ from data_designer.config.column_types import ( ColumnConfigT, DataDesignerColumnType, - column_type_is_model_generated, get_column_config_from_kwargs, get_column_display_order, ) @@ -422,23 +421,6 @@ def get_constraints(self, target_column: str) -> list[ColumnConstraintT]: """ return [c for c in self._constraints if c.target_column == target_column] - def get_llm_gen_columns(self) -> list[ColumnConfigT]: - """Get all model-generated column configurations. - - Returns: - A list of column configurations that use model generation. - """ - logger.warning("get_llm_gen_columns is deprecated. Use get_model_gen_columns instead.") - return self.get_model_gen_columns() - - def get_model_gen_columns(self) -> list[ColumnConfigT]: - """Get all model-generated column configurations. - - Returns: - A list of column configurations that use model generation. - """ - return [c for c in self._column_configs.values() if column_type_is_model_generated(c.column_type)] - def get_columns_of_type(self, column_type: DataDesignerColumnType) -> list[ColumnConfigT]: """Get all column configurations of the specified type. diff --git a/src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py b/src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py index 3a411f23..366cdc1f 100644 --- a/src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py +++ b/src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py @@ -31,7 +31,6 @@ ) from data_designer.engine.models.facade import ModelFacade from data_designer.engine.models.recipes.response_recipes import TextResponseRecipe -from data_designer.engine.resources.resource_provider import ResourceType logger = logging.getLogger(__name__) @@ -42,7 +41,6 @@ def metadata() -> ColumnProfilerMetadata: return ColumnProfilerMetadata( name="judge_score_profiler", description="Analyzes LLM-as-judge score distributions in a Data Designer dataset.", - required_resources=[ResourceType.MODEL_REGISTRY], applicable_column_types=[DataDesignerColumnType.LLM_JUDGE], ) diff --git a/src/data_designer/engine/column_generators/generators/base.py b/src/data_designer/engine/column_generators/generators/base.py index fa477b85..c1897263 100644 --- a/src/data_designer/engine/column_generators/generators/base.py +++ b/src/data_designer/engine/column_generators/generators/base.py @@ -16,6 +16,7 @@ if TYPE_CHECKING: from data_designer.config.models import BaseInferenceParams, ModelConfig from data_designer.engine.models.facade import ModelFacade + from data_designer.engine.models.registry import ModelRegistry logger = logging.getLogger(__name__) @@ -72,27 +73,39 @@ def can_generate_from_scratch(self) -> bool: def generate_from_scratch(self, num_records: int) -> pd.DataFrame: ... -class WithModelGeneration: +class ColumnGeneratorWithModelRegistry(ColumnGenerator[TaskConfigT], ABC): + @property + def model_registry(self) -> ModelRegistry: + return self.resource_provider.model_registry + + def get_model(self, model_alias: str) -> ModelFacade: + return self.model_registry.get_model(model_alias=model_alias) + + def get_model_config(self, model_alias: str) -> ModelConfig: + return self.model_registry.get_model_config(model_alias=model_alias) + + def get_model_provider_name(self, model_alias: str) -> str: + provider = self.model_registry.get_model_provider(model_alias=model_alias) + return provider.name + + +class ColumnGeneratorWithModel(ColumnGeneratorWithModelRegistry[TaskConfigT], ABC): @functools.cached_property def model(self) -> ModelFacade: - return self.resource_provider.model_registry.get_model(model_alias=self.config.model_alias) + return self.get_model(model_alias=self.config.model_alias) @functools.cached_property def model_config(self) -> ModelConfig: - return self.resource_provider.model_registry.get_model_config(model_alias=self.config.model_alias) + return self.get_model_config(model_alias=self.config.model_alias) @functools.cached_property def inference_parameters(self) -> BaseInferenceParams: return self.model_config.inference_parameters def log_pre_generation(self) -> None: - logger.info(f"Preparing {self.config.column_type} column generation") - logger.info(f" |-- column name: {self.config.name!r}") - logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}") - if self.model_config.provider is None: - logger.info(f" |-- default model provider: {self._get_provider_name()!r}") - - def _get_provider_name(self) -> str: - model_alias = self.model_config.alias - provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias) - return provider.name + logger.info(f"{self.config.column_type} model configuration for generating column '{self.config.name}'") + logger.info(f" |-- model: {self.model_config.model!r}") + logger.info(f" |-- model alias: {self.config.model_alias!r}") + logger.info(f" |-- model provider: {self.get_model_provider_name(model_alias=self.config.model_alias)!r}") + logger.info(f" |-- generation type: {self.model_config.generation_type.value!r}") + logger.info(f" |-- inference parameters: {self.inference_parameters.format_for_display()}") diff --git a/src/data_designer/engine/column_generators/generators/embedding.py b/src/data_designer/engine/column_generators/generators/embedding.py index a8623db5..e85bb9f0 100644 --- a/src/data_designer/engine/column_generators/generators/embedding.py +++ b/src/data_designer/engine/column_generators/generators/embedding.py @@ -6,13 +6,11 @@ from data_designer.config.column_configs import EmbeddingColumnConfig from data_designer.engine.column_generators.generators.base import ( - ColumnGenerator, + ColumnGeneratorWithModel, GenerationStrategy, GeneratorMetadata, - WithModelGeneration, ) from data_designer.engine.processing.utils import deserialize_json_values, parse_list_string -from data_designer.engine.resources.resource_provider import ResourceType class EmbeddingGenerationResult(BaseModel): @@ -27,14 +25,13 @@ def dimension(self) -> int: return len(self.embeddings[0]) if len(self.embeddings) > 0 else 0 -class EmbeddingCellGenerator(WithModelGeneration, ColumnGenerator[EmbeddingColumnConfig]): +class EmbeddingCellGenerator(ColumnGeneratorWithModel[EmbeddingColumnConfig]): @staticmethod def metadata() -> GeneratorMetadata: return GeneratorMetadata( name="embedding_cell_generator", description="Generate embeddings for a text column.", generation_strategy=GenerationStrategy.CELL_BY_CELL, - required_resources=[ResourceType.MODEL_REGISTRY], ) def generate(self, data: dict) -> dict: diff --git a/src/data_designer/engine/column_generators/generators/expression.py b/src/data_designer/engine/column_generators/generators/expression.py index 00ce4771..6b5e40fb 100644 --- a/src/data_designer/engine/column_generators/generators/expression.py +++ b/src/data_designer/engine/column_generators/generators/expression.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import logging import pandas as pd @@ -25,7 +27,6 @@ def metadata() -> GeneratorMetadata: name="expression_generator", description="Generate a column from a jinja2 expression.", generation_strategy=GenerationStrategy.FULL_COLUMN, - required_resources=None, ) def generate(self, data: pd.DataFrame) -> pd.DataFrame: diff --git a/src/data_designer/engine/column_generators/generators/llm_completion.py b/src/data_designer/engine/column_generators/generators/llm_completion.py index ef47ed01..ebd82885 100644 --- a/src/data_designer/engine/column_generators/generators/llm_completion.py +++ b/src/data_designer/engine/column_generators/generators/llm_completion.py @@ -12,19 +12,18 @@ ) from data_designer.config.utils.constants import REASONING_TRACE_COLUMN_POSTFIX from data_designer.engine.column_generators.generators.base import ( - ColumnGenerator, + ColumnGeneratorWithModel, GenerationStrategy, GeneratorMetadata, - WithModelGeneration, ) from data_designer.engine.column_generators.utils.prompt_renderer import ( PromptType, RecordBasedPromptRenderer, create_response_recipe, ) +from data_designer.engine.configurable_task import TaskConfigT from data_designer.engine.models.recipes.base import ResponseRecipe from data_designer.engine.processing.utils import deserialize_json_values -from data_designer.engine.resources.resource_provider import ResourceType logger = logging.getLogger(__name__) @@ -33,7 +32,7 @@ DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS = 0 -class WithChatCompletionGeneration(WithModelGeneration): +class ColumnGeneratorWithModelChatCompletion(ColumnGeneratorWithModel[TaskConfigT]): @functools.cached_property def response_recipe(self) -> ResponseRecipe: return create_response_recipe(self.config, self.model_config) @@ -92,47 +91,43 @@ def generate(self, data: dict) -> dict: return data -class LLMTextCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMTextColumnConfig]): +class LLMTextCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMTextColumnConfig]): @staticmethod def metadata() -> GeneratorMetadata: return GeneratorMetadata( name="llm_text_generator", description="Generate a new dataset cell from a prompt template", generation_strategy=GenerationStrategy.CELL_BY_CELL, - required_resources=[ResourceType.MODEL_REGISTRY], ) -class LLMCodeCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMCodeColumnConfig]): +class LLMCodeCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMCodeColumnConfig]): @staticmethod def metadata() -> GeneratorMetadata: return GeneratorMetadata( name="llm_code_generator", description="Generate a new dataset cell from a prompt template", generation_strategy=GenerationStrategy.CELL_BY_CELL, - required_resources=[ResourceType.MODEL_REGISTRY], ) -class LLMStructuredCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMStructuredColumnConfig]): +class LLMStructuredCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMStructuredColumnConfig]): @staticmethod def metadata() -> GeneratorMetadata: return GeneratorMetadata( name="llm_structured_generator", description="Generate a new dataset cell from a prompt template", generation_strategy=GenerationStrategy.CELL_BY_CELL, - required_resources=[ResourceType.MODEL_REGISTRY], ) -class LLMJudgeCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMJudgeColumnConfig]): +class LLMJudgeCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMJudgeColumnConfig]): @staticmethod def metadata() -> GeneratorMetadata: return GeneratorMetadata( name="llm_judge_generator", description="Judge a new dataset cell based on a set of rubrics", generation_strategy=GenerationStrategy.CELL_BY_CELL, - required_resources=[ResourceType.MODEL_REGISTRY], ) @property diff --git a/src/data_designer/engine/column_generators/generators/samplers.py b/src/data_designer/engine/column_generators/generators/samplers.py index f57624eb..2c635ff3 100644 --- a/src/data_designer/engine/column_generators/generators/samplers.py +++ b/src/data_designer/engine/column_generators/generators/samplers.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import logging import random from functools import partial @@ -17,7 +19,6 @@ from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig from data_designer.engine.processing.utils import concat_datasets from data_designer.engine.resources.managed_dataset_generator import ManagedDatasetGenerator -from data_designer.engine.resources.resource_provider import ResourceType from data_designer.engine.sampling_gen.data_sources.sources import SamplerType from data_designer.engine.sampling_gen.entities.person import load_person_data_sampler from data_designer.engine.sampling_gen.generator import DatasetGenerator as SamplingDatasetGenerator @@ -32,7 +33,6 @@ def metadata() -> GeneratorMetadata: name="sampler_column_generator", description="Generate columns using sampling-based method.", generation_strategy=GenerationStrategy.FULL_COLUMN, - required_resources=[ResourceType.BLOB_STORAGE], ) def generate(self, data: pd.DataFrame) -> pd.DataFrame: diff --git a/src/data_designer/engine/column_generators/generators/seed_dataset.py b/src/data_designer/engine/column_generators/generators/seed_dataset.py index 08c2fed6..b89326cf 100644 --- a/src/data_designer/engine/column_generators/generators/seed_dataset.py +++ b/src/data_designer/engine/column_generators/generators/seed_dataset.py @@ -1,6 +1,9 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + import functools import logging @@ -16,7 +19,6 @@ from data_designer.engine.column_generators.utils.errors import SeedDatasetError from data_designer.engine.dataset_builders.multi_column_configs import SeedDatasetMultiColumnConfig from data_designer.engine.processing.utils import concat_datasets -from data_designer.engine.resources.resource_provider import ResourceType MAX_ZERO_RECORD_RESPONSE_FACTOR = 2 @@ -30,7 +32,6 @@ def metadata() -> GeneratorMetadata: name="seed_dataset_column_generator", description="Sample columns from a seed dataset.", generation_strategy=GenerationStrategy.FULL_COLUMN, - required_resources=[ResourceType.SEED_READER], ) @property diff --git a/src/data_designer/engine/column_generators/generators/validation.py b/src/data_designer/engine/column_generators/generators/validation.py index a2e2b3c9..d33c9d6e 100644 --- a/src/data_designer/engine/column_generators/generators/validation.py +++ b/src/data_designer/engine/column_generators/generators/validation.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import logging import pandas as pd @@ -50,7 +52,6 @@ def metadata() -> GeneratorMetadata: name="validate", description="Validate data.", generation_strategy=GenerationStrategy.FULL_COLUMN, - required_resources=None, ) def generate(self, data: pd.DataFrame) -> pd.DataFrame: diff --git a/src/data_designer/engine/column_generators/utils/generator_classification.py b/src/data_designer/engine/column_generators/utils/generator_classification.py new file mode 100644 index 00000000..d87311a0 --- /dev/null +++ b/src/data_designer/engine/column_generators/utils/generator_classification.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +from data_designer.config.column_types import DataDesignerColumnType +from data_designer.config.utils.type_helpers import resolve_string_enum +from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModelRegistry +from data_designer.plugin_manager import PluginManager + +plugin_manager = PluginManager() + + +def column_type_used_in_execution_dag(column_type: str | DataDesignerColumnType) -> bool: + """Return True if the column type is used in the workflow execution DAG.""" + column_type = resolve_string_enum(column_type, DataDesignerColumnType) + dag_column_types = { + DataDesignerColumnType.EXPRESSION, + DataDesignerColumnType.LLM_CODE, + DataDesignerColumnType.LLM_JUDGE, + DataDesignerColumnType.LLM_STRUCTURED, + DataDesignerColumnType.LLM_TEXT, + DataDesignerColumnType.VALIDATION, + DataDesignerColumnType.EMBEDDING, + } + dag_column_types.update(plugin_manager.get_plugin_column_types(DataDesignerColumnType)) + return column_type in dag_column_types + + +def column_type_is_model_generated(column_type: str | DataDesignerColumnType) -> bool: + """Return True if the column type is a model-generated column.""" + column_type = resolve_string_enum(column_type, DataDesignerColumnType) + model_generated_column_types = { + DataDesignerColumnType.LLM_TEXT, + DataDesignerColumnType.LLM_CODE, + DataDesignerColumnType.LLM_STRUCTURED, + DataDesignerColumnType.LLM_JUDGE, + DataDesignerColumnType.EMBEDDING, + } + for plugin in plugin_manager.get_column_generator_plugins(): + if issubclass(plugin.impl_cls, ColumnGeneratorWithModelRegistry): + model_generated_column_types.add(plugin.name) + return column_type in model_generated_column_types diff --git a/src/data_designer/engine/configurable_task.py b/src/data_designer/engine/configurable_task.py index 38b2748f..19f39543 100644 --- a/src/data_designer/engine/configurable_task.py +++ b/src/data_designer/engine/configurable_task.py @@ -9,7 +9,7 @@ from data_designer.config.base import ConfigBase from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage -from data_designer.engine.resources.resource_provider import ResourceProvider, ResourceType +from data_designer.engine.resources.resource_provider import ResourceProvider DataT = TypeVar("DataT", dict, pd.DataFrame) TaskConfigT = TypeVar("ConfigT", bound=ConfigBase) @@ -18,14 +18,12 @@ class ConfigurableTaskMetadata(ConfigBase): name: str description: str - required_resources: list[ResourceType] | None class ConfigurableTask(ABC, Generic[TaskConfigT]): - def __init__(self, config: TaskConfigT, *, resource_provider: ResourceProvider | None): + def __init__(self, config: TaskConfigT, resource_provider: ResourceProvider): self._config = self.get_config_type().model_validate(config) self._resource_provider = resource_provider - self._validate_resources() self._validate() self._initialize() @@ -61,8 +59,6 @@ def config(self) -> TaskConfigT: @property def resource_provider(self) -> ResourceProvider: - if self._resource_provider is None: - raise ValueError(f"No resource provider provided for the `{self.metadata().name}` task.") return self._resource_provider @staticmethod @@ -74,9 +70,3 @@ def _initialize(self) -> None: def _validate(self) -> None: """An internal method for custom validation logic, which will be called in the constructor.""" - - def _validate_resources(self) -> None: - for resource in self.metadata().required_resources or []: - if resource is not None: - if getattr(self.resource_provider, ResourceType(resource).value) is None: - raise ValueError(f"Resource {resource} is required for the `{self.metadata().name}`") diff --git a/src/data_designer/engine/dataset_builders/column_wise_builder.py b/src/data_designer/engine/dataset_builders/column_wise_builder.py index 0f54ddf1..c22c3142 100644 --- a/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -1,5 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 + from __future__ import annotations import functools @@ -13,7 +14,7 @@ import pandas as pd -from data_designer.config.column_types import ColumnConfigT, column_type_is_model_generated +from data_designer.config.column_types import ColumnConfigT from data_designer.config.dataset_builders import BuildStage from data_designer.config.processors import ( DropColumnsProcessorConfig, @@ -22,9 +23,10 @@ ) from data_designer.engine.column_generators.generators.base import ( ColumnGenerator, + ColumnGeneratorWithModel, GenerationStrategy, - WithModelGeneration, ) +from data_designer.engine.column_generators.utils.generator_classification import column_type_is_model_generated from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError from data_designer.engine.dataset_builders.multi_column_configs import ( @@ -45,6 +47,7 @@ from data_designer.engine.resources.resource_provider import ResourceProvider if TYPE_CHECKING: + from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModelRegistry from data_designer.engine.models.usage import ModelUsageStats logger = logging.getLogger(__name__) @@ -192,7 +195,7 @@ def _run_from_scratch_column_generator(self, generator: ColumnGenerator) -> None def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None: max_workers = MAX_CONCURRENCY_PER_NON_LLM_GENERATOR - if isinstance(generator, WithModelGeneration): + if isinstance(generator, ColumnGeneratorWithModel): max_workers = generator.inference_parameters.max_parallel_requests self._fan_out_with_threads(generator, max_workers=max_workers) @@ -206,7 +209,7 @@ def _run_model_health_check_if_needed(self) -> bool: list(set(config.model_alias for config in self.llm_generated_column_configs)) ) - def _fan_out_with_threads(self, generator: WithModelGeneration, max_workers: int) -> None: + def _fan_out_with_threads(self, generator: ColumnGeneratorWithModelRegistry, max_workers: int) -> None: if generator.generation_strategy != GenerationStrategy.CELL_BY_CELL: raise DatasetGenerationError( f"Generator {generator.metadata().name} is not a {GenerationStrategy.CELL_BY_CELL} " diff --git a/src/data_designer/engine/dataset_builders/utils/dag.py b/src/data_designer/engine/dataset_builders/utils/dag.py index 6c056d11..0d653fcc 100644 --- a/src/data_designer/engine/dataset_builders/utils/dag.py +++ b/src/data_designer/engine/dataset_builders/utils/dag.py @@ -5,7 +5,8 @@ import networkx as nx -from data_designer.config.column_types import ColumnConfigT, column_type_used_in_execution_dag +from data_designer.config.column_types import ColumnConfigT +from data_designer.engine.column_generators.utils.generator_classification import column_type_used_in_execution_dag from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError logger = logging.getLogger(__name__) diff --git a/src/data_designer/engine/processing/processors/base.py b/src/data_designer/engine/processing/processors/base.py index 48c06576..d2b1f16b 100644 --- a/src/data_designer/engine/processing/processors/base.py +++ b/src/data_designer/engine/processing/processors/base.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + from abc import ABC, abstractmethod from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata, DataT, TaskConfigT diff --git a/src/data_designer/engine/processing/processors/drop_columns.py b/src/data_designer/engine/processing/processors/drop_columns.py index d5a137fe..a39c0aa9 100644 --- a/src/data_designer/engine/processing/processors/drop_columns.py +++ b/src/data_designer/engine/processing/processors/drop_columns.py @@ -19,7 +19,6 @@ def metadata() -> ConfigurableTaskMetadata: return ConfigurableTaskMetadata( name="drop_columns_processor", description="Drop columns from the input dataset.", - required_resources=None, ) def process(self, data: pd.DataFrame, *, current_batch_number: int | None = None) -> pd.DataFrame: diff --git a/src/data_designer/engine/processing/processors/schema_transform.py b/src/data_designer/engine/processing/processors/schema_transform.py index 2927177f..f8994939 100644 --- a/src/data_designer/engine/processing/processors/schema_transform.py +++ b/src/data_designer/engine/processing/processors/schema_transform.py @@ -22,7 +22,6 @@ def metadata() -> ConfigurableTaskMetadata: return ConfigurableTaskMetadata( name="schema_transform_processor", description="Generate dataset with transformed schema using a Jinja2 template.", - required_resources=None, ) @property diff --git a/src/data_designer/engine/validation.py b/src/data_designer/engine/validation.py index 4e951e65..9705cc6f 100644 --- a/src/data_designer/engine/validation.py +++ b/src/data_designer/engine/validation.py @@ -14,7 +14,7 @@ from rich.padding import Padding from rich.panel import Panel -from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType, column_type_is_model_generated +from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType from data_designer.config.processors import ProcessorConfigT, ProcessorType from data_designer.config.utils.constants import RICH_CONSOLE_THEME from data_designer.config.utils.misc import ( @@ -22,6 +22,7 @@ get_prompt_template_keywords, ) from data_designer.config.validator_params import ValidatorType +from data_designer.engine.column_generators.utils.generator_classification import column_type_is_model_generated class ViolationType(str, Enum): diff --git a/src/data_designer/plugin_manager.py b/src/data_designer/plugin_manager.py index d5d9433d..19bca2ea 100644 --- a/src/data_designer/plugin_manager.py +++ b/src/data_designer/plugin_manager.py @@ -37,22 +37,17 @@ def get_column_generator_plugin_if_exists(self, plugin_name: str) -> Plugin | No if self._plugin_registry.plugin_exists(plugin_name): return self._plugin_registry.get_plugin(plugin_name) - def get_plugin_column_types(self, enum_type: type[Enum], required_resources: list[str] | None = None) -> list[Enum]: + def get_plugin_column_types(self, enum_type: type[Enum]) -> list[Enum]: """Get a list of plugin column types. Args: enum_type: The enum type to use for plugin entries. - required_resources: If provided, only return plugins with the required resources. Returns: A list of plugin column types. """ type_list = [] for plugin in self._plugin_registry.get_plugins(PluginType.COLUMN_GENERATOR): - if required_resources: - task_required_resources = plugin.impl_cls.metadata().required_resources or [] - if not all(resource in task_required_resources for resource in required_resources): - continue type_list.append(enum_type(plugin.name)) return type_list diff --git a/src/data_designer/plugins/testing/stubs.py b/src/data_designer/plugins/testing/stubs.py index d6db1688..c624826e 100644 --- a/src/data_designer/plugins/testing/stubs.py +++ b/src/data_designer/plugins/testing/stubs.py @@ -6,7 +6,6 @@ from data_designer.config.base import ConfigBase from data_designer.config.column_configs import SingleColumnConfig from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata -from data_designer.engine.resources.resource_provider import ResourceType from data_designer.plugins.plugin import Plugin, PluginType MODULE_NAME = __name__ @@ -27,7 +26,6 @@ def metadata() -> ConfigurableTaskMetadata: return ConfigurableTaskMetadata( name="test_generator", description="Test generator", - required_resources=None, ) @@ -61,7 +59,6 @@ def metadata() -> ConfigurableTaskMetadata: return ConfigurableTaskMetadata( name="test_plugin_a", description="Test plugin A", - required_resources=None, ) @@ -71,7 +68,6 @@ def metadata() -> ConfigurableTaskMetadata: return ConfigurableTaskMetadata( name="test_plugin_b", description="Test plugin B", - required_resources=None, ) @@ -96,7 +92,6 @@ def metadata() -> ConfigurableTaskMetadata: return ConfigurableTaskMetadata( name="test_plugin_models", description="Test plugin requiring models", - required_resources=[ResourceType.MODEL_REGISTRY], ) @@ -106,7 +101,6 @@ def metadata() -> ConfigurableTaskMetadata: return ConfigurableTaskMetadata( name="test_plugin_models_and_blobs", description="Test plugin requiring models and blobs", - required_resources=[ResourceType.MODEL_REGISTRY, ResourceType.BLOB_STORAGE], ) @@ -116,7 +110,6 @@ def metadata() -> ConfigurableTaskMetadata: return ConfigurableTaskMetadata( name="test_plugin_blobs_and_seeds", description="Test plugin requiring blobs and seeds", - required_resources=[ResourceType.BLOB_STORAGE, ResourceType.SEED_READER], ) diff --git a/tests/config/test_columns.py b/tests/config/test_columns.py index cbc95756..24986d22 100644 --- a/tests/config/test_columns.py +++ b/tests/config/test_columns.py @@ -18,8 +18,6 @@ ) from data_designer.config.column_types import ( DataDesignerColumnType, - column_type_is_model_generated, - column_type_used_in_execution_dag, get_column_config_from_kwargs, get_column_display_order, ) @@ -56,30 +54,6 @@ def test_data_designer_column_type_get_display_order(): ] -def test_data_designer_column_type_is_llm_generated(): - assert column_type_is_model_generated(DataDesignerColumnType.LLM_TEXT) - assert column_type_is_model_generated(DataDesignerColumnType.LLM_CODE) - assert column_type_is_model_generated(DataDesignerColumnType.LLM_STRUCTURED) - assert column_type_is_model_generated(DataDesignerColumnType.LLM_JUDGE) - assert column_type_is_model_generated(DataDesignerColumnType.EMBEDDING) - assert not column_type_is_model_generated(DataDesignerColumnType.SAMPLER) - assert not column_type_is_model_generated(DataDesignerColumnType.VALIDATION) - assert not column_type_is_model_generated(DataDesignerColumnType.EXPRESSION) - assert not column_type_is_model_generated(DataDesignerColumnType.SEED_DATASET) - - -def test_data_designer_column_type_is_in_dag(): - assert column_type_used_in_execution_dag(DataDesignerColumnType.EXPRESSION) - assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_CODE) - assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_JUDGE) - assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_STRUCTURED) - assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_TEXT) - assert column_type_used_in_execution_dag(DataDesignerColumnType.VALIDATION) - assert column_type_used_in_execution_dag(DataDesignerColumnType.EMBEDDING) - assert not column_type_used_in_execution_dag(DataDesignerColumnType.SAMPLER) - assert not column_type_used_in_execution_dag(DataDesignerColumnType.SEED_DATASET) - - def test_sampler_column_config(): sampler_column_config = SamplerColumnConfig( name="test_sampler", diff --git a/tests/config/test_config_builder.py b/tests/config/test_config_builder.py index 549d316b..8f47de80 100644 --- a/tests/config/test_config_builder.py +++ b/tests/config/test_config_builder.py @@ -406,7 +406,6 @@ def test_getters(stub_data_designer_builder): assert len(stub_data_designer_builder.get_column_configs()) == 8 assert stub_data_designer_builder.get_column_config(name="code_id").name == "code_id" assert len(stub_data_designer_builder.get_constraints(target_column="age")) == 1 - assert len(stub_data_designer_builder.get_llm_gen_columns()) == 3 assert len(stub_data_designer_builder.get_columns_of_type(DataDesignerColumnType.SAMPLER)) == 4 assert len(stub_data_designer_builder.get_columns_excluding_type(DataDesignerColumnType.SAMPLER)) == 4 assert stub_data_designer_builder.get_seed_config().source.path == "datasets/test-repo/testing/data.csv" diff --git a/tests/engine/analysis/column_profilers/test_base.py b/tests/engine/analysis/column_profilers/test_base.py index c133e492..ea98624c 100644 --- a/tests/engine/analysis/column_profilers/test_base.py +++ b/tests/engine/analysis/column_profilers/test_base.py @@ -56,7 +56,6 @@ def test_column_profiler_metadata_creation(): name="test_profiler", description="Test profiler", applicable_column_types=[DataDesignerColumnType.SAMPLER, DataDesignerColumnType.LLM_TEXT], - required_resources=None, ) assert metadata.name == "test_profiler" diff --git a/tests/engine/column_generators/generators/test_column_generator_base.py b/tests/engine/column_generators/generators/test_column_generator_base.py index 511b1c0f..af3a7d70 100644 --- a/tests/engine/column_generators/generators/test_column_generator_base.py +++ b/tests/engine/column_generators/generators/test_column_generator_base.py @@ -21,7 +21,6 @@ def _create_test_metadata(name="test", description="test", strategy=GenerationSt name=name, description=description, generation_strategy=strategy, - required_resources=None, ) diff --git a/tests/engine/column_generators/generators/test_llm_completion_generators.py b/tests/engine/column_generators/generators/test_llm_completion_generators.py index 6cba3bdc..2e637ff2 100644 --- a/tests/engine/column_generators/generators/test_llm_completion_generators.py +++ b/tests/engine/column_generators/generators/test_llm_completion_generators.py @@ -31,12 +31,15 @@ def _create_generator_with_mocks(config_class=LLMTextColumnConfig, **config_kwar mock_inference_params = Mock() mock_prompt_renderer = Mock() mock_response_recipe = Mock() + mock_provider = Mock() mock_resource_provider.model_registry = mock_model_registry mock_model_registry.get_model.return_value = mock_model mock_model_registry.get_model_config.return_value = mock_model_config + mock_model_registry.get_model_provider.return_value = mock_provider mock_model_config.inference_parameters = mock_inference_params mock_model_config.alias = "test_model" + mock_provider.name = "test_provider" mock_inference_params.generate_kwargs = {"temperature": 0.7, "max_tokens": 100} @@ -95,25 +98,32 @@ def test_generate_method(): @patch("data_designer.engine.column_generators.generators.base.logger", autospec=True) -def test_log_pre_generation(mock_logger): - generator, mock_resource_provider, _, mock_model_config, _, _, _ = _create_generator_with_mocks() - mock_model_config.model_dump_json.return_value = '{"test": "config"}' +def test_log_pre_generation(mock_logger: Mock) -> None: + generator, mock_resource_provider, _, mock_model_config, mock_inference_params, _, _ = ( + _create_generator_with_mocks() + ) + mock_model_config.model = "meta/llama-3.1-8b-instruct" + mock_model_config.generation_type.value = "chat-completion" + mock_inference_params.format_for_display.return_value = "temperature=0.70, max_tokens=100" generator.log_pre_generation() - assert mock_logger.info.call_count == 3 - mock_logger.info.assert_any_call("Preparing llm-text column generation") - mock_logger.info.assert_any_call(" |-- column name: 'test_column'") - mock_logger.info.assert_any_call(' |-- model config:\n{"test": "config"}') + assert mock_logger.info.call_count == 6 + mock_logger.info.assert_any_call("llm-text model configuration for generating column 'test_column'") + mock_logger.info.assert_any_call(" |-- model: 'meta/llama-3.1-8b-instruct'") + mock_logger.info.assert_any_call(" |-- model alias: 'test_model'") + mock_logger.info.assert_any_call(" |-- model provider: 'test_provider'") + mock_logger.info.assert_any_call(" |-- generation type: 'chat-completion'") + mock_logger.info.assert_any_call(" |-- inference parameters: temperature=0.70, max_tokens=100") - # Test with provider - mock_model_config.provider = None + # Test with different provider + mock_logger.reset_mock() mock_provider = Mock() - mock_provider.name = "test_provider" + mock_provider.name = "test_provider_2" mock_resource_provider.model_registry.get_model_provider.return_value = mock_provider generator.log_pre_generation() - mock_logger.info.assert_any_call(" |-- default model provider: 'test_provider'") + mock_logger.info.assert_any_call(" |-- model provider: 'test_provider_2'") @pytest.mark.parametrize( diff --git a/tests/engine/column_generators/generators/test_seed_dataset.py b/tests/engine/column_generators/generators/test_seed_dataset.py index 792632c2..84f9354d 100644 --- a/tests/engine/column_generators/generators/test_seed_dataset.py +++ b/tests/engine/column_generators/generators/test_seed_dataset.py @@ -19,7 +19,7 @@ ) from data_designer.engine.column_generators.utils.errors import SeedDatasetError from data_designer.engine.dataset_builders.multi_column_configs import SeedDatasetMultiColumnConfig -from data_designer.engine.resources.resource_provider import ResourceProvider, ResourceType +from data_designer.engine.resources.resource_provider import ResourceProvider @pytest.fixture @@ -111,7 +111,6 @@ def seed_dataset_jsonl(sample_dataframe): def test_seed_dataset_column_generator_metadata(): metadata = SeedDatasetColumnGenerator.metadata() assert metadata.generation_strategy == GenerationStrategy.FULL_COLUMN - assert ResourceType.SEED_READER in metadata.required_resources def test_seed_dataset_column_generator_config_structure(): diff --git a/tests/engine/column_generators/utils/test_generator_classification.py b/tests/engine/column_generators/utils/test_generator_classification.py new file mode 100644 index 00000000..9617b418 --- /dev/null +++ b/tests/engine/column_generators/utils/test_generator_classification.py @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from data_designer.config.column_types import DataDesignerColumnType +from data_designer.engine.column_generators.utils.generator_classification import ( + column_type_is_model_generated, + column_type_used_in_execution_dag, +) + + +def test_column_type_is_model_generated() -> None: + assert column_type_is_model_generated(DataDesignerColumnType.LLM_TEXT) + assert column_type_is_model_generated(DataDesignerColumnType.LLM_CODE) + assert column_type_is_model_generated(DataDesignerColumnType.LLM_STRUCTURED) + assert column_type_is_model_generated(DataDesignerColumnType.LLM_JUDGE) + assert column_type_is_model_generated(DataDesignerColumnType.EMBEDDING) + assert not column_type_is_model_generated(DataDesignerColumnType.SAMPLER) + assert not column_type_is_model_generated(DataDesignerColumnType.VALIDATION) + assert not column_type_is_model_generated(DataDesignerColumnType.EXPRESSION) + assert not column_type_is_model_generated(DataDesignerColumnType.SEED_DATASET) + + +def test_column_type_used_in_execution_dag() -> None: + assert column_type_used_in_execution_dag(DataDesignerColumnType.EXPRESSION) + assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_CODE) + assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_JUDGE) + assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_STRUCTURED) + assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_TEXT) + assert column_type_used_in_execution_dag(DataDesignerColumnType.VALIDATION) + assert column_type_used_in_execution_dag(DataDesignerColumnType.EMBEDDING) + assert not column_type_used_in_execution_dag(DataDesignerColumnType.SAMPLER) + assert not column_type_used_in_execution_dag(DataDesignerColumnType.SEED_DATASET) diff --git a/tests/engine/processing/processors/test_drop_columns.py b/tests/engine/processing/processors/test_drop_columns.py index c571ec96..6e754cb6 100644 --- a/tests/engine/processing/processors/test_drop_columns.py +++ b/tests/engine/processing/processors/test_drop_columns.py @@ -42,7 +42,6 @@ def test_metadata(): assert metadata.name == "drop_columns_processor" assert metadata.description == "Drop columns from the input dataset." - assert metadata.required_resources is None @pytest.mark.parametrize( diff --git a/tests/engine/processing/processors/test_schema_transform.py b/tests/engine/processing/processors/test_schema_transform.py index 4649e875..ac65d820 100644 --- a/tests/engine/processing/processors/test_schema_transform.py +++ b/tests/engine/processing/processors/test_schema_transform.py @@ -52,7 +52,6 @@ def test_metadata() -> None: assert metadata.name == "schema_transform_processor" assert metadata.description == "Generate dataset with transformed schema using a Jinja2 template." - assert metadata.required_resources is None def test_process_returns_original_dataframe( diff --git a/tests/engine/test_configurable_task.py b/tests/engine/test_configurable_task.py index 54ec9547..9896a598 100644 --- a/tests/engine/test_configurable_task.py +++ b/tests/engine/test_configurable_task.py @@ -15,25 +15,21 @@ ) from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage from data_designer.engine.models.registry import ModelRegistry -from data_designer.engine.resources.resource_provider import ResourceProvider, ResourceType +from data_designer.engine.resources.resource_provider import ResourceProvider def test_configurable_task_metadata_creation(): - metadata = ConfigurableTaskMetadata( - name="test_task", description="Test task description", required_resources=[ResourceType.MODEL_REGISTRY] - ) + metadata = ConfigurableTaskMetadata(name="test_task", description="Test task description") assert metadata.name == "test_task" assert metadata.description == "Test task description" - assert metadata.required_resources == [ResourceType.MODEL_REGISTRY] def test_configurable_task_metadata_with_no_resources(): - metadata = ConfigurableTaskMetadata(name="test_task", description="Test task description", required_resources=None) + metadata = ConfigurableTaskMetadata(name="test_task", description="Test task description") assert metadata.name == "test_task" assert metadata.description == "Test task description" - assert metadata.required_resources is None def test_configurable_task_generic_type_variables(): @@ -53,7 +49,7 @@ def get_config_type(cls) -> type[TestConfig]: @classmethod def metadata(cls) -> ConfigurableTaskMetadata: - return ConfigurableTaskMetadata(name="test_task", description="Test task", required_resources=None) + return ConfigurableTaskMetadata(name="test_task", description="Test task") def _validate(self) -> None: pass @@ -87,7 +83,7 @@ def get_config_type(cls) -> type[TestConfig]: @classmethod def metadata(cls) -> ConfigurableTaskMetadata: - return ConfigurableTaskMetadata(name="test_task", description="Test task", required_resources=None) + return ConfigurableTaskMetadata(name="test_task", description="Test task") def _validate(self) -> None: if self._config.value == "invalid": @@ -121,9 +117,7 @@ def get_config_type(cls) -> type[TestConfig]: @classmethod def metadata(cls) -> ConfigurableTaskMetadata: - return ConfigurableTaskMetadata( - name="test_task", description="Test task", required_resources=[ResourceType.MODEL_REGISTRY] - ) + return ConfigurableTaskMetadata(name="test_task", description="Test task") def _validate(self) -> None: pass @@ -152,7 +146,7 @@ class TestConfig(ConfigBase): class TestTask(ConfigurableTask[TestConfig]): @classmethod def metadata(cls) -> ConfigurableTaskMetadata: - return ConfigurableTaskMetadata(name="test_task", description="Test task", required_resources=None) + return ConfigurableTaskMetadata(name="test_task", description="Test task") def _validate(self) -> None: pass diff --git a/tests/test_plugin_manager.py b/tests/test_plugin_manager.py index 0040eab7..3ea11caa 100644 --- a/tests/test_plugin_manager.py +++ b/tests/test_plugin_manager.py @@ -6,7 +6,6 @@ from enum import Enum from unittest.mock import patch -from data_designer.engine.resources.resource_provider import ResourceType from data_designer.plugin_manager import PluginManager from data_designer.plugins.plugin import Plugin from data_designer.plugins.registry import PluginRegistry @@ -15,7 +14,6 @@ plugin_blobs_and_seeds, plugin_models, plugin_models_and_blobs, - plugin_none, ) @@ -97,30 +95,6 @@ def test_get_plugin_column_types_with_plugins() -> None: assert all(isinstance(item, TestEnum) for item in result) -def test_get_plugin_column_types_with_resource_filtering() -> None: - """Test filtering plugins by required resources.""" - all_plugins = [plugin_models, plugin_models_and_blobs, plugin_blobs_and_seeds] - TestEnum = make_test_enum(all_plugins) - - with mock_plugin_system(all_plugins): - manager = PluginManager() - result = manager.get_plugin_column_types(TestEnum, required_resources=[ResourceType.MODEL_REGISTRY]) - - assert len(result) == 2 - assert set(result) == {plugin_models.name, plugin_models_and_blobs.name} - - -def test_get_plugin_column_types_filters_none_resources() -> None: - """Test filtering when plugin has None for required_resources.""" - TestEnum = make_test_enum([plugin_none]) - - with mock_plugin_system([plugin_none]): - manager = PluginManager() - result = manager.get_plugin_column_types(TestEnum, required_resources=[ResourceType.MODEL_REGISTRY]) - - assert result == [] - - def test_get_plugin_column_types_empty() -> None: """Test getting plugin column types when no plugins are registered.""" TestEnum = make_test_enum([])