diff --git a/pyproject.toml b/pyproject.toml index 362b5298..c4fb25b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ dev = [ "pytest>=8.3.3", "pytest-asyncio>=0.24.0", "pytest-cov>=7.0.0", + "pytest-env>=1.2.0", "pytest-httpx>=0.35.0", ] docs = [ @@ -89,6 +90,9 @@ version-file = "src/data_designer/_version.py" [tool.pytest.ini_options] testpaths = ["tests"] asyncio_default_fixture_loop_scope = "session" +env = [ + "DISABLE_DATA_DESIGNER_PLUGINS=true", +] [tool.uv] package = true diff --git a/src/data_designer/config/analysis/column_statistics.py b/src/data_designer/config/analysis/column_statistics.py index 991e41b9..c39dedfb 100644 --- a/src/data_designer/config/analysis/column_statistics.py +++ b/src/data_designer/config/analysis/column_statistics.py @@ -5,13 +5,14 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Annotated, Any, Literal, Optional, Union +from typing import Any, Literal, Optional, Union from pandas import Series -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, ConfigDict, create_model, field_validator, model_validator from typing_extensions import Self, TypeAlias -from ..columns import DataDesignerColumnType +from ...plugin_manager import PluginManager +from ..column_types import DataDesignerColumnType from ..sampler_params import SamplerType from ..utils.constants import EPSILON from ..utils.numerical_helpers import is_float, is_int, prepare_number_for_reporting @@ -238,17 +239,41 @@ def from_series(cls, series: Series) -> Self: ) -ColumnStatisticsT: TypeAlias = Annotated[ - Union[ - GeneralColumnStatistics, - LLMTextColumnStatistics, - LLMCodeColumnStatistics, - LLMStructuredColumnStatistics, - LLMJudgedColumnStatistics, - SamplerColumnStatistics, - SeedDatasetColumnStatistics, - ValidationColumnStatistics, - ExpressionColumnStatistics, - ], - Field(discriminator="column_type"), +ColumnStatisticsT: TypeAlias = Union[ + GeneralColumnStatistics, + LLMTextColumnStatistics, + LLMCodeColumnStatistics, + LLMStructuredColumnStatistics, + LLMJudgedColumnStatistics, + SamplerColumnStatistics, + SeedDatasetColumnStatistics, + ValidationColumnStatistics, + ExpressionColumnStatistics, ] + + +DEFAULT_COLUMN_STATISTICS_MAP = { + DataDesignerColumnType.EXPRESSION: ExpressionColumnStatistics, + DataDesignerColumnType.LLM_CODE: LLMCodeColumnStatistics, + DataDesignerColumnType.LLM_JUDGE: LLMJudgedColumnStatistics, + DataDesignerColumnType.LLM_STRUCTURED: LLMStructuredColumnStatistics, + DataDesignerColumnType.LLM_TEXT: LLMTextColumnStatistics, + DataDesignerColumnType.SAMPLER: SamplerColumnStatistics, + DataDesignerColumnType.SEED_DATASET: SeedDatasetColumnStatistics, + DataDesignerColumnType.VALIDATION: ValidationColumnStatistics, +} + +for plugin in PluginManager().get_column_generator_plugins(): + # Dynamically create a statistics class for this plugin using Pydantic's create_model + plugin_stats_cls_name = f"{plugin.config_type_as_class_name}ColumnStatistics" + + # Create the class with proper Pydantic field + plugin_stats_cls = create_model( + plugin_stats_cls_name, + __base__=GeneralColumnStatistics, + column_type=(Literal[plugin.name], plugin.name), + ) + + # Add the plugin statistics class to the union + ColumnStatisticsT |= plugin_stats_cls + DEFAULT_COLUMN_STATISTICS_MAP[DataDesignerColumnType(plugin.name)] = plugin_stats_cls diff --git a/src/data_designer/config/analysis/dataset_profiler.py b/src/data_designer/config/analysis/dataset_profiler.py index aa2b638f..72f4d528 100644 --- a/src/data_designer/config/analysis/dataset_profiler.py +++ b/src/data_designer/config/analysis/dataset_profiler.py @@ -3,11 +3,11 @@ from functools import cached_property from pathlib import Path -from typing import Optional, Union +from typing import Annotated, Optional, Union from pydantic import BaseModel, Field, field_validator -from ..columns import DataDesignerColumnType, get_column_display_order +from ..column_types import DataDesignerColumnType, get_column_display_order from ..utils.constants import EPSILON from ..utils.numerical_helpers import prepare_number_for_reporting from .column_profilers import ColumnProfilerResultsT @@ -18,7 +18,7 @@ class DatasetProfilerResults(BaseModel): num_records: int target_num_records: int - column_statistics: list[ColumnStatisticsT] = Field(..., min_length=1) + column_statistics: list[Annotated[ColumnStatisticsT, Field(discriminator="column_type")]] = Field(..., min_length=1) side_effect_column_names: Optional[list[str]] = None column_profiles: Optional[list[ColumnProfilerResultsT]] = None diff --git a/src/data_designer/config/analysis/utils/reporting.py b/src/data_designer/config/analysis/utils/reporting.py index e4df4190..e1d2c2cf 100644 --- a/src/data_designer/config/analysis/utils/reporting.py +++ b/src/data_designer/config/analysis/utils/reporting.py @@ -15,7 +15,7 @@ from rich.text import Text from ...analysis.column_statistics import CategoricalHistogramData -from ...columns import COLUMN_TYPE_EMOJI_MAP, DataDesignerColumnType, get_column_display_order +from ...column_types import COLUMN_TYPE_EMOJI_MAP, DataDesignerColumnType, get_column_display_order from ...utils.visualization import ( ColorPalette, convert_to_row_element, @@ -27,7 +27,6 @@ if TYPE_CHECKING: from ...analysis.dataset_profiler import DatasetProfilerResults - HEADER_STYLE = "dim" RULE_STYLE = f"bold {ColorPalette.NVIDIA_GREEN.value}" ACCENT_STYLE = f"bold {ColorPalette.BLUE.value}" diff --git a/src/data_designer/config/column_configs.py b/src/data_designer/config/column_configs.py new file mode 100644 index 00000000..2840d13d --- /dev/null +++ b/src/data_designer/config/column_configs.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC +from typing import Literal, Optional, Type, Union + +from pydantic import BaseModel, Field, model_validator +from typing_extensions import Self + +from .base import ConfigBase +from .errors import InvalidConfigError +from .models import ImageContext +from .sampler_params import SamplerParamsT, SamplerType +from .utils.code_lang import CodeLang +from .utils.constants import REASONING_TRACE_COLUMN_POSTFIX +from .utils.misc import assert_valid_jinja2_template, get_prompt_template_keywords +from .validator_params import ValidatorParamsT, ValidatorType + + +class SingleColumnConfig(ConfigBase, ABC): + name: str + drop: bool = False + column_type: str + + @property + def required_columns(self) -> list[str]: + return [] + + @property + def side_effect_columns(self) -> list[str]: + return [] + + +class SamplerColumnConfig(SingleColumnConfig): + sampler_type: SamplerType + params: SamplerParamsT + conditional_params: dict[str, SamplerParamsT] = {} + convert_to: Optional[str] = None + column_type: Literal["sampler"] = "sampler" + + +class LLMTextColumnConfig(SingleColumnConfig): + prompt: str + model_alias: str + system_prompt: Optional[str] = None + multi_modal_context: Optional[list[ImageContext]] = None + column_type: Literal["llm-text"] = "llm-text" + + @property + def required_columns(self) -> list[str]: + required_cols = list(get_prompt_template_keywords(self.prompt)) + if self.system_prompt: + required_cols.extend(list(get_prompt_template_keywords(self.system_prompt))) + return list(set(required_cols)) + + @property + def side_effect_columns(self) -> list[str]: + return [f"{self.name}{REASONING_TRACE_COLUMN_POSTFIX}"] + + @model_validator(mode="after") + def assert_prompt_valid_jinja(self) -> Self: + assert_valid_jinja2_template(self.prompt) + if self.system_prompt: + assert_valid_jinja2_template(self.system_prompt) + return self + + +class LLMCodeColumnConfig(LLMTextColumnConfig): + code_lang: CodeLang + column_type: Literal["llm-code"] = "llm-code" + + +class LLMStructuredColumnConfig(LLMTextColumnConfig): + output_format: Union[dict, Type[BaseModel]] + column_type: Literal["llm-structured"] = "llm-structured" + + @model_validator(mode="after") + def validate_output_format(self) -> Self: + if not isinstance(self.output_format, dict) and issubclass(self.output_format, BaseModel): + self.output_format = self.output_format.model_json_schema() + return self + + +class Score(ConfigBase): + name: str = Field(..., description="A clear name for this score.") + description: str = Field(..., description="An informative and detailed assessment guide for using this score.") + options: dict[Union[int, str], str] = Field(..., description="Score options in the format of {score: description}.") + + +class LLMJudgeColumnConfig(LLMTextColumnConfig): + scores: list[Score] = Field(..., min_length=1) + column_type: Literal["llm-judge"] = "llm-judge" + + +class ExpressionColumnConfig(SingleColumnConfig): + name: str + expr: str + dtype: Literal["int", "float", "str", "bool"] = "str" + column_type: Literal["expression"] = "expression" + + @property + def required_columns(self) -> list[str]: + return list(get_prompt_template_keywords(self.expr)) + + @model_validator(mode="after") + def assert_expression_valid_jinja(self) -> Self: + if not self.expr.strip(): + raise InvalidConfigError( + f"🛑 Expression column '{self.name}' has an empty or whitespace-only expression. " + f"Please provide a valid Jinja2 expression (e.g., '{{ column_name }}' or '{{ col1 }} + {{ col2 }}') " + "or remove this column if not needed." + ) + assert_valid_jinja2_template(self.expr) + return self + + +class ValidationColumnConfig(SingleColumnConfig): + target_columns: list[str] + validator_type: ValidatorType + validator_params: ValidatorParamsT + batch_size: int = Field(default=10, ge=1, description="Number of records to process in each batch") + column_type: Literal["validation"] = "validation" + + @property + def required_columns(self) -> list[str]: + return self.target_columns + + +class SeedDatasetColumnConfig(SingleColumnConfig): + column_type: Literal["seed-dataset"] = "seed-dataset" diff --git a/src/data_designer/config/column_types.py b/src/data_designer/config/column_types.py new file mode 100644 index 00000000..50ba498d --- /dev/null +++ b/src/data_designer/config/column_types.py @@ -0,0 +1,168 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +from typing_extensions import TypeAlias + +from ..plugin_manager import PluginManager +from .column_configs import ( + ExpressionColumnConfig, + LLMCodeColumnConfig, + LLMJudgeColumnConfig, + LLMStructuredColumnConfig, + LLMTextColumnConfig, + SamplerColumnConfig, + SeedDatasetColumnConfig, + ValidationColumnConfig, +) +from .errors import InvalidColumnTypeError, InvalidConfigError +from .sampler_params import SamplerType +from .utils.type_helpers import SAMPLER_PARAMS, create_str_enum_from_discriminated_type_union, resolve_string_enum + +plugin_manager = PluginManager() + +ColumnConfigT: TypeAlias = Union[ + ExpressionColumnConfig, + LLMCodeColumnConfig, + LLMJudgeColumnConfig, + LLMStructuredColumnConfig, + LLMTextColumnConfig, + SamplerColumnConfig, + SeedDatasetColumnConfig, + ValidationColumnConfig, +] +ColumnConfigT = plugin_manager.inject_into_column_config_type_union(ColumnConfigT) + +DataDesignerColumnType = create_str_enum_from_discriminated_type_union( + enum_name="DataDesignerColumnType", + type_union=ColumnConfigT, + discriminator_field_name="column_type", +) + +COLUMN_TYPE_EMOJI_MAP = { + "general": "⚛️", # possible analysis column type + DataDesignerColumnType.EXPRESSION: "🧩", + DataDesignerColumnType.LLM_CODE: "💻", + DataDesignerColumnType.LLM_JUDGE: "⚖️", + DataDesignerColumnType.LLM_STRUCTURED: "🗂️", + DataDesignerColumnType.LLM_TEXT: "📝", + DataDesignerColumnType.SEED_DATASET: "🌱", + DataDesignerColumnType.SAMPLER: "🎲", + DataDesignerColumnType.VALIDATION: "🔍", +} +COLUMN_TYPE_EMOJI_MAP.update( + {DataDesignerColumnType(p.name): p.emoji for p in plugin_manager.get_column_generator_plugins()} +) + + +def column_type_used_in_execution_dag(column_type: Union[str, DataDesignerColumnType]) -> bool: + """Return True if the column type is used in the workflow execution DAG.""" + column_type = resolve_string_enum(column_type, DataDesignerColumnType) + dag_column_types = { + DataDesignerColumnType.EXPRESSION, + DataDesignerColumnType.LLM_CODE, + DataDesignerColumnType.LLM_JUDGE, + DataDesignerColumnType.LLM_STRUCTURED, + DataDesignerColumnType.LLM_TEXT, + DataDesignerColumnType.VALIDATION, + } + dag_column_types.update(plugin_manager.get_plugin_column_types(DataDesignerColumnType)) + return column_type in dag_column_types + + +def column_type_is_llm_generated(column_type: Union[str, DataDesignerColumnType]) -> bool: + """Return True if the column type is an LLM-generated column.""" + column_type = resolve_string_enum(column_type, DataDesignerColumnType) + llm_generated_column_types = { + DataDesignerColumnType.LLM_TEXT, + DataDesignerColumnType.LLM_CODE, + DataDesignerColumnType.LLM_STRUCTURED, + DataDesignerColumnType.LLM_JUDGE, + } + llm_generated_column_types.update( + plugin_manager.get_plugin_column_types( + DataDesignerColumnType, + required_resources=["model_registry"], + ) + ) + return column_type in llm_generated_column_types + + +def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType, **kwargs) -> ColumnConfigT: + """Create a Data Designer column config object from kwargs. + + Args: + name: Name of the column. + column_type: Type of the column. + **kwargs: Keyword arguments to pass to the column constructor. + + Returns: + Data Designer column object of the appropriate type. + """ + column_type = resolve_string_enum(column_type, DataDesignerColumnType) + if column_type == DataDesignerColumnType.LLM_TEXT: + return LLMTextColumnConfig(name=name, **kwargs) + if column_type == DataDesignerColumnType.LLM_CODE: + return LLMCodeColumnConfig(name=name, **kwargs) + if column_type == DataDesignerColumnType.LLM_STRUCTURED: + return LLMStructuredColumnConfig(name=name, **kwargs) + if column_type == DataDesignerColumnType.LLM_JUDGE: + return LLMJudgeColumnConfig(name=name, **kwargs) + if column_type == DataDesignerColumnType.VALIDATION: + return ValidationColumnConfig(name=name, **kwargs) + if column_type == DataDesignerColumnType.EXPRESSION: + return ExpressionColumnConfig(name=name, **kwargs) + if column_type == DataDesignerColumnType.SAMPLER: + return SamplerColumnConfig(name=name, **_resolve_sampler_kwargs(name, kwargs)) + if column_type == DataDesignerColumnType.SEED_DATASET: + return SeedDatasetColumnConfig(name=name, **kwargs) + if plugin := plugin_manager.get_column_generator_plugin_if_exists(column_type.value): + return plugin.config_cls(name=name, **kwargs) + raise InvalidColumnTypeError(f"🛑 {column_type} is not a valid column type.") # pragma: no cover + + +def get_column_display_order() -> list[DataDesignerColumnType]: + """Return the preferred display order of the column types.""" + display_order = [ + DataDesignerColumnType.SEED_DATASET, + DataDesignerColumnType.SAMPLER, + DataDesignerColumnType.LLM_TEXT, + DataDesignerColumnType.LLM_CODE, + DataDesignerColumnType.LLM_STRUCTURED, + DataDesignerColumnType.LLM_JUDGE, + DataDesignerColumnType.VALIDATION, + DataDesignerColumnType.EXPRESSION, + ] + display_order.extend(plugin_manager.get_plugin_column_types(DataDesignerColumnType)) + return display_order + + +def _resolve_sampler_kwargs(name: str, kwargs: dict) -> dict: + if "sampler_type" not in kwargs: + raise InvalidConfigError(f"🛑 `sampler_type` is required for sampler column '{name}'.") + sampler_type = resolve_string_enum(kwargs["sampler_type"], SamplerType) + + # Handle params - it could be a dict or already a concrete object + params_value = kwargs.get("params", {}) + expected_params_class = SAMPLER_PARAMS[sampler_type.value] + + if isinstance(params_value, expected_params_class): + # params is already a concrete object of the right type + params = params_value + elif isinstance(params_value, dict): + # params is a dictionary, create new instance + params = expected_params_class(**params_value) + else: + # params is neither dict nor expected type + raise InvalidConfigError( + f"🛑 Invalid params for sampler column '{name}'. " + f"Expected a dictionary or an instance of {expected_params_class.__name__}. " + f"You provided {params_value=}." + ) + + return { + "sampler_type": sampler_type, + "params": params, + **{k: v for k, v in kwargs.items() if k not in ["sampler_type", "params"]}, + } diff --git a/src/data_designer/config/columns.py b/src/data_designer/config/columns.py deleted file mode 100644 index 8886cb09..00000000 --- a/src/data_designer/config/columns.py +++ /dev/null @@ -1,260 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from abc import ABC -from typing import Literal, Optional, Type, Union - -from pydantic import BaseModel, Field, model_validator -from typing_extensions import Self, TypeAlias - -from .base import ConfigBase -from .errors import InvalidColumnTypeError, InvalidConfigError -from .models import ImageContext -from .sampler_params import SamplerParamsT, SamplerType -from .utils.code_lang import CodeLang -from .utils.constants import REASONING_TRACE_COLUMN_POSTFIX -from .utils.misc import assert_valid_jinja2_template, get_prompt_template_keywords -from .utils.type_helpers import SAMPLER_PARAMS, create_str_enum_from_discriminated_type_union, resolve_string_enum -from .validator_params import ValidatorParamsT, ValidatorType - - -class SingleColumnConfig(ConfigBase, ABC): - name: str - drop: bool = False - column_type: str - - @property - def required_columns(self) -> list[str]: - return [] - - @property - def side_effect_columns(self) -> list[str]: - return [] - - -class SamplerColumnConfig(SingleColumnConfig): - sampler_type: SamplerType - params: SamplerParamsT - conditional_params: dict[str, SamplerParamsT] = {} - convert_to: Optional[str] = None - column_type: Literal["sampler"] = "sampler" - - -class LLMTextColumnConfig(SingleColumnConfig): - prompt: str - model_alias: str - system_prompt: Optional[str] = None - multi_modal_context: Optional[list[ImageContext]] = None - column_type: Literal["llm-text"] = "llm-text" - - @property - def required_columns(self) -> list[str]: - required_cols = list(get_prompt_template_keywords(self.prompt)) - if self.system_prompt: - required_cols.extend(list(get_prompt_template_keywords(self.system_prompt))) - return list(set(required_cols)) - - @property - def side_effect_columns(self) -> list[str]: - return [f"{self.name}{REASONING_TRACE_COLUMN_POSTFIX}"] - - @model_validator(mode="after") - def assert_prompt_valid_jinja(self) -> Self: - assert_valid_jinja2_template(self.prompt) - if self.system_prompt: - assert_valid_jinja2_template(self.system_prompt) - return self - - -class LLMCodeColumnConfig(LLMTextColumnConfig): - code_lang: CodeLang - column_type: Literal["llm-code"] = "llm-code" - - -class LLMStructuredColumnConfig(LLMTextColumnConfig): - output_format: Union[dict, Type[BaseModel]] - column_type: Literal["llm-structured"] = "llm-structured" - - @model_validator(mode="after") - def validate_output_format(self) -> Self: - if not isinstance(self.output_format, dict) and issubclass(self.output_format, BaseModel): - self.output_format = self.output_format.model_json_schema() - return self - - -class Score(ConfigBase): - name: str = Field(..., description="A clear name for this score.") - description: str = Field(..., description="An informative and detailed assessment guide for using this score.") - options: dict[Union[int, str], str] = Field(..., description="Score options in the format of {score: description}.") - - -class LLMJudgeColumnConfig(LLMTextColumnConfig): - scores: list[Score] = Field(..., min_length=1) - column_type: Literal["llm-judge"] = "llm-judge" - - -class ExpressionColumnConfig(SingleColumnConfig): - name: str - expr: str - dtype: Literal["int", "float", "str", "bool"] = "str" - column_type: Literal["expression"] = "expression" - - @property - def required_columns(self) -> list[str]: - return list(get_prompt_template_keywords(self.expr)) - - @model_validator(mode="after") - def assert_expression_valid_jinja(self) -> Self: - if not self.expr.strip(): - raise InvalidConfigError( - f"🛑 Expression column '{self.name}' has an empty or whitespace-only expression. " - f"Please provide a valid Jinja2 expression (e.g., '{{ column_name }}' or '{{ col1 }} + {{ col2 }}') " - "or remove this column if not needed." - ) - assert_valid_jinja2_template(self.expr) - return self - - -class ValidationColumnConfig(SingleColumnConfig): - target_columns: list[str] - validator_type: ValidatorType - validator_params: ValidatorParamsT - batch_size: int = Field(default=10, ge=1, description="Number of records to process in each batch") - column_type: Literal["validation"] = "validation" - - @property - def required_columns(self) -> list[str]: - return self.target_columns - - -class SeedDatasetColumnConfig(SingleColumnConfig): - column_type: Literal["seed-dataset"] = "seed-dataset" - - -ColumnConfigT: TypeAlias = Union[ - ExpressionColumnConfig, - LLMCodeColumnConfig, - LLMJudgeColumnConfig, - LLMStructuredColumnConfig, - LLMTextColumnConfig, - SamplerColumnConfig, - SeedDatasetColumnConfig, - ValidationColumnConfig, -] - - -DataDesignerColumnType = create_str_enum_from_discriminated_type_union( - enum_name="DataDesignerColumnType", - type_union=ColumnConfigT, - discriminator_field_name="column_type", -) - - -COLUMN_TYPE_EMOJI_MAP = { - "general": "⚛️", # possible analysis column type - DataDesignerColumnType.EXPRESSION: "🧩", - DataDesignerColumnType.LLM_CODE: "💻", - DataDesignerColumnType.LLM_JUDGE: "⚖️", - DataDesignerColumnType.LLM_STRUCTURED: "🗂️", - DataDesignerColumnType.LLM_TEXT: "📝", - DataDesignerColumnType.SEED_DATASET: "🌱", - DataDesignerColumnType.SAMPLER: "🎲", - DataDesignerColumnType.VALIDATION: "🔍", -} - - -def column_type_used_in_execution_dag(column_type: Union[str, DataDesignerColumnType]) -> bool: - """Return True if the column type is used in the workflow execution DAG.""" - column_type = resolve_string_enum(column_type, DataDesignerColumnType) - return column_type in { - DataDesignerColumnType.EXPRESSION, - DataDesignerColumnType.LLM_CODE, - DataDesignerColumnType.LLM_JUDGE, - DataDesignerColumnType.LLM_STRUCTURED, - DataDesignerColumnType.LLM_TEXT, - DataDesignerColumnType.VALIDATION, - } - - -def column_type_is_llm_generated(column_type: Union[str, DataDesignerColumnType]) -> bool: - """Return True if the column type is an LLM-generated column.""" - column_type = resolve_string_enum(column_type, DataDesignerColumnType) - return column_type in { - DataDesignerColumnType.LLM_TEXT, - DataDesignerColumnType.LLM_CODE, - DataDesignerColumnType.LLM_STRUCTURED, - DataDesignerColumnType.LLM_JUDGE, - } - - -def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType, **kwargs) -> ColumnConfigT: - """Create a Data Designer column config object from kwargs. - - Args: - name: Name of the column. - column_type: Type of the column. - **kwargs: Keyword arguments to pass to the column constructor. - - Returns: - Data Designer column object of the appropriate type. - """ - column_type = resolve_string_enum(column_type, DataDesignerColumnType) - if column_type == DataDesignerColumnType.LLM_TEXT: - return LLMTextColumnConfig(name=name, **kwargs) - elif column_type == DataDesignerColumnType.LLM_CODE: - return LLMCodeColumnConfig(name=name, **kwargs) - elif column_type == DataDesignerColumnType.LLM_STRUCTURED: - return LLMStructuredColumnConfig(name=name, **kwargs) - elif column_type == DataDesignerColumnType.LLM_JUDGE: - return LLMJudgeColumnConfig(name=name, **kwargs) - elif column_type == DataDesignerColumnType.VALIDATION: - return ValidationColumnConfig(name=name, **kwargs) - elif column_type == DataDesignerColumnType.EXPRESSION: - return ExpressionColumnConfig(name=name, **kwargs) - elif column_type == DataDesignerColumnType.SAMPLER: - return SamplerColumnConfig(name=name, **_resolve_sampler_kwargs(name, kwargs)) - elif column_type == DataDesignerColumnType.SEED_DATASET: - return SeedDatasetColumnConfig(name=name, **kwargs) - raise InvalidColumnTypeError(f"🛑 {column_type} is not a valid column type.") # pragma: no cover - - -def get_column_display_order() -> list[DataDesignerColumnType]: - """Return the preferred display order of the column types.""" - return [ - DataDesignerColumnType.SEED_DATASET, - DataDesignerColumnType.SAMPLER, - DataDesignerColumnType.LLM_TEXT, - DataDesignerColumnType.LLM_CODE, - DataDesignerColumnType.LLM_STRUCTURED, - DataDesignerColumnType.LLM_JUDGE, - DataDesignerColumnType.VALIDATION, - DataDesignerColumnType.EXPRESSION, - ] - - -def _resolve_sampler_kwargs(name: str, kwargs: dict) -> dict: - if "sampler_type" not in kwargs: - raise InvalidConfigError(f"🛑 `sampler_type` is required for sampler column '{name}'.") - sampler_type = resolve_string_enum(kwargs["sampler_type"], SamplerType) - - # Handle params - it could be a dict or already a concrete object - params_value = kwargs.get("params", {}) - expected_params_class = SAMPLER_PARAMS[sampler_type.value] - - if isinstance(params_value, expected_params_class): - # params is already a concrete object of the right type - params = params_value - elif isinstance(params_value, dict): - # params is a dictionary, create new instance - params = expected_params_class(**params_value) - else: - # params is neither dict nor expected type - raise InvalidConfigError( - f"🛑 Invalid params for sampler column '{name}'. Expected a dictionary or an instance of {expected_params_class.__name__}." - ) - - return { - "sampler_type": sampler_type, - "params": params, - **{k: v for k, v in kwargs.items() if k not in ["sampler_type", "params"]}, - } diff --git a/src/data_designer/config/config_builder.py b/src/data_designer/config/config_builder.py index eca394cf..9f1eee86 100644 --- a/src/data_designer/config/config_builder.py +++ b/src/data_designer/config/config_builder.py @@ -15,12 +15,13 @@ from .analysis.column_profilers import ColumnProfilerConfigT from .base import ExportableConfigBase -from .columns import ( +from .column_configs import SeedDatasetColumnConfig +from .column_types import ( ColumnConfigT, DataDesignerColumnType, - SeedDatasetColumnConfig, column_type_is_llm_generated, get_column_config_from_kwargs, + get_column_display_order, ) from .data_designer_config import DataDesignerConfig from .dataset_builders import BuildStage @@ -46,11 +47,7 @@ from .utils.constants import DEFAULT_REPR_HTML_STYLE, REPR_HTML_TEMPLATE from .utils.info import ConfigBuilderInfo from .utils.io_helpers import serialize_data, smart_load_yaml -from .utils.misc import ( - can_run_data_designer_locally, - json_indent_list_of_strings, - kebab_to_snake, -) +from .utils.misc import can_run_data_designer_locally, json_indent_list_of_strings, kebab_to_snake from .utils.type_helpers import resolve_string_enum from .utils.validation import ViolationLevel, rich_print_violations, validate_data_designer_config @@ -624,16 +621,7 @@ def __repr__(self) -> str: "seed_dataset": (None if self._seed_config is None else f"'{self._seed_config.dataset}'"), } - for column_type in [ - DataDesignerColumnType.SEED_DATASET, - DataDesignerColumnType.SAMPLER, - DataDesignerColumnType.LLM_TEXT, - DataDesignerColumnType.LLM_CODE, - DataDesignerColumnType.LLM_STRUCTURED, - DataDesignerColumnType.LLM_JUDGE, - DataDesignerColumnType.VALIDATION, - DataDesignerColumnType.EXPRESSION, - ]: + for column_type in get_column_display_order(): columns = self.get_columns_of_type(column_type) if len(columns) > 0: column_label = f"{kebab_to_snake(column_type.value)}_columns" diff --git a/src/data_designer/config/data_designer_config.py b/src/data_designer/config/data_designer_config.py index 24d791b8..3575f1a7 100644 --- a/src/data_designer/config/data_designer_config.py +++ b/src/data_designer/config/data_designer_config.py @@ -9,7 +9,7 @@ from .analysis.column_profilers import ColumnProfilerConfigT from .base import ExportableConfigBase -from .columns import ColumnConfigT +from .column_types import ColumnConfigT from .models import ModelConfig from .processors import ProcessorConfig from .sampler_constraints import ColumnConstraintT diff --git a/src/data_designer/config/utils/validation.py b/src/data_designer/config/utils/validation.py index 428c575c..f1a0bef1 100644 --- a/src/data_designer/config/utils/validation.py +++ b/src/data_designer/config/utils/validation.py @@ -15,7 +15,7 @@ from rich.padding import Padding from rich.panel import Panel -from ..columns import ColumnConfigT, DataDesignerColumnType, column_type_is_llm_generated +from ..column_types import ColumnConfigT, DataDesignerColumnType, column_type_is_llm_generated from ..processors import ProcessorConfig, ProcessorType from ..validator_params import ValidatorType from .constants import RICH_CONSOLE_THEME diff --git a/src/data_designer/config/utils/visualization.py b/src/data_designer/config/utils/visualization.py index 4b245369..3d7420ad 100644 --- a/src/data_designer/config/utils/visualization.py +++ b/src/data_designer/config/utils/visualization.py @@ -21,7 +21,7 @@ from rich.text import Text from ..base import ConfigBase -from ..columns import DataDesignerColumnType +from ..column_types import DataDesignerColumnType from ..models import ModelConfig, ModelProvider, get_nvidia_api_key, get_openai_api_key from ..sampler_params import SamplerType from .code_lang import code_lang_to_syntax_lexer diff --git a/src/data_designer/engine/analysis/column_profilers/base.py b/src/data_designer/engine/analysis/column_profilers/base.py index 45d0d6ea..4bfcec9f 100644 --- a/src/data_designer/engine/analysis/column_profilers/base.py +++ b/src/data_designer/engine/analysis/column_profilers/base.py @@ -12,7 +12,8 @@ from typing_extensions import Self from data_designer.config.base import ConfigBase -from data_designer.config.columns import DataDesignerColumnType, SingleColumnConfig +from data_designer.config.column_configs import SingleColumnConfig +from data_designer.config.column_types import DataDesignerColumnType from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata, TaskConfigT logger = logging.getLogger(__name__) diff --git a/src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py b/src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py index 37de5dbd..0c8e3569 100644 --- a/src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py +++ b/src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py @@ -20,7 +20,7 @@ MissingValue, NumericalDistribution, ) -from data_designer.config.columns import COLUMN_TYPE_EMOJI_MAP, DataDesignerColumnType +from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP, DataDesignerColumnType from data_designer.engine.analysis.column_profilers.base import ( ColumnConfigWithDataFrame, ColumnProfiler, diff --git a/src/data_designer/engine/analysis/column_statistics.py b/src/data_designer/engine/analysis/column_statistics.py index dd4fb1e9..4b3e4f0e 100644 --- a/src/data_designer/engine/analysis/column_statistics.py +++ b/src/data_designer/engine/analysis/column_statistics.py @@ -11,18 +11,11 @@ from typing_extensions import Self from data_designer.config.analysis.column_statistics import ( + DEFAULT_COLUMN_STATISTICS_MAP, ColumnStatisticsT, - ExpressionColumnStatistics, GeneralColumnStatistics, - LLMCodeColumnStatistics, - LLMJudgedColumnStatistics, - LLMStructuredColumnStatistics, - LLMTextColumnStatistics, - SamplerColumnStatistics, - SeedDatasetColumnStatistics, - ValidationColumnStatistics, ) -from data_designer.config.columns import ColumnConfigT, DataDesignerColumnType +from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType from data_designer.config.sampler_params import SamplerType, is_numerical_sampler_type from data_designer.engine.analysis.column_profilers.base import ColumnConfigWithDataFrame from data_designer.engine.analysis.utils.column_statistics_calculations import ( @@ -134,18 +127,6 @@ def calculate_validation_column_info(self) -> dict[str, Any]: class ExpressionColumnStatisticsCalculator(GeneralColumnStatisticsCalculator): ... -DEFAULT_COLUMN_STATISTICS_MAP = { - DataDesignerColumnType.EXPRESSION: ExpressionColumnStatistics, - DataDesignerColumnType.LLM_CODE: LLMCodeColumnStatistics, - DataDesignerColumnType.LLM_JUDGE: LLMJudgedColumnStatistics, - DataDesignerColumnType.LLM_STRUCTURED: LLMStructuredColumnStatistics, - DataDesignerColumnType.LLM_TEXT: LLMTextColumnStatistics, - DataDesignerColumnType.SAMPLER: SamplerColumnStatistics, - DataDesignerColumnType.SEED_DATASET: SeedDatasetColumnStatistics, - DataDesignerColumnType.VALIDATION: ValidationColumnStatistics, -} - - ColumnStatisticsCalculatorT: TypeAlias = Union[ ExpressionColumnStatisticsCalculator, ValidationColumnStatisticsCalculator, diff --git a/src/data_designer/engine/analysis/dataset_profiler.py b/src/data_designer/engine/analysis/dataset_profiler.py index 3dd39ed0..3f173bba 100644 --- a/src/data_designer/engine/analysis/dataset_profiler.py +++ b/src/data_designer/engine/analysis/dataset_profiler.py @@ -11,10 +11,10 @@ from data_designer.config.analysis.column_profilers import ColumnProfilerConfigT from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults from data_designer.config.base import ConfigBase -from data_designer.config.columns import ( +from data_designer.config.column_configs import SingleColumnConfig +from data_designer.config.column_types import ( COLUMN_TYPE_EMOJI_MAP, ColumnConfigT, - SingleColumnConfig, ) from data_designer.engine.analysis.column_profilers.base import ColumnConfigWithDataFrame, ColumnProfiler from data_designer.engine.analysis.column_statistics import get_column_statistics_calculator diff --git a/src/data_designer/engine/analysis/utils/column_statistics_calculations.py b/src/data_designer/engine/analysis/utils/column_statistics_calculations.py index 67d72800..2b47c8cb 100644 --- a/src/data_designer/engine/analysis/utils/column_statistics_calculations.py +++ b/src/data_designer/engine/analysis/utils/column_statistics_calculations.py @@ -20,7 +20,7 @@ MissingValue, NumericalDistribution, ) -from data_designer.config.columns import ( +from data_designer.config.column_configs import ( LLMTextColumnConfig, SingleColumnConfig, ValidationColumnConfig, diff --git a/src/data_designer/engine/analysis/utils/judge_score_processing.py b/src/data_designer/engine/analysis/utils/judge_score_processing.py index 95b01ccc..d9bbc764 100644 --- a/src/data_designer/engine/analysis/utils/judge_score_processing.py +++ b/src/data_designer/engine/analysis/utils/judge_score_processing.py @@ -14,7 +14,7 @@ MissingValue, NumericalDistribution, ) -from data_designer.config.columns import LLMJudgeColumnConfig +from data_designer.config.column_configs import LLMJudgeColumnConfig logger = logging.getLogger(__name__) diff --git a/src/data_designer/engine/column_generators/generators/expression.py b/src/data_designer/engine/column_generators/generators/expression.py index 7da80e66..00ce4771 100644 --- a/src/data_designer/engine/column_generators/generators/expression.py +++ b/src/data_designer/engine/column_generators/generators/expression.py @@ -5,7 +5,7 @@ import pandas as pd -from data_designer.config.columns import ExpressionColumnConfig +from data_designer.config.column_configs import ExpressionColumnConfig from data_designer.engine.column_generators.generators.base import ( ColumnGenerator, GenerationStrategy, diff --git a/src/data_designer/engine/column_generators/generators/llm_generators.py b/src/data_designer/engine/column_generators/generators/llm_generators.py index 14324f36..ee0ab58a 100644 --- a/src/data_designer/engine/column_generators/generators/llm_generators.py +++ b/src/data_designer/engine/column_generators/generators/llm_generators.py @@ -4,13 +4,13 @@ import functools import logging -from data_designer.config.columns import ( - COLUMN_TYPE_EMOJI_MAP, +from data_designer.config.column_configs import ( LLMCodeColumnConfig, LLMJudgeColumnConfig, LLMStructuredColumnConfig, LLMTextColumnConfig, ) +from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP from data_designer.config.models import InferenceParameters, ModelConfig from data_designer.config.utils.constants import REASONING_TRACE_COLUMN_POSTFIX from data_designer.engine.column_generators.generators.base import ( diff --git a/src/data_designer/engine/column_generators/generators/validation.py b/src/data_designer/engine/column_generators/generators/validation.py index 424165fc..f46296b9 100644 --- a/src/data_designer/engine/column_generators/generators/validation.py +++ b/src/data_designer/engine/column_generators/generators/validation.py @@ -5,7 +5,7 @@ import pandas as pd -from data_designer.config.columns import ValidationColumnConfig +from data_designer.config.column_configs import ValidationColumnConfig from data_designer.config.errors import InvalidConfigError from data_designer.config.utils.code_lang import SQL_DIALECTS, CodeLang from data_designer.config.validator_params import ( diff --git a/src/data_designer/engine/column_generators/registry.py b/src/data_designer/engine/column_generators/registry.py index 95c6ee7a..61b43753 100644 --- a/src/data_designer/engine/column_generators/registry.py +++ b/src/data_designer/engine/column_generators/registry.py @@ -2,8 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from data_designer.config.base import ConfigBase -from data_designer.config.columns import ( - DataDesignerColumnType, +from data_designer.config.column_configs import ( ExpressionColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, @@ -11,6 +10,7 @@ LLMTextColumnConfig, ValidationColumnConfig, ) +from data_designer.config.column_types import DataDesignerColumnType from data_designer.engine.column_generators.generators.base import ColumnGenerator from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator from data_designer.engine.column_generators.generators.llm_generators import ( @@ -27,34 +27,30 @@ SeedDatasetMultiColumnConfig, ) from data_designer.engine.registry.base import TaskRegistry +from data_designer.plugins.plugin import PluginType +from data_designer.plugins.registry import PluginRegistry class ColumnGeneratorRegistry(TaskRegistry[DataDesignerColumnType, ColumnGenerator, ConfigBase]): ... -def create_default_column_generator_registry() -> ColumnGeneratorRegistry: +def create_default_column_generator_registry(with_plugins: bool = True) -> ColumnGeneratorRegistry: registry = ColumnGeneratorRegistry() - registry.register(DataDesignerColumnType.LLM_TEXT, LLMTextCellGenerator, LLMTextColumnConfig, False) - registry.register(DataDesignerColumnType.LLM_CODE, LLMCodeCellGenerator, LLMCodeColumnConfig, False) - registry.register(DataDesignerColumnType.LLM_JUDGE, LLMJudgeCellGenerator, LLMJudgeColumnConfig, False) - registry.register(DataDesignerColumnType.EXPRESSION, ExpressionColumnGenerator, ExpressionColumnConfig, False) - registry.register(DataDesignerColumnType.SAMPLER, SamplerColumnGenerator, SamplerMultiColumnConfig, False) - registry.register( - DataDesignerColumnType.SEED_DATASET, - SeedDatasetColumnGenerator, - SeedDatasetMultiColumnConfig, - False, - ) - registry.register( - DataDesignerColumnType.VALIDATION, - ValidationColumnGenerator, - ValidationColumnConfig, - False, - ) - registry.register( - DataDesignerColumnType.LLM_STRUCTURED, - LLMStructuredCellGenerator, - LLMStructuredColumnConfig, - False, - ) + registry.register(DataDesignerColumnType.LLM_TEXT, LLMTextCellGenerator, LLMTextColumnConfig) + registry.register(DataDesignerColumnType.LLM_CODE, LLMCodeCellGenerator, LLMCodeColumnConfig) + registry.register(DataDesignerColumnType.LLM_JUDGE, LLMJudgeCellGenerator, LLMJudgeColumnConfig) + registry.register(DataDesignerColumnType.EXPRESSION, ExpressionColumnGenerator, ExpressionColumnConfig) + registry.register(DataDesignerColumnType.SAMPLER, SamplerColumnGenerator, SamplerMultiColumnConfig) + registry.register(DataDesignerColumnType.SEED_DATASET, SeedDatasetColumnGenerator, SeedDatasetMultiColumnConfig) + registry.register(DataDesignerColumnType.VALIDATION, ValidationColumnGenerator, ValidationColumnConfig) + registry.register(DataDesignerColumnType.LLM_STRUCTURED, LLMStructuredCellGenerator, LLMStructuredColumnConfig) + + if with_plugins: + for plugin in PluginRegistry().get_plugins(PluginType.COLUMN_GENERATOR): + registry.register( + DataDesignerColumnType(plugin.name), + plugin.task_cls, + plugin.config_cls, + ) + return registry diff --git a/src/data_designer/engine/column_generators/utils/judge_score_factory.py b/src/data_designer/engine/column_generators/utils/judge_score_factory.py index b4d458c6..a02a86ee 100644 --- a/src/data_designer/engine/column_generators/utils/judge_score_factory.py +++ b/src/data_designer/engine/column_generators/utils/judge_score_factory.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict, Field, create_model -from data_designer.config.columns import Score +from data_designer.config.column_configs import Score SCORING_FORMAT = "* {score}: {description}" SCORE_FIELD_DESCRIPTION_FORMAT = "Score Descriptions for {enum_name}:\n{scoring}" diff --git a/src/data_designer/engine/column_generators/utils/prompt_renderer.py b/src/data_designer/engine/column_generators/utils/prompt_renderer.py index d895fe9e..0c1cd708 100644 --- a/src/data_designer/engine/column_generators/utils/prompt_renderer.py +++ b/src/data_designer/engine/column_generators/utils/prompt_renderer.py @@ -5,7 +5,8 @@ import json import logging -from data_designer.config.columns import DataDesignerColumnType, SingleColumnConfig +from data_designer.config.column_configs import SingleColumnConfig +from data_designer.config.column_types import DataDesignerColumnType from data_designer.config.models import ModelConfig from data_designer.config.utils.code_lang import CodeLang from data_designer.config.utils.misc import get_prompt_template_keywords diff --git a/src/data_designer/engine/dataset_builders/column_wise_builder.py b/src/data_designer/engine/dataset_builders/column_wise_builder.py index 26e794ca..14d63dce 100644 --- a/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -10,7 +10,7 @@ import pandas as pd -from data_designer.config.columns import ColumnConfigT, column_type_is_llm_generated +from data_designer.config.column_types import ColumnConfigT, column_type_is_llm_generated from data_designer.config.dataset_builders import BuildStage from data_designer.config.processors import ( DropColumnsProcessorConfig, diff --git a/src/data_designer/engine/dataset_builders/multi_column_configs.py b/src/data_designer/engine/dataset_builders/multi_column_configs.py index b7435706..fc2e74f5 100644 --- a/src/data_designer/engine/dataset_builders/multi_column_configs.py +++ b/src/data_designer/engine/dataset_builders/multi_column_configs.py @@ -7,13 +7,8 @@ from pydantic import Field, field_validator from data_designer.config.base import ConfigBase -from data_designer.config.columns import ( - ColumnConfigT, - DataDesignerColumnType, - SamplerColumnConfig, - SeedDatasetColumnConfig, - SingleColumnConfig, -) +from data_designer.config.column_configs import SamplerColumnConfig, SeedDatasetColumnConfig, SingleColumnConfig +from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType from data_designer.config.sampler_constraints import ColumnConstraintT from data_designer.config.seed import SeedConfig diff --git a/src/data_designer/engine/dataset_builders/utils/config_compiler.py b/src/data_designer/engine/dataset_builders/utils/config_compiler.py index d80ec37a..5359bfa1 100644 --- a/src/data_designer/engine/dataset_builders/utils/config_compiler.py +++ b/src/data_designer/engine/dataset_builders/utils/config_compiler.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from data_designer.config.columns import DataDesignerColumnType +from data_designer.config.column_types import DataDesignerColumnType from data_designer.config.data_designer_config import DataDesignerConfig from data_designer.config.processors import ProcessorConfig from data_designer.engine.dataset_builders.multi_column_configs import ( diff --git a/src/data_designer/engine/dataset_builders/utils/dag.py b/src/data_designer/engine/dataset_builders/utils/dag.py index e6d9da14..6c056d11 100644 --- a/src/data_designer/engine/dataset_builders/utils/dag.py +++ b/src/data_designer/engine/dataset_builders/utils/dag.py @@ -5,7 +5,7 @@ import networkx as nx -from data_designer.config.columns import ColumnConfigT, column_type_used_in_execution_dag +from data_designer.config.column_types import ColumnConfigT, column_type_used_in_execution_dag from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError logger = logging.getLogger(__name__) diff --git a/src/data_designer/engine/registry/base.py b/src/data_designer/engine/registry/base.py index 5f780940..47e51b85 100644 --- a/src/data_designer/engine/registry/base.py +++ b/src/data_designer/engine/registry/base.py @@ -35,7 +35,7 @@ def register( name: EnumNameT, task: Type[TaskT], config: Type[TaskConfigT], - raise_on_collision: bool = True, + raise_on_collision: bool = False, ) -> None: if cls._has_been_registered(name): if not raise_on_collision: diff --git a/src/data_designer/engine/registry/data_designer_registry.py b/src/data_designer/engine/registry/data_designer_registry.py index 8ed2f0ba..407029c3 100644 --- a/src/data_designer/engine/registry/data_designer_registry.py +++ b/src/data_designer/engine/registry/data_designer_registry.py @@ -9,10 +9,7 @@ ColumnGeneratorRegistry, create_default_column_generator_registry, ) -from data_designer.engine.processing.processors.registry import ( - ProcessorRegistry, - create_default_processor_registry, -) +from data_designer.engine.processing.processors.registry import ProcessorRegistry, create_default_processor_registry class DataDesignerRegistry: diff --git a/src/data_designer/engine/sampling_gen/column.py b/src/data_designer/engine/sampling_gen/column.py index 12c910c7..26ada316 100644 --- a/src/data_designer/engine/sampling_gen/column.py +++ b/src/data_designer/engine/sampling_gen/column.py @@ -6,7 +6,7 @@ from pydantic import field_serializer, model_validator from typing_extensions import Self -from data_designer.config.columns import SamplerColumnConfig +from data_designer.config.column_configs import SamplerColumnConfig from data_designer.config.sampler_params import SamplerParamsT, SamplerType from data_designer.engine.sampling_gen.data_sources.base import DataSource from data_designer.engine.sampling_gen.data_sources.sources import SamplerRegistry diff --git a/src/data_designer/engine/sampling_gen/schema_builder.py b/src/data_designer/engine/sampling_gen/schema_builder.py index c5ee010f..3c3e3eaf 100644 --- a/src/data_designer/engine/sampling_gen/schema_builder.py +++ b/src/data_designer/engine/sampling_gen/schema_builder.py @@ -3,7 +3,7 @@ from copy import deepcopy -from data_designer.config.columns import SamplerColumnConfig +from data_designer.config.column_configs import SamplerColumnConfig from data_designer.config.sampler_constraints import ColumnConstraintT from data_designer.config.sampler_params import SamplerParamsT from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig diff --git a/src/data_designer/essentials/__init__.py b/src/data_designer/essentials/__init__.py index 2d6516c6..70ec184d 100644 --- a/src/data_designer/essentials/__init__.py +++ b/src/data_designer/essentials/__init__.py @@ -1,9 +1,11 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from ..logging import LoggingConfig, configure_logging + +configure_logging(LoggingConfig.default()) from ..config.analysis.column_profilers import JudgeScoreProfilerConfig -from ..config.columns import ( - DataDesignerColumnType, +from ..config.column_configs import ( ExpressionColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, @@ -14,6 +16,7 @@ SeedDatasetColumnConfig, ValidationColumnConfig, ) +from ..config.column_types import DataDesignerColumnType from ..config.config_builder import DataDesignerConfigBuilder from ..config.data_designer_config import DataDesignerConfig from ..config.dataset_builders import BuildStage @@ -58,7 +61,6 @@ RemoteValidatorParams, ValidatorType, ) -from ..logging import LoggingConfig, configure_logging local_library_imports = [] try: @@ -131,6 +133,3 @@ ] __all__.extend(local_library_imports) - - -configure_logging(LoggingConfig.default()) diff --git a/src/data_designer/plugin_manager.py b/src/data_designer/plugin_manager.py new file mode 100644 index 00000000..923138ea --- /dev/null +++ b/src/data_designer/plugin_manager.py @@ -0,0 +1,83 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING, Type, TypeAlias + +from .config.utils.misc import can_run_data_designer_locally + +if TYPE_CHECKING: + from data_designer.plugins.plugin import Plugin + + +if can_run_data_designer_locally(): + from data_designer.plugins.plugin import PluginType + from data_designer.plugins.registry import PluginRegistry + + +class PluginManager: + def __init__(self): + if can_run_data_designer_locally(): + self._plugins_supported = True + self._plugin_registry = PluginRegistry() + else: + self._plugins_supported = False + self._plugin_registry = None + + def get_column_generator_plugins(self) -> list[Plugin]: + """Get all column generator plugins. + + Returns: + A list of all column generator plugins. + """ + return self._plugin_registry.get_plugins(PluginType.COLUMN_GENERATOR) if self._plugins_supported else [] + + def get_column_generator_plugin_if_exists(self, plugin_name: str) -> Plugin | None: + """Get a column generator plugin by name if it exists. + + Args: + plugin_name: The name of the plugin to retrieve. + + Returns: + The plugin if found, otherwise None. + """ + if self._plugins_supported and self._plugin_registry.plugin_exists(plugin_name): + return self._plugin_registry.get_plugin(plugin_name) + return None + + def get_plugin_column_types(self, enum_type: Type[Enum], required_resources: list[str] | None = None) -> list[Enum]: + """Get a list of plugin column types. + + Args: + enum_type: The enum type to use for plugin entries. + required_resources: If provided, only return plugins with the required resources. + + Returns: + A list of plugin column types. + """ + type_list = [] + if self._plugins_supported: + for plugin in self._plugin_registry.get_plugins(PluginType.COLUMN_GENERATOR): + if required_resources: + task_required_resources = plugin.task_cls.metadata().required_resources or [] + if not all(resource in task_required_resources for resource in required_resources): + continue + type_list.append(enum_type(plugin.name)) + return type_list + + def inject_into_column_config_type_union(self, column_config_type: Type[TypeAlias]) -> Type[TypeAlias]: + """Inject plugins into the column config type. + + Args: + column_config_type: The column config type to inject plugins into. + + Returns: + The column config type with plugins injected. + """ + if self._plugins_supported: + column_config_type = self._plugin_registry.add_plugin_types_to_union( + column_config_type, PluginType.COLUMN_GENERATOR + ) + return column_config_type diff --git a/src/data_designer/plugins/__init__.py b/src/data_designer/plugins/__init__.py new file mode 100644 index 00000000..b7acb81e --- /dev/null +++ b/src/data_designer/plugins/__init__.py @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from data_designer.plugins.plugin import Plugin, PluginType + +__all__ = ["Plugin", "PluginType"] diff --git a/src/data_designer/plugins/errors.py b/src/data_designer/plugins/errors.py new file mode 100644 index 00000000..de6e4435 --- /dev/null +++ b/src/data_designer/plugins/errors.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from data_designer.errors import DataDesignerError + + +class PluginRegistrationError(DataDesignerError): ... + + +class PluginNotFoundError(DataDesignerError): ... diff --git a/src/data_designer/plugins/plugin.py b/src/data_designer/plugins/plugin.py new file mode 100644 index 00000000..886a2252 --- /dev/null +++ b/src/data_designer/plugins/plugin.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from enum import Enum +from typing import Literal, Type, get_origin + +from pydantic import BaseModel, model_validator +from typing_extensions import Self + +from data_designer.config.base import ConfigBase +from data_designer.engine.configurable_task import ConfigurableTask + + +class PluginType(str, Enum): + COLUMN_GENERATOR = "column-generator" + + @property + def discriminator_field(self) -> str: + if self == PluginType.COLUMN_GENERATOR: + return "column_type" + else: + raise ValueError(f"Invalid plugin type: {self.value}") + + @property + def display_name(self) -> str: + return self.value.replace("-", " ") + + +class Plugin(BaseModel): + task_cls: Type[ConfigurableTask] + config_cls: Type[ConfigBase] + plugin_type: PluginType + emoji: str = "🔌" + + @property + def config_type_as_class_name(self) -> str: + return self.enum_key_name.title().replace("_", "") + + @property + def enum_key_name(self) -> str: + return self.name.replace("-", "_").upper() + + @property + def name(self) -> str: + return self.config_cls.model_fields[self.discriminator_field].default + + @property + def discriminator_field(self) -> str: + return self.plugin_type.discriminator_field + + @model_validator(mode="after") + def validate_discriminator_field(self) -> Self: + cfg = self.config_cls.__name__ + field = self.plugin_type.discriminator_field + if field not in self.config_cls.model_fields: + raise ValueError(f"Discriminator field '{field}' not found in config class {cfg}") + field_info = self.config_cls.model_fields[field] + if get_origin(field_info.annotation) is not Literal: + raise ValueError(f"Field '{field}' of {cfg} must be a Literal type, not {field_info.annotation}.") + if not isinstance(field_info.default, str): + raise ValueError(f"The default of '{field}' must be a string, not {type(field_info.default)}.") + enum_key = field_info.default.replace("-", "_").upper() + if not enum_key.isidentifier(): + raise ValueError( + f"The default value '{field_info.default}' for discriminator field '{field}' " + f"cannot be converted to a valid enum key. The converted key '{enum_key}' " + f"must be a valid Python identifier." + ) + return self diff --git a/src/data_designer/plugins/registry.py b/src/data_designer/plugins/registry.py new file mode 100644 index 00000000..6ef465e0 --- /dev/null +++ b/src/data_designer/plugins/registry.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from importlib.metadata import entry_points +import logging +import os +import threading +from typing import Type, TypeAlias + +from typing_extensions import Self + +from data_designer.plugins.errors import PluginNotFoundError +from data_designer.plugins.plugin import Plugin, PluginType + +logger = logging.getLogger(__name__) + + +PLUGINS_DISABLED = os.getenv("DISABLE_DATA_DESIGNER_PLUGINS", "false").lower() == "true" + + +class PluginRegistry: + _instance = None + _plugins_discovered = False + _lock = threading.Lock() + + _plugins: dict[str, Plugin] = {} + + def __init__(self): + with self._lock: + if not self._plugins_discovered: + self._discover() + + @classmethod + def reset(cls) -> None: + with cls._lock: + cls._instance = None + cls._plugins_discovered = False + cls._plugins = {} + + def add_plugin_types_to_union(self, type_union: Type[TypeAlias], plugin_type: PluginType) -> Type[TypeAlias]: + for plugin in self.get_plugins(plugin_type): + if plugin.config_cls not in type_union.__args__: + type_union |= plugin.config_cls + return type_union + + def get_plugin(self, plugin_name: str) -> Plugin: + if plugin_name not in self._plugins: + raise PluginNotFoundError(f"Plugin {plugin_name!r} not found.") + return self._plugins[plugin_name] + + def get_plugins(self, plugin_type: PluginType) -> list[Plugin]: + return [plugin for plugin in self._plugins.values() if plugin.plugin_type == plugin_type] + + def get_plugin_names(self, plugin_type: PluginType) -> list[str]: + return [plugin.name for plugin in self.get_plugins(plugin_type)] + + def num_plugins(self, plugin_type: PluginType) -> int: + return len(self.get_plugins(plugin_type)) + + def plugin_exists(self, plugin_name: str) -> bool: + return plugin_name in self._plugins + + def _discover(self) -> Self: + if PLUGINS_DISABLED: + return self + for ep in entry_points(group="data_designer.plugins"): + try: + plugin = ep.load() + if isinstance(plugin, Plugin): + logger.info( + f"🔌 Plugin discovered ➜ {plugin.plugin_type.display_name} " + f"{plugin.enum_key_name} is now available ⚡️" + ) + self._plugins[plugin.name] = plugin + except Exception as e: + logger.warning(f"🛑 Failed to load plugin from entry point {ep.name!r}: {e}") + self._plugins_discovered = True + return self + + def __new__(cls, *args, **kwargs): + """Plugin manager is a singleton.""" + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super().__new__(cls) + return cls._instance diff --git a/tests/config/analysis/utils/test_reporting.py b/tests/config/analysis/utils/test_reporting.py index ec9a9912..82b5234c 100644 --- a/tests/config/analysis/utils/test_reporting.py +++ b/tests/config/analysis/utils/test_reporting.py @@ -11,7 +11,7 @@ from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults from data_designer.config.analysis.utils.errors import AnalysisReportError from data_designer.config.analysis.utils.reporting import ReportSection, generate_analysis_report -from data_designer.config.columns import DataDesignerColumnType +from data_designer.config.column_types import DataDesignerColumnType @pytest.fixture diff --git a/tests/config/test_columns.py b/tests/config/test_columns.py index df2d0668..2accdaa3 100644 --- a/tests/config/test_columns.py +++ b/tests/config/test_columns.py @@ -4,8 +4,7 @@ from pydantic import ValidationError import pytest -from data_designer.config.columns import ( - DataDesignerColumnType, +from data_designer.config.column_configs import ( ExpressionColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, @@ -15,6 +14,9 @@ Score, SeedDatasetColumnConfig, ValidationColumnConfig, +) +from data_designer.config.column_types import ( + DataDesignerColumnType, column_type_is_llm_generated, column_type_used_in_execution_dag, get_column_config_from_kwargs, diff --git a/tests/config/test_config_builder.py b/tests/config/test_config_builder.py index 76ce7bc7..fef6cfb9 100644 --- a/tests/config/test_config_builder.py +++ b/tests/config/test_config_builder.py @@ -11,8 +11,7 @@ import yaml from data_designer.config.analysis.column_profilers import JudgeScoreProfilerConfig -from data_designer.config.columns import ( - DataDesignerColumnType, +from data_designer.config.column_configs import ( ExpressionColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, @@ -21,8 +20,8 @@ SamplerColumnConfig, Score, ValidationColumnConfig, - get_column_config_from_kwargs, ) +from data_designer.config.column_types import DataDesignerColumnType, get_column_config_from_kwargs from data_designer.config.config_builder import BuilderConfig, DataDesignerConfigBuilder from data_designer.config.data_designer_config import DataDesignerConfig from data_designer.config.datastore import DatastoreSettings diff --git a/tests/config/utils/test_validation.py b/tests/config/utils/test_validation.py index 451afbbe..0b5c5684 100644 --- a/tests/config/utils/test_validation.py +++ b/tests/config/utils/test_validation.py @@ -3,7 +3,7 @@ from unittest.mock import Mock, patch -from data_designer.config.columns import ( +from data_designer.config.column_configs import ( ExpressionColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, diff --git a/tests/conftest.py b/tests/conftest.py index f0e6546e..31dc0057 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,7 +13,7 @@ from data_designer.config.analysis.column_statistics import GeneralColumnStatistics from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults -from data_designer.config.columns import SamplerColumnConfig +from data_designer.config.column_configs import SamplerColumnConfig from data_designer.config.config_builder import DataDesignerConfigBuilder from data_designer.config.data_designer_config import DataDesignerConfig from data_designer.config.datastore import DatastoreSettings diff --git a/tests/engine/analysis/column_profilers/test_base.py b/tests/engine/analysis/column_profilers/test_base.py index 966a78ab..be74f2b5 100644 --- a/tests/engine/analysis/column_profilers/test_base.py +++ b/tests/engine/analysis/column_profilers/test_base.py @@ -5,7 +5,9 @@ from pydantic import ValidationError import pytest -from data_designer.config.columns import DataDesignerColumnType, SamplerColumnConfig, SamplerType +from data_designer.config.column_configs import SamplerColumnConfig +from data_designer.config.column_types import DataDesignerColumnType +from data_designer.config.sampler_params import SamplerType from data_designer.engine.analysis.column_profilers.base import ( ColumnConfigWithDataFrame, ColumnProfilerMetadata, diff --git a/tests/engine/analysis/column_profilers/test_judge_score_profiler.py b/tests/engine/analysis/column_profilers/test_judge_score_profiler.py index a380b5c4..343f44a6 100644 --- a/tests/engine/analysis/column_profilers/test_judge_score_profiler.py +++ b/tests/engine/analysis/column_profilers/test_judge_score_profiler.py @@ -14,7 +14,7 @@ MissingValue, NumericalDistribution, ) -from data_designer.config.columns import LLMJudgeColumnConfig, Score +from data_designer.config.column_configs import LLMJudgeColumnConfig, Score from data_designer.engine.analysis.column_profilers.base import ColumnConfigWithDataFrame from data_designer.engine.analysis.column_profilers.judge_score_profiler import ( JudgeScoreProfiler, diff --git a/tests/engine/analysis/conftest.py b/tests/engine/analysis/conftest.py index 7d571d5f..ff64af3f 100644 --- a/tests/engine/analysis/conftest.py +++ b/tests/engine/analysis/conftest.py @@ -14,7 +14,8 @@ ColumnDistributionType, NumericalDistribution, ) -from data_designer.config.columns import ColumnConfigT, LLMJudgeColumnConfig, Score +from data_designer.config.column_configs import LLMJudgeColumnConfig, Score +from data_designer.config.column_types import ColumnConfigT from data_designer.config.models import ModelConfig from data_designer.engine.analysis.dataset_profiler import ( DataDesignerDatasetProfiler, diff --git a/tests/engine/analysis/test_column_statistics_calculator.py b/tests/engine/analysis/test_column_statistics_calculator.py index b72c1b93..a9030d9e 100644 --- a/tests/engine/analysis/test_column_statistics_calculator.py +++ b/tests/engine/analysis/test_column_statistics_calculator.py @@ -4,7 +4,7 @@ import pandas as pd from data_designer.config.analysis.column_statistics import ColumnDistributionType -from data_designer.config.columns import DataDesignerColumnType +from data_designer.config.column_types import DataDesignerColumnType from data_designer.config.sampler_params import SamplerType from data_designer.engine.analysis.column_profilers.base import ColumnConfigWithDataFrame from data_designer.engine.analysis.column_statistics import get_column_statistics_calculator diff --git a/tests/engine/analysis/test_dataset_profiler.py b/tests/engine/analysis/test_dataset_profiler.py index 154f6360..0bd1e6cf 100644 --- a/tests/engine/analysis/test_dataset_profiler.py +++ b/tests/engine/analysis/test_dataset_profiler.py @@ -6,7 +6,7 @@ import pytest from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults -from data_designer.config.columns import SamplerColumnConfig +from data_designer.config.column_configs import SamplerColumnConfig from data_designer.config.sampler_params import CategorySamplerParams, SamplerType from data_designer.engine.analysis.column_profilers.judge_score_profiler import JudgeScoreProfilerConfig from data_designer.engine.analysis.dataset_profiler import ( diff --git a/tests/engine/analysis/utils/test_column_statistics_calculations.py b/tests/engine/analysis/utils/test_column_statistics_calculations.py index 94b20503..2f7316aa 100644 --- a/tests/engine/analysis/utils/test_column_statistics_calculations.py +++ b/tests/engine/analysis/utils/test_column_statistics_calculations.py @@ -16,7 +16,7 @@ MissingValue, NumericalDistribution, ) -from data_designer.config.columns import LLMTextColumnConfig +from data_designer.config.column_configs import LLMTextColumnConfig from data_designer.config.utils.numerical_helpers import prepare_number_for_reporting from data_designer.engine.analysis.utils.column_statistics_calculations import ( calculate_column_distribution, diff --git a/tests/engine/analysis/utils/test_judge_score_processing.py b/tests/engine/analysis/utils/test_judge_score_processing.py index 85663c75..e20e65f5 100644 --- a/tests/engine/analysis/utils/test_judge_score_processing.py +++ b/tests/engine/analysis/utils/test_judge_score_processing.py @@ -11,7 +11,7 @@ MissingValue, NumericalDistribution, ) -from data_designer.config.columns import LLMJudgeColumnConfig, Score +from data_designer.config.column_configs import LLMJudgeColumnConfig, Score from data_designer.engine.analysis.utils.judge_score_processing import ( JudgeScoreDistributions, JudgeScoreSample, diff --git a/tests/engine/column_generators/generators/test_column_generator_base.py b/tests/engine/column_generators/generators/test_column_generator_base.py index 433214f5..511b1c0f 100644 --- a/tests/engine/column_generators/generators/test_column_generator_base.py +++ b/tests/engine/column_generators/generators/test_column_generator_base.py @@ -5,7 +5,7 @@ import pandas as pd -from data_designer.config.columns import ExpressionColumnConfig +from data_designer.config.column_configs import ExpressionColumnConfig from data_designer.engine.column_generators.generators.base import ( ColumnGenerator, FromScratchColumnGenerator, diff --git a/tests/engine/column_generators/generators/test_expression.py b/tests/engine/column_generators/generators/test_expression.py index e2db903d..ec42c99d 100644 --- a/tests/engine/column_generators/generators/test_expression.py +++ b/tests/engine/column_generators/generators/test_expression.py @@ -6,7 +6,7 @@ import pandas as pd import pytest -from data_designer.config.columns import ExpressionColumnConfig +from data_designer.config.column_configs import ExpressionColumnConfig from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator from data_designer.engine.column_generators.utils.errors import ExpressionTemplateRenderError from data_designer.engine.resources.resource_provider import ResourceProvider diff --git a/tests/engine/column_generators/generators/test_llm_generators.py b/tests/engine/column_generators/generators/test_llm_generators.py index 8854aa7d..259f3a08 100644 --- a/tests/engine/column_generators/generators/test_llm_generators.py +++ b/tests/engine/column_generators/generators/test_llm_generators.py @@ -5,7 +5,7 @@ import pytest -from data_designer.config.columns import ( +from data_designer.config.column_configs import ( LLMCodeColumnConfig, LLMJudgeColumnConfig, LLMStructuredColumnConfig, diff --git a/tests/engine/column_generators/generators/test_samplers.py b/tests/engine/column_generators/generators/test_samplers.py index 0b30b969..23e68b8e 100644 --- a/tests/engine/column_generators/generators/test_samplers.py +++ b/tests/engine/column_generators/generators/test_samplers.py @@ -3,7 +3,7 @@ import pytest -from data_designer.config.columns import SamplerColumnConfig +from data_designer.config.column_configs import SamplerColumnConfig from data_designer.config.sampler_params import ( BernoulliMixtureSamplerParams, BernoulliSamplerParams, diff --git a/tests/engine/column_generators/generators/test_seed_dataset.py b/tests/engine/column_generators/generators/test_seed_dataset.py index ebb9c72a..37e69a26 100644 --- a/tests/engine/column_generators/generators/test_seed_dataset.py +++ b/tests/engine/column_generators/generators/test_seed_dataset.py @@ -9,7 +9,7 @@ import pandas as pd import pytest -from data_designer.config.columns import SeedDatasetColumnConfig +from data_designer.config.column_configs import SeedDatasetColumnConfig from data_designer.config.seed import IndexRange, PartitionBlock, SamplingStrategy from data_designer.engine.column_generators.generators.base import GenerationStrategy from data_designer.engine.column_generators.generators.seed_dataset import ( diff --git a/tests/engine/column_generators/generators/test_validation.py b/tests/engine/column_generators/generators/test_validation.py index e84d0335..65f9bdd4 100644 --- a/tests/engine/column_generators/generators/test_validation.py +++ b/tests/engine/column_generators/generators/test_validation.py @@ -6,7 +6,7 @@ import pandas as pd import pytest -from data_designer.config.columns import ValidationColumnConfig +from data_designer.config.column_configs import ValidationColumnConfig from data_designer.config.utils.code_lang import CodeLang from data_designer.config.validator_params import ( CodeValidatorParams, diff --git a/tests/engine/column_generators/test_registry.py b/tests/engine/column_generators/test_registry.py index e658283e..f70b0d90 100644 --- a/tests/engine/column_generators/test_registry.py +++ b/tests/engine/column_generators/test_registry.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from data_designer.config.columns import DataDesignerColumnType +from data_designer.config.column_types import DataDesignerColumnType from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator from data_designer.engine.column_generators.generators.llm_generators import ( LLMCodeCellGenerator, diff --git a/tests/engine/column_generators/utils/test_judge_score_factory.py b/tests/engine/column_generators/utils/test_judge_score_factory.py index d692247a..afb830f3 100644 --- a/tests/engine/column_generators/utils/test_judge_score_factory.py +++ b/tests/engine/column_generators/utils/test_judge_score_factory.py @@ -6,7 +6,7 @@ from pydantic import BaseModel import pytest -from data_designer.config.columns import Score +from data_designer.config.column_configs import Score from data_designer.engine.column_generators.utils.judge_score_factory import ( SCORE_FIELD_DESCRIPTION_FORMAT, SCORING_FORMAT, diff --git a/tests/engine/column_generators/utils/test_prompt_renderer.py b/tests/engine/column_generators/utils/test_prompt_renderer.py index 09c70253..cd3697b9 100644 --- a/tests/engine/column_generators/utils/test_prompt_renderer.py +++ b/tests/engine/column_generators/utils/test_prompt_renderer.py @@ -5,7 +5,7 @@ import pytest -from data_designer.config.columns import ( +from data_designer.config.column_configs import ( LLMCodeColumnConfig, LLMJudgeColumnConfig, LLMStructuredColumnConfig, diff --git a/tests/engine/dataset_builders/test_column_wise_builder.py b/tests/engine/dataset_builders/test_column_wise_builder.py index 487efca7..5572af44 100644 --- a/tests/engine/dataset_builders/test_column_wise_builder.py +++ b/tests/engine/dataset_builders/test_column_wise_builder.py @@ -6,7 +6,7 @@ import pandas as pd import pytest -from data_designer.config.columns import LLMTextColumnConfig, SamplerColumnConfig +from data_designer.config.column_configs import LLMTextColumnConfig, SamplerColumnConfig from data_designer.config.dataset_builders import BuildStage from data_designer.config.processors import DropColumnsProcessorConfig from data_designer.engine.dataset_builders.column_wise_builder import ( diff --git a/tests/engine/dataset_builders/test_multi_column_configs.py b/tests/engine/dataset_builders/test_multi_column_configs.py index e030c1da..2754d5c5 100644 --- a/tests/engine/dataset_builders/test_multi_column_configs.py +++ b/tests/engine/dataset_builders/test_multi_column_configs.py @@ -4,11 +4,8 @@ from pydantic import ValidationError import pytest -from data_designer.config.columns import ( - DataDesignerColumnType, - SamplerColumnConfig, - SeedDatasetColumnConfig, -) +from data_designer.config.column_configs import SamplerColumnConfig, SeedDatasetColumnConfig +from data_designer.config.column_types import DataDesignerColumnType from data_designer.config.sampler_constraints import ( ColumnInequalityConstraint, InequalityOperator, diff --git a/tests/engine/dataset_builders/utils/test_config_compiler.py b/tests/engine/dataset_builders/utils/test_config_compiler.py index b17dab5a..00809e46 100644 --- a/tests/engine/dataset_builders/utils/test_config_compiler.py +++ b/tests/engine/dataset_builders/utils/test_config_compiler.py @@ -3,12 +3,8 @@ import pytest -from data_designer.config.columns import ( - DataDesignerColumnType, - LLMTextColumnConfig, - SamplerColumnConfig, - SeedDatasetColumnConfig, -) +from data_designer.config.column_configs import LLMTextColumnConfig, SamplerColumnConfig, SeedDatasetColumnConfig +from data_designer.config.column_types import DataDesignerColumnType from data_designer.config.data_designer_config import DataDesignerConfig from data_designer.config.seed import SamplingStrategy, SeedConfig from data_designer.engine.dataset_builders.utils.config_compiler import ( diff --git a/tests/engine/dataset_builders/utils/test_dag.py b/tests/engine/dataset_builders/utils/test_dag.py index dec06be8..b914ff85 100644 --- a/tests/engine/dataset_builders/utils/test_dag.py +++ b/tests/engine/dataset_builders/utils/test_dag.py @@ -3,8 +3,7 @@ import pytest -from data_designer.config.columns import ( - DataDesignerColumnType, +from data_designer.config.column_configs import ( ExpressionColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, @@ -13,6 +12,7 @@ Score, ValidationColumnConfig, ) +from data_designer.config.column_types import DataDesignerColumnType from data_designer.config.sampler_params import SamplerType from data_designer.config.utils.code_lang import CodeLang from data_designer.config.validator_params import CodeValidatorParams diff --git a/tests/engine/registry/test_base.py b/tests/engine/registry/test_base.py index 73d6a981..3e653de1 100644 --- a/tests/engine/registry/test_base.py +++ b/tests/engine/registry/test_base.py @@ -94,12 +94,13 @@ def test_register_task_scenarios( TaskRegistry.register(stub_test_enum.TASK_A, stub_test_task_class, stub_test_config_class) with pytest.raises(expected_error, match="task_a has already been registered!"): - TaskRegistry.register(stub_test_enum.TASK_A, stub_test_task_class, stub_test_config_class) + TaskRegistry.register( + stub_test_enum.TASK_A, stub_test_task_class, stub_test_config_class, raise_on_collision=True + ) elif test_case == "register_task_collision_no_raise": TaskRegistry.register(stub_test_enum.TASK_A, stub_test_task_class, stub_test_config_class) - TaskRegistry.register( - stub_test_enum.TASK_A, stub_test_task_class, stub_test_config_class, raise_on_collision=False - ) + # Default behavior is raise_on_collision=False, so no need to pass it explicitly + TaskRegistry.register(stub_test_enum.TASK_A, stub_test_task_class, stub_test_config_class) @pytest.mark.parametrize( diff --git a/tests/plugins/test_plugin.py b/tests/plugins/test_plugin.py new file mode 100644 index 00000000..c505f8bb --- /dev/null +++ b/tests/plugins/test_plugin.py @@ -0,0 +1,181 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Literal + +from pydantic import ValidationError +import pytest + +from data_designer.config.base import ConfigBase +from data_designer.config.column_configs import SamplerColumnConfig, SingleColumnConfig +from data_designer.engine.column_generators.generators.samplers import SamplerColumnGenerator +from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata +from data_designer.plugins.plugin import Plugin, PluginType + + +class ValidTestConfig(SingleColumnConfig): + """Valid config for testing plugin creation.""" + + column_type: Literal["test-generator"] = "test-generator" + name: str + + +class ValidTestTask(ConfigurableTask[ValidTestConfig]): + """Valid task for testing plugin creation.""" + + @staticmethod + def metadata() -> ConfigurableTaskMetadata: + return ConfigurableTaskMetadata( + name="test_generator", + description="Test generator", + required_resources=None, + ) + + +@pytest.fixture +def valid_plugin() -> Plugin: + """Fixture providing a valid plugin instance for testing.""" + return Plugin( + task_cls=ValidTestTask, + config_cls=ValidTestConfig, + plugin_type=PluginType.COLUMN_GENERATOR, + ) + + +# ============================================================================= +# PluginType Tests +# ============================================================================= + + +def test_plugin_type_discriminator_field_for_column_generator() -> None: + """Test that COLUMN_GENERATOR type returns 'column_type' as discriminator field.""" + assert PluginType.COLUMN_GENERATOR.discriminator_field == "column_type" + + +def test_plugin_type_all_types_have_discriminator_fields() -> None: + """Test that all plugin types have valid discriminator fields.""" + for plugin_type in PluginType: + assert isinstance(plugin_type.discriminator_field, str) + assert len(plugin_type.discriminator_field) > 0 + + +# ============================================================================= +# Plugin Creation and Properties Tests +# ============================================================================= + + +def test_create_plugin_with_valid_inputs(valid_plugin: Plugin) -> None: + """Test that Plugin can be created with valid task, config, and plugin type.""" + assert valid_plugin.task_cls == ValidTestTask + assert valid_plugin.config_cls == ValidTestConfig + assert valid_plugin.plugin_type == PluginType.COLUMN_GENERATOR + + +def test_plugin_name_derived_from_config_default(valid_plugin: Plugin) -> None: + """Test that plugin.name returns the discriminator field's default value.""" + assert valid_plugin.name == "test-generator" + + +def test_plugin_discriminator_field_from_type(valid_plugin: Plugin) -> None: + """Test that plugin.discriminator_field returns the correct field name.""" + assert valid_plugin.discriminator_field == "column_type" + + +def test_plugin_requires_all_fields() -> None: + """Test that Plugin creation fails without required fields.""" + with pytest.raises(ValidationError): + Plugin() # type: ignore + + with pytest.raises(ValidationError): + Plugin(task_cls=ValidTestTask) # type: ignore + + with pytest.raises(ValidationError): + Plugin(config_cls=ValidTestConfig) # type: ignore + + +# ============================================================================= +# Plugin Validation Tests +# ============================================================================= + + +def test_validation_fails_when_config_missing_discriminator_field() -> None: + """Test validation fails when config lacks the required discriminator field.""" + + class ConfigWithoutDiscriminator(ConfigBase): + some_field: str + + with pytest.raises(ValueError, match="Discriminator field 'column_type' not found in config class"): + Plugin( + task_cls=ValidTestTask, + config_cls=ConfigWithoutDiscriminator, + plugin_type=PluginType.COLUMN_GENERATOR, + ) + + +def test_validation_fails_when_discriminator_field_not_literal_type() -> None: + """Test validation fails when discriminator field is not a Literal type.""" + + class ConfigWithStringField(ConfigBase): + column_type: str = "test-generator" + + with pytest.raises(ValueError, match="Field 'column_type' of .* must be a Literal type"): + Plugin( + task_cls=ValidTestTask, + config_cls=ConfigWithStringField, + plugin_type=PluginType.COLUMN_GENERATOR, + ) + + +def test_validation_fails_when_discriminator_default_not_string() -> None: + """Test validation fails when discriminator field default is not a string.""" + + class ConfigWithNonStringDefault(ConfigBase): + column_type: Literal["test-generator"] = 123 # type: ignore + + with pytest.raises(ValueError, match="The default of 'column_type' must be a string"): + Plugin( + task_cls=ValidTestTask, + config_cls=ConfigWithNonStringDefault, + plugin_type=PluginType.COLUMN_GENERATOR, + ) + + +def test_validation_fails_with_invalid_enum_key_conversion() -> None: + """Test validation fails when default value cannot be converted to valid Python identifier.""" + + class ConfigWithInvalidKey(ConfigBase): + column_type: Literal["invalid-key-!@#"] = "invalid-key-!@#" + + with pytest.raises(ValueError, match="cannot be converted to a valid enum key"): + Plugin( + task_cls=ValidTestTask, + config_cls=ConfigWithInvalidKey, + plugin_type=PluginType.COLUMN_GENERATOR, + ) + + +# ============================================================================= +# Integration Tests +# ============================================================================= + + +def test_plugin_works_with_real_sampler_column_generator() -> None: + """Test that Plugin works with actual SamplerColumnGenerator from the codebase.""" + plugin = Plugin( + task_cls=SamplerColumnGenerator, + config_cls=SamplerColumnConfig, + plugin_type=PluginType.COLUMN_GENERATOR, + ) + + assert plugin.name == "sampler" + assert plugin.discriminator_field == "column_type" + assert plugin.task_cls == SamplerColumnGenerator + assert plugin.config_cls == SamplerColumnConfig + + +def test_plugin_preserves_type_information(valid_plugin: Plugin) -> None: + """Test that Plugin correctly stores and provides access to type information.""" + assert isinstance(valid_plugin.task_cls, type) + assert isinstance(valid_plugin.config_cls, type) + assert issubclass(valid_plugin.task_cls, ConfigurableTask) + assert issubclass(valid_plugin.config_cls, ConfigBase) diff --git a/tests/plugins/test_plugin_registry.py b/tests/plugins/test_plugin_registry.py new file mode 100644 index 00000000..b3956f44 --- /dev/null +++ b/tests/plugins/test_plugin_registry.py @@ -0,0 +1,289 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from contextlib import contextmanager +from importlib.metadata import EntryPoint +import threading +from typing import Literal +from unittest.mock import MagicMock, patch + +import pytest + +from data_designer.config.base import ConfigBase +from data_designer.config.column_configs import SingleColumnConfig +from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata +from data_designer.plugins.errors import PluginNotFoundError +from data_designer.plugins.plugin import Plugin, PluginType +from data_designer.plugins.registry import PluginRegistry + +# ============================================================================= +# Test Stubs +# ============================================================================= + + +class StubPluginConfigA(SingleColumnConfig): + column_type: Literal["test-plugin-a"] = "test-plugin-a" + + +class StubPluginConfigB(SingleColumnConfig): + column_type: Literal["test-plugin-b"] = "test-plugin-b" + + +class StubPluginTaskA(ConfigurableTask[StubPluginConfigA]): + @staticmethod + def metadata() -> ConfigurableTaskMetadata: + return ConfigurableTaskMetadata( + name="test_plugin_a", + description="Test plugin A", + required_resources=None, + ) + + +class StubPluginTaskB(ConfigurableTask[StubPluginConfigB]): + @staticmethod + def metadata() -> ConfigurableTaskMetadata: + return ConfigurableTaskMetadata( + name="test_plugin_b", + description="Test plugin B", + required_resources=None, + ) + + +# ============================================================================= +# Test Fixtures +# ============================================================================= + + +@pytest.fixture +def plugin_a() -> Plugin: + return Plugin( + task_cls=StubPluginTaskA, + config_cls=StubPluginConfigA, + plugin_type=PluginType.COLUMN_GENERATOR, + ) + + +@pytest.fixture +def plugin_b() -> Plugin: + return Plugin( + task_cls=StubPluginTaskB, + config_cls=StubPluginConfigB, + plugin_type=PluginType.COLUMN_GENERATOR, + ) + + +@pytest.fixture(autouse=True) +def clean_plugin_registry() -> None: + """Reset PluginRegistry singleton state before and after each test.""" + PluginRegistry.reset() + + yield + + PluginRegistry.reset() + + +@pytest.fixture +def mock_plugin_discovery(): + """Mock plugin discovery to test with specific entry points.""" + + @contextmanager + def _mock_discovery(entry_points_list): + with patch("data_designer.plugins.registry.PLUGINS_DISABLED", False): + with patch("data_designer.plugins.registry.entry_points", return_value=entry_points_list): + yield + + return _mock_discovery + + +@pytest.fixture +def mock_entry_points(plugin_a: Plugin, plugin_b: Plugin) -> list[MagicMock]: + """Create mock entry points for plugin_a and plugin_b.""" + mock_ep_a = MagicMock(spec=EntryPoint) + mock_ep_a.name = "test-plugin-a" + mock_ep_a.load.return_value = plugin_a + + mock_ep_b = MagicMock(spec=EntryPoint) + mock_ep_b.name = "test-plugin-b" + mock_ep_b.load.return_value = plugin_b + + return [mock_ep_a, mock_ep_b] + + +# ============================================================================= +# PluginRegistry Singleton Tests +# ============================================================================= + + +def test_plugin_registry_is_singleton(mock_plugin_discovery) -> None: + """Test PluginRegistry returns same instance.""" + with mock_plugin_discovery([]): + manager1 = PluginRegistry() + manager2 = PluginRegistry() + + assert manager1 is manager2 + + +def test_plugin_registry_singleton_thread_safety(mock_plugin_discovery) -> None: + """Test PluginRegistry singleton creation is thread-safe.""" + instances: list[PluginRegistry] = [] + + with mock_plugin_discovery([]): + + def create_manager() -> None: + instances.append(PluginRegistry()) + + threads = [threading.Thread(target=create_manager) for _ in range(10)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert all(instance is instances[0] for instance in instances) + + +# ============================================================================= +# PluginRegistry Discovery Tests +# ============================================================================= + + +def test_plugin_registry_discovers_plugins( + mock_plugin_discovery, mock_entry_points: list[MagicMock], plugin_a: Plugin, plugin_b: Plugin +) -> None: + """Test PluginRegistry discovers and loads plugins from entry points.""" + with mock_plugin_discovery(mock_entry_points): + manager = PluginRegistry() + + assert manager.num_plugins(PluginType.COLUMN_GENERATOR) == 2 + assert manager.get_plugin("test-plugin-a") == plugin_a + assert manager.get_plugin("test-plugin-b") == plugin_b + + +def test_plugin_registry_skips_invalid_plugins(mock_plugin_discovery, plugin_a: Plugin) -> None: + """Test PluginRegistry skips non-Plugin objects during discovery.""" + mock_ep_valid = MagicMock(spec=EntryPoint) + mock_ep_valid.name = "test-plugin-a" + mock_ep_valid.load.return_value = plugin_a + + mock_ep_invalid = MagicMock(spec=EntryPoint) + mock_ep_invalid.name = "invalid-plugin" + mock_ep_invalid.load.return_value = "not a plugin" + + with mock_plugin_discovery([mock_ep_valid, mock_ep_invalid]): + manager = PluginRegistry() + + assert manager.num_plugins(PluginType.COLUMN_GENERATOR) == 1 + assert manager.get_plugin("test-plugin-a") == plugin_a + + +def test_plugin_registry_handles_loading_errors(mock_plugin_discovery, plugin_a: Plugin) -> None: + """Test PluginRegistry gracefully handles plugin loading errors.""" + mock_ep_valid = MagicMock(spec=EntryPoint) + mock_ep_valid.name = "test-plugin-a" + mock_ep_valid.load.return_value = plugin_a + + mock_ep_error = MagicMock(spec=EntryPoint) + mock_ep_error.name = "error-plugin" + mock_ep_error.load.side_effect = Exception("Loading failed") + + with mock_plugin_discovery([mock_ep_valid, mock_ep_error]): + manager = PluginRegistry() + + assert manager.num_plugins(PluginType.COLUMN_GENERATOR) == 1 + assert manager.get_plugin("test-plugin-a") == plugin_a + + +def test_plugin_registry_discovery_runs_once() -> None: + """Test discovery runs once even with multiple PluginRegistry instances.""" + mock_entry_points = MagicMock(return_value=[]) + + with patch("data_designer.plugins.registry.PLUGINS_DISABLED", False): + with patch("data_designer.plugins.registry.entry_points", mock_entry_points): + PluginRegistry() + PluginRegistry() + PluginRegistry() + + assert mock_entry_points.call_count == 1 + + +def test_plugin_registry_respects_disabled_flag() -> None: + """Test PluginRegistry respects DISABLE_DATA_DESIGNER_PLUGINS flag.""" + mock_entry_points = MagicMock(return_value=[]) + + with patch("data_designer.plugins.registry.PLUGINS_DISABLED", True): + with patch("data_designer.plugins.registry.entry_points", mock_entry_points): + manager = PluginRegistry() + + assert mock_entry_points.call_count == 0 + assert manager.num_plugins(PluginType.COLUMN_GENERATOR) == 0 + + +# ============================================================================= +# PluginRegistry Query Methods Tests +# ============================================================================= + + +def test_plugin_registry_get_plugin_raises_error(mock_plugin_discovery) -> None: + """Test get_plugin() raises error for nonexistent plugin.""" + with mock_plugin_discovery([]): + manager = PluginRegistry() + + with pytest.raises(PluginNotFoundError, match="Plugin 'nonexistent' not found"): + manager.get_plugin("nonexistent") + + +def test_plugin_registry_get_plugins_by_type( + mock_plugin_discovery, mock_entry_points: list[MagicMock], plugin_a: Plugin, plugin_b: Plugin +) -> None: + """Test get_plugins() filters by plugin type.""" + with mock_plugin_discovery(mock_entry_points): + manager = PluginRegistry() + plugins = manager.get_plugins(PluginType.COLUMN_GENERATOR) + + assert len(plugins) == 2 + assert plugin_a in plugins + assert plugin_b in plugins + + +def test_plugin_registry_get_plugins_empty(mock_plugin_discovery) -> None: + """Test get_plugins() returns empty list when no plugins match.""" + with mock_plugin_discovery([]): + manager = PluginRegistry() + plugins = manager.get_plugins(PluginType.COLUMN_GENERATOR) + + assert plugins == [] + + +def test_plugin_registry_get_plugin_names(mock_plugin_discovery, mock_entry_points: list[MagicMock]) -> None: + """Test get_plugin_names() returns plugin names by type.""" + with mock_plugin_discovery(mock_entry_points): + manager = PluginRegistry() + names = manager.get_plugin_names(PluginType.COLUMN_GENERATOR) + + assert set(names) == {"test-plugin-a", "test-plugin-b"} + + +# ============================================================================= +# PluginRegistry Type Union Tests +# ============================================================================= + + +def test_plugin_registry_update_type_union(mock_plugin_discovery, mock_entry_points: list[MagicMock]) -> None: + """Test update_type_union() adds plugin config types to union.""" + from typing import Union + + from typing_extensions import TypeAlias + + class DummyConfig(ConfigBase): + pass + + with mock_plugin_discovery(mock_entry_points): + manager = PluginRegistry() + + # Create a Union with at least 2 types so it has __args__ + type_union: TypeAlias = Union[ConfigBase, DummyConfig] + updated_union = manager.add_plugin_types_to_union(type_union, PluginType.COLUMN_GENERATOR) + + assert StubPluginConfigA in updated_union.__args__ + assert StubPluginConfigB in updated_union.__args__ + assert ConfigBase in updated_union.__args__ + assert DummyConfig in updated_union.__args__ diff --git a/tests/test_plugin_manager.py b/tests/test_plugin_manager.py new file mode 100644 index 00000000..c00d78f9 --- /dev/null +++ b/tests/test_plugin_manager.py @@ -0,0 +1,232 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Generator +from contextlib import contextmanager +from enum import Enum +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from data_designer.plugin_manager import PluginManager + + +class MockPluginType(str, Enum): + """Mock PluginType enum for testing.""" + + COLUMN_GENERATOR = "column-generator" + + @property + def discriminator_field(self) -> str: + return "column_type" + + +def create_mock_plugin(name: str, plugin_type: MockPluginType, resources: list[str] | None = None) -> Mock: + """Create a mock plugin with specified name and resources. + + Args: + name: Plugin name. + plugin_type: Plugin type enum. + resources: List of required resources, or None if no resource requirements. + + Returns: + Mock plugin object. + """ + plugin = Mock() + plugin.name = name + plugin.plugin_type = plugin_type + plugin.config_cls = Mock(name=name) + + mock_task = Mock() + mock_task.metadata = Mock(return_value=Mock(required_resources=resources)) + plugin.task_cls = mock_task + + return plugin + + +@contextmanager +def mock_plugin_system(registry: MagicMock) -> Generator[None, None, None]: + """Context manager to mock the plugin system with a given registry. + + This works regardless of whether the actual environment has plugins available or not + by patching at the module level where PluginManager is instantiated. + """ + with patch("data_designer.plugin_manager.can_run_data_designer_locally", return_value=True): + with patch("data_designer.plugin_manager.PluginRegistry", return_value=registry, create=True): + with patch("data_designer.plugin_manager.PluginType", MockPluginType, create=True): + yield + + +@pytest.fixture +def mock_plugin_registry() -> MagicMock: + """Create a mock plugin registry.""" + return MagicMock() + + +@pytest.fixture +def mock_plugins() -> list[Mock]: + """Create mock plugins for testing.""" + return [ + create_mock_plugin("plugin-one", MockPluginType.COLUMN_GENERATOR, ["resource1", "resource2"]), + create_mock_plugin("plugin-two", MockPluginType.COLUMN_GENERATOR, ["resource1"]), + create_mock_plugin("plugin-three", MockPluginType.COLUMN_GENERATOR, ["resource2", "resource3"]), + ] + + +def test_get_column_generator_plugins_with_plugins(mock_plugin_registry: MagicMock, mock_plugins: list[Mock]) -> None: + """Test getting plugin column configs when plugins are available.""" + mock_plugin_registry.get_plugins.return_value = [mock_plugins[0], mock_plugins[1]] + + with mock_plugin_system(mock_plugin_registry): + manager = PluginManager() + result = manager.get_column_generator_plugins() + + assert len(result) == 2 + assert [p.name for p in result] == ["plugin-one", "plugin-two"] + mock_plugin_registry.get_plugins.assert_called_once_with(MockPluginType.COLUMN_GENERATOR) + + +@pytest.mark.parametrize("plugins_available", [True, False]) +def test_get_column_generator_plugins_empty(mock_plugin_registry: MagicMock, plugins_available: bool) -> None: + """Test getting plugin column configs when no plugins are registered or system is disabled.""" + if plugins_available: + mock_plugin_registry.get_plugins.return_value = [] + with mock_plugin_system(mock_plugin_registry): + manager = PluginManager() + result = manager.get_column_generator_plugins() + else: + with patch("data_designer.plugin_manager.can_run_data_designer_locally", return_value=False): + manager = PluginManager() + result = manager.get_column_generator_plugins() + + assert result == [] + + +def test_get_column_generator_plugin_if_exists_found(mock_plugin_registry: MagicMock, mock_plugins: list[Mock]) -> None: + """Test getting a specific plugin by name when it exists.""" + mock_plugin_registry.plugin_exists.return_value = True + mock_plugin_registry.get_plugin.return_value = mock_plugins[0] + + with mock_plugin_system(mock_plugin_registry): + manager = PluginManager() + result = manager.get_column_generator_plugin_if_exists("plugin-one") + + assert result is not None + assert result.name == "plugin-one" + mock_plugin_registry.plugin_exists.assert_called_once_with("plugin-one") + mock_plugin_registry.get_plugin.assert_called_once_with("plugin-one") + + +def test_get_column_generator_plugin_if_exists_not_found(mock_plugin_registry: MagicMock) -> None: + """Test getting a specific plugin by name when it doesn't exist.""" + mock_plugin_registry.plugin_exists.return_value = False + + with mock_plugin_system(mock_plugin_registry): + manager = PluginManager() + result = manager.get_column_generator_plugin_if_exists("plugin-three") + + assert result is None + mock_plugin_registry.plugin_exists.assert_called_once_with("plugin-three") + mock_plugin_registry.get_plugin.assert_not_called() + + +def test_get_column_generator_plugin_if_exists_when_disabled() -> None: + """Test getting a specific plugin when plugin system is disabled.""" + with patch("data_designer.plugin_manager.can_run_data_designer_locally", return_value=False): + manager = PluginManager() + result = manager.get_column_generator_plugin_if_exists("plugin-one") + + assert result is None + + +def test_get_plugin_column_types_with_plugins(mock_plugin_registry: MagicMock, mock_plugins: list[Mock]) -> None: + """Test getting plugin column types when plugins are available.""" + TestEnum = Enum( + "TestEnum", {plugin.name.replace("-", "_").upper(): plugin.name for plugin in mock_plugins}, type=str + ) + mock_plugin_registry.get_plugins.return_value = mock_plugins + + with mock_plugin_system(mock_plugin_registry): + manager = PluginManager() + result = manager.get_plugin_column_types(TestEnum) + + assert len(result) == 3 + assert all(isinstance(item, TestEnum) for item in result) + mock_plugin_registry.get_plugins.assert_called_once_with(MockPluginType.COLUMN_GENERATOR) + + +def test_get_plugin_column_types_with_resource_filtering( + mock_plugin_registry: MagicMock, mock_plugins: list[Mock] +) -> None: + """Test filtering plugins by required resources.""" + TestEnum = Enum( + "TestEnum", {"PLUGIN_ONE": "plugin-one", "PLUGIN_TWO": "plugin-two", "PLUGIN_THREE": "plugin-three"}, type=str + ) + mock_plugin_registry.get_plugins.return_value = mock_plugins + + with mock_plugin_system(mock_plugin_registry): + manager = PluginManager() + result = manager.get_plugin_column_types(TestEnum, required_resources=["resource1"]) + + assert len(result) == 2 + assert set(result) == {TestEnum.PLUGIN_ONE, TestEnum.PLUGIN_TWO} + + +def test_get_plugin_column_types_filters_none_resources(mock_plugin_registry: MagicMock) -> None: + """Test filtering when plugin has None for required_resources.""" + plugin = create_mock_plugin("plugin-one", MockPluginType.COLUMN_GENERATOR, None) + TestEnum = Enum("TestEnum", {"PLUGIN_ONE": "plugin-one"}, type=str) + mock_plugin_registry.get_plugins.return_value = [plugin] + + with mock_plugin_system(mock_plugin_registry): + manager = PluginManager() + result = manager.get_plugin_column_types(TestEnum, required_resources=["resource1"]) + + assert result == [] + + +@pytest.mark.parametrize("plugins_available", [True, False]) +def test_get_plugin_column_types_empty(mock_plugin_registry: MagicMock, plugins_available: bool) -> None: + """Test getting plugin column types when no plugins are registered or system is disabled.""" + TestEnum = Enum("TestEnum", {}, type=str) + + if plugins_available: + mock_plugin_registry.get_plugins.return_value = [] + with mock_plugin_system(mock_plugin_registry): + manager = PluginManager() + result = manager.get_plugin_column_types(TestEnum) + else: + with patch("data_designer.plugin_manager.can_run_data_designer_locally", return_value=False): + manager = PluginManager() + result = manager.get_plugin_column_types(TestEnum) + + assert result == [] + + +def test_inject_into_column_config_type_union_with_plugins(mock_plugin_registry: MagicMock) -> None: + """Test injecting plugins into column config type union.""" + + class BaseType: + pass + + mock_plugin_registry.add_plugin_types_to_union.return_value = str | int + + with mock_plugin_system(mock_plugin_registry): + manager = PluginManager() + result = manager.inject_into_column_config_type_union(BaseType) + + assert result == str | int + mock_plugin_registry.add_plugin_types_to_union.assert_called_once_with(BaseType, MockPluginType.COLUMN_GENERATOR) + + +def test_inject_into_column_config_type_union_when_disabled() -> None: + """Test injecting plugins when plugin system is disabled.""" + + class BaseType: + pass + + with patch("data_designer.plugin_manager.can_run_data_designer_locally", return_value=False): + manager = PluginManager() + result = manager.inject_into_column_config_type_union(BaseType) + + assert result == BaseType diff --git a/uv.lock b/uv.lock index b9f73f9a..5243435f 100644 --- a/uv.lock +++ b/uv.lock @@ -766,6 +766,7 @@ dev = [ { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, + { name = "pytest-env" }, { name = "pytest-httpx" }, ] docs = [ @@ -823,6 +824,7 @@ dev = [ { name = "pytest", specifier = ">=8.3.3" }, { name = "pytest-asyncio", specifier = ">=0.24.0" }, { name = "pytest-cov", specifier = ">=7.0.0" }, + { name = "pytest-env", specifier = ">=1.2.0" }, { name = "pytest-httpx", specifier = ">=0.35.0" }, ] docs = [ @@ -3398,6 +3400,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" }, ] +[[package]] +name = "pytest-env" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/13/12/9c87d0ca45d5992473208bcef2828169fa7d39b8d7fc6e3401f5c08b8bf7/pytest_env-1.2.0.tar.gz", hash = "sha256:475e2ebe8626cee01f491f304a74b12137742397d6c784ea4bc258f069232b80", size = 8973, upload-time = "2025-10-09T19:15:47.42Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/27/98/822b924a4a3eb58aacba84444c7439fce32680592f394de26af9c76e2569/pytest_env-1.2.0-py3-none-any.whl", hash = "sha256:d7e5b7198f9b83c795377c09feefa45d56083834e60d04767efd64819fc9da00", size = 6251, upload-time = "2025-10-09T19:15:46.077Z" }, +] + [[package]] name = "pytest-httpx" version = "0.35.0"