From 8d996045a96d243441a6efcbcdc2d1a476d58cec Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Fri, 7 Nov 2025 14:44:05 -0500 Subject: [PATCH 01/27] separate column configs and types --- .../config/analysis/column_statistics.py | 2 +- .../config/analysis/dataset_profiler.py | 2 +- .../config/analysis/utils/reporting.py | 2 +- src/data_designer/config/column_configs.py | 130 ++++++++++++++++ .../config/{columns.py => column_types.py} | 146 +++--------------- src/data_designer/config/config_builder.py | 4 +- .../config/data_designer_config.py | 2 +- src/data_designer/config/utils/validation.py | 2 +- .../config/utils/visualization.py | 2 +- .../engine/analysis/column_profilers/base.py | 3 +- .../column_profilers/judge_score_profiler.py | 2 +- .../engine/analysis/column_statistics.py | 2 +- .../engine/analysis/dataset_profiler.py | 4 +- .../utils/column_statistics_calculations.py | 2 +- .../analysis/utils/judge_score_processing.py | 2 +- .../generators/expression.py | 2 +- .../generators/llm_generators.py | 4 +- .../generators/validation.py | 2 +- .../engine/column_generators/registry.py | 6 +- .../utils/judge_score_factory.py | 2 +- .../utils/prompt_renderer.py | 3 +- .../dataset_builders/column_wise_builder.py | 2 +- .../dataset_builders/multi_column_configs.py | 9 +- .../dataset_builders/utils/config_compiler.py | 2 +- .../engine/dataset_builders/utils/dag.py | 2 +- .../engine/registry/data_designer_registry.py | 4 +- .../engine/sampling_gen/column.py | 2 +- .../engine/sampling_gen/schema_builder.py | 2 +- src/data_designer/essentials/__init__.py | 4 +- tests/config/analysis/utils/test_reporting.py | 2 +- tests/config/test_columns.py | 6 +- tests/config/test_config_builder.py | 5 +- tests/config/utils/test_validation.py | 2 +- tests/conftest.py | 2 +- .../analysis/column_profilers/test_base.py | 4 +- .../test_judge_score_profiler.py | 2 +- tests/engine/analysis/conftest.py | 3 +- .../test_column_statistics_calculator.py | 2 +- .../engine/analysis/test_dataset_profiler.py | 2 +- .../test_column_statistics_calculations.py | 2 +- .../utils/test_judge_score_processing.py | 2 +- .../generators/test_column_generator_base.py | 2 +- .../generators/test_expression.py | 2 +- .../generators/test_llm_generators.py | 2 +- .../generators/test_samplers.py | 2 +- .../generators/test_seed_dataset.py | 2 +- .../generators/test_validation.py | 2 +- .../engine/column_generators/test_registry.py | 6 +- .../utils/test_judge_score_factory.py | 2 +- .../utils/test_prompt_renderer.py | 2 +- .../test_column_wise_builder.py | 2 +- .../test_multi_column_configs.py | 7 +- .../utils/test_config_compiler.py | 8 +- .../engine/dataset_builders/utils/test_dag.py | 4 +- .../registry/test_data_designer_registry.py | 2 +- 55 files changed, 227 insertions(+), 205 deletions(-) create mode 100644 src/data_designer/config/column_configs.py rename src/data_designer/config/{columns.py => column_types.py} (55%) diff --git a/src/data_designer/config/analysis/column_statistics.py b/src/data_designer/config/analysis/column_statistics.py index 991e41b9..be1d11a9 100644 --- a/src/data_designer/config/analysis/column_statistics.py +++ b/src/data_designer/config/analysis/column_statistics.py @@ -11,7 +11,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from typing_extensions import Self, TypeAlias -from ..columns import DataDesignerColumnType +from ..column_types import DataDesignerColumnType from ..sampler_params import SamplerType from ..utils.constants import EPSILON from ..utils.numerical_helpers import is_float, is_int, prepare_number_for_reporting diff --git a/src/data_designer/config/analysis/dataset_profiler.py b/src/data_designer/config/analysis/dataset_profiler.py index aa2b638f..058b58d7 100644 --- a/src/data_designer/config/analysis/dataset_profiler.py +++ b/src/data_designer/config/analysis/dataset_profiler.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, Field, field_validator -from ..columns import DataDesignerColumnType, get_column_display_order +from ..column_types import DataDesignerColumnType, get_column_display_order from ..utils.constants import EPSILON from ..utils.numerical_helpers import prepare_number_for_reporting from .column_profilers import ColumnProfilerResultsT diff --git a/src/data_designer/config/analysis/utils/reporting.py b/src/data_designer/config/analysis/utils/reporting.py index e4df4190..fb7d116e 100644 --- a/src/data_designer/config/analysis/utils/reporting.py +++ b/src/data_designer/config/analysis/utils/reporting.py @@ -15,7 +15,7 @@ from rich.text import Text from ...analysis.column_statistics import CategoricalHistogramData -from ...columns import COLUMN_TYPE_EMOJI_MAP, DataDesignerColumnType, get_column_display_order +from ...column_types import COLUMN_TYPE_EMOJI_MAP, DataDesignerColumnType, get_column_display_order from ...utils.visualization import ( ColorPalette, convert_to_row_element, diff --git a/src/data_designer/config/column_configs.py b/src/data_designer/config/column_configs.py new file mode 100644 index 00000000..2840d13d --- /dev/null +++ b/src/data_designer/config/column_configs.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC +from typing import Literal, Optional, Type, Union + +from pydantic import BaseModel, Field, model_validator +from typing_extensions import Self + +from .base import ConfigBase +from .errors import InvalidConfigError +from .models import ImageContext +from .sampler_params import SamplerParamsT, SamplerType +from .utils.code_lang import CodeLang +from .utils.constants import REASONING_TRACE_COLUMN_POSTFIX +from .utils.misc import assert_valid_jinja2_template, get_prompt_template_keywords +from .validator_params import ValidatorParamsT, ValidatorType + + +class SingleColumnConfig(ConfigBase, ABC): + name: str + drop: bool = False + column_type: str + + @property + def required_columns(self) -> list[str]: + return [] + + @property + def side_effect_columns(self) -> list[str]: + return [] + + +class SamplerColumnConfig(SingleColumnConfig): + sampler_type: SamplerType + params: SamplerParamsT + conditional_params: dict[str, SamplerParamsT] = {} + convert_to: Optional[str] = None + column_type: Literal["sampler"] = "sampler" + + +class LLMTextColumnConfig(SingleColumnConfig): + prompt: str + model_alias: str + system_prompt: Optional[str] = None + multi_modal_context: Optional[list[ImageContext]] = None + column_type: Literal["llm-text"] = "llm-text" + + @property + def required_columns(self) -> list[str]: + required_cols = list(get_prompt_template_keywords(self.prompt)) + if self.system_prompt: + required_cols.extend(list(get_prompt_template_keywords(self.system_prompt))) + return list(set(required_cols)) + + @property + def side_effect_columns(self) -> list[str]: + return [f"{self.name}{REASONING_TRACE_COLUMN_POSTFIX}"] + + @model_validator(mode="after") + def assert_prompt_valid_jinja(self) -> Self: + assert_valid_jinja2_template(self.prompt) + if self.system_prompt: + assert_valid_jinja2_template(self.system_prompt) + return self + + +class LLMCodeColumnConfig(LLMTextColumnConfig): + code_lang: CodeLang + column_type: Literal["llm-code"] = "llm-code" + + +class LLMStructuredColumnConfig(LLMTextColumnConfig): + output_format: Union[dict, Type[BaseModel]] + column_type: Literal["llm-structured"] = "llm-structured" + + @model_validator(mode="after") + def validate_output_format(self) -> Self: + if not isinstance(self.output_format, dict) and issubclass(self.output_format, BaseModel): + self.output_format = self.output_format.model_json_schema() + return self + + +class Score(ConfigBase): + name: str = Field(..., description="A clear name for this score.") + description: str = Field(..., description="An informative and detailed assessment guide for using this score.") + options: dict[Union[int, str], str] = Field(..., description="Score options in the format of {score: description}.") + + +class LLMJudgeColumnConfig(LLMTextColumnConfig): + scores: list[Score] = Field(..., min_length=1) + column_type: Literal["llm-judge"] = "llm-judge" + + +class ExpressionColumnConfig(SingleColumnConfig): + name: str + expr: str + dtype: Literal["int", "float", "str", "bool"] = "str" + column_type: Literal["expression"] = "expression" + + @property + def required_columns(self) -> list[str]: + return list(get_prompt_template_keywords(self.expr)) + + @model_validator(mode="after") + def assert_expression_valid_jinja(self) -> Self: + if not self.expr.strip(): + raise InvalidConfigError( + f"🛑 Expression column '{self.name}' has an empty or whitespace-only expression. " + f"Please provide a valid Jinja2 expression (e.g., '{{ column_name }}' or '{{ col1 }} + {{ col2 }}') " + "or remove this column if not needed." + ) + assert_valid_jinja2_template(self.expr) + return self + + +class ValidationColumnConfig(SingleColumnConfig): + target_columns: list[str] + validator_type: ValidatorType + validator_params: ValidatorParamsT + batch_size: int = Field(default=10, ge=1, description="Number of records to process in each batch") + column_type: Literal["validation"] = "validation" + + @property + def required_columns(self) -> list[str]: + return self.target_columns + + +class SeedDatasetColumnConfig(SingleColumnConfig): + column_type: Literal["seed-dataset"] = "seed-dataset" diff --git a/src/data_designer/config/columns.py b/src/data_designer/config/column_types.py similarity index 55% rename from src/data_designer/config/columns.py rename to src/data_designer/config/column_types.py index 8886cb09..305c0304 100644 --- a/src/data_designer/config/columns.py +++ b/src/data_designer/config/column_types.py @@ -1,135 +1,27 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from abc import ABC -from typing import Literal, Optional, Type, Union +from typing import Union -from pydantic import BaseModel, Field, model_validator -from typing_extensions import Self, TypeAlias +from typing_extensions import TypeAlias -from .base import ConfigBase +from .column_configs import ( + ExpressionColumnConfig, + LLMCodeColumnConfig, + LLMJudgeColumnConfig, + LLMStructuredColumnConfig, + LLMTextColumnConfig, + SamplerColumnConfig, + SeedDatasetColumnConfig, + ValidationColumnConfig, +) from .errors import InvalidColumnTypeError, InvalidConfigError -from .models import ImageContext -from .sampler_params import SamplerParamsT, SamplerType -from .utils.code_lang import CodeLang -from .utils.constants import REASONING_TRACE_COLUMN_POSTFIX -from .utils.misc import assert_valid_jinja2_template, get_prompt_template_keywords +from .sampler_params import SamplerType +from .utils.misc import can_run_data_designer_locally from .utils.type_helpers import SAMPLER_PARAMS, create_str_enum_from_discriminated_type_union, resolve_string_enum -from .validator_params import ValidatorParamsT, ValidatorType - - -class SingleColumnConfig(ConfigBase, ABC): - name: str - drop: bool = False - column_type: str - - @property - def required_columns(self) -> list[str]: - return [] - - @property - def side_effect_columns(self) -> list[str]: - return [] - - -class SamplerColumnConfig(SingleColumnConfig): - sampler_type: SamplerType - params: SamplerParamsT - conditional_params: dict[str, SamplerParamsT] = {} - convert_to: Optional[str] = None - column_type: Literal["sampler"] = "sampler" - - -class LLMTextColumnConfig(SingleColumnConfig): - prompt: str - model_alias: str - system_prompt: Optional[str] = None - multi_modal_context: Optional[list[ImageContext]] = None - column_type: Literal["llm-text"] = "llm-text" - - @property - def required_columns(self) -> list[str]: - required_cols = list(get_prompt_template_keywords(self.prompt)) - if self.system_prompt: - required_cols.extend(list(get_prompt_template_keywords(self.system_prompt))) - return list(set(required_cols)) - - @property - def side_effect_columns(self) -> list[str]: - return [f"{self.name}{REASONING_TRACE_COLUMN_POSTFIX}"] - - @model_validator(mode="after") - def assert_prompt_valid_jinja(self) -> Self: - assert_valid_jinja2_template(self.prompt) - if self.system_prompt: - assert_valid_jinja2_template(self.system_prompt) - return self - - -class LLMCodeColumnConfig(LLMTextColumnConfig): - code_lang: CodeLang - column_type: Literal["llm-code"] = "llm-code" - - -class LLMStructuredColumnConfig(LLMTextColumnConfig): - output_format: Union[dict, Type[BaseModel]] - column_type: Literal["llm-structured"] = "llm-structured" - - @model_validator(mode="after") - def validate_output_format(self) -> Self: - if not isinstance(self.output_format, dict) and issubclass(self.output_format, BaseModel): - self.output_format = self.output_format.model_json_schema() - return self - - -class Score(ConfigBase): - name: str = Field(..., description="A clear name for this score.") - description: str = Field(..., description="An informative and detailed assessment guide for using this score.") - options: dict[Union[int, str], str] = Field(..., description="Score options in the format of {score: description}.") - - -class LLMJudgeColumnConfig(LLMTextColumnConfig): - scores: list[Score] = Field(..., min_length=1) - column_type: Literal["llm-judge"] = "llm-judge" - - -class ExpressionColumnConfig(SingleColumnConfig): - name: str - expr: str - dtype: Literal["int", "float", "str", "bool"] = "str" - column_type: Literal["expression"] = "expression" - - @property - def required_columns(self) -> list[str]: - return list(get_prompt_template_keywords(self.expr)) - - @model_validator(mode="after") - def assert_expression_valid_jinja(self) -> Self: - if not self.expr.strip(): - raise InvalidConfigError( - f"🛑 Expression column '{self.name}' has an empty or whitespace-only expression. " - f"Please provide a valid Jinja2 expression (e.g., '{{ column_name }}' or '{{ col1 }} + {{ col2 }}') " - "or remove this column if not needed." - ) - assert_valid_jinja2_template(self.expr) - return self - - -class ValidationColumnConfig(SingleColumnConfig): - target_columns: list[str] - validator_type: ValidatorType - validator_params: ValidatorParamsT - batch_size: int = Field(default=10, ge=1, description="Number of records to process in each batch") - column_type: Literal["validation"] = "validation" - - @property - def required_columns(self) -> list[str]: - return self.target_columns - - -class SeedDatasetColumnConfig(SingleColumnConfig): - column_type: Literal["seed-dataset"] = "seed-dataset" +if can_run_data_designer_locally(): + from data_designer.plugins.manager import PluginManager, PluginType ColumnConfigT: TypeAlias = Union[ ExpressionColumnConfig, @@ -143,6 +35,12 @@ class SeedDatasetColumnConfig(SingleColumnConfig): ] +if can_run_data_designer_locally(): + pm = PluginManager() + if pm.num_plugins(PluginType.COLUMN_GENERATOR) > 0: + ColumnConfigT = pm.update_type_union(ColumnConfigT, PluginType.COLUMN_GENERATOR) + + DataDesignerColumnType = create_str_enum_from_discriminated_type_union( enum_name="DataDesignerColumnType", type_union=ColumnConfigT, diff --git a/src/data_designer/config/config_builder.py b/src/data_designer/config/config_builder.py index eca394cf..db1b8691 100644 --- a/src/data_designer/config/config_builder.py +++ b/src/data_designer/config/config_builder.py @@ -15,10 +15,10 @@ from .analysis.column_profilers import ColumnProfilerConfigT from .base import ExportableConfigBase -from .columns import ( +from .column_configs import SeedDatasetColumnConfig +from .column_types import ( ColumnConfigT, DataDesignerColumnType, - SeedDatasetColumnConfig, column_type_is_llm_generated, get_column_config_from_kwargs, ) diff --git a/src/data_designer/config/data_designer_config.py b/src/data_designer/config/data_designer_config.py index 24d791b8..3575f1a7 100644 --- a/src/data_designer/config/data_designer_config.py +++ b/src/data_designer/config/data_designer_config.py @@ -9,7 +9,7 @@ from .analysis.column_profilers import ColumnProfilerConfigT from .base import ExportableConfigBase -from .columns import ColumnConfigT +from .column_types import ColumnConfigT from .models import ModelConfig from .processors import ProcessorConfig from .sampler_constraints import ColumnConstraintT diff --git a/src/data_designer/config/utils/validation.py b/src/data_designer/config/utils/validation.py index 428c575c..f1a0bef1 100644 --- a/src/data_designer/config/utils/validation.py +++ b/src/data_designer/config/utils/validation.py @@ -15,7 +15,7 @@ from rich.padding import Padding from rich.panel import Panel -from ..columns import ColumnConfigT, DataDesignerColumnType, column_type_is_llm_generated +from ..column_types import ColumnConfigT, DataDesignerColumnType, column_type_is_llm_generated from ..processors import ProcessorConfig, ProcessorType from ..validator_params import ValidatorType from .constants import RICH_CONSOLE_THEME diff --git a/src/data_designer/config/utils/visualization.py b/src/data_designer/config/utils/visualization.py index 4b245369..3d7420ad 100644 --- a/src/data_designer/config/utils/visualization.py +++ b/src/data_designer/config/utils/visualization.py @@ -21,7 +21,7 @@ from rich.text import Text from ..base import ConfigBase -from ..columns import DataDesignerColumnType +from ..column_types import DataDesignerColumnType from ..models import ModelConfig, ModelProvider, get_nvidia_api_key, get_openai_api_key from ..sampler_params import SamplerType from .code_lang import code_lang_to_syntax_lexer diff --git a/src/data_designer/engine/analysis/column_profilers/base.py b/src/data_designer/engine/analysis/column_profilers/base.py index 45d0d6ea..4bfcec9f 100644 --- a/src/data_designer/engine/analysis/column_profilers/base.py +++ b/src/data_designer/engine/analysis/column_profilers/base.py @@ -12,7 +12,8 @@ from typing_extensions import Self from data_designer.config.base import ConfigBase -from data_designer.config.columns import DataDesignerColumnType, SingleColumnConfig +from data_designer.config.column_configs import SingleColumnConfig +from data_designer.config.column_types import DataDesignerColumnType from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata, TaskConfigT logger = logging.getLogger(__name__) diff --git a/src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py b/src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py index 37de5dbd..0c8e3569 100644 --- a/src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py +++ b/src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py @@ -20,7 +20,7 @@ MissingValue, NumericalDistribution, ) -from data_designer.config.columns import COLUMN_TYPE_EMOJI_MAP, DataDesignerColumnType +from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP, DataDesignerColumnType from data_designer.engine.analysis.column_profilers.base import ( ColumnConfigWithDataFrame, ColumnProfiler, diff --git a/src/data_designer/engine/analysis/column_statistics.py b/src/data_designer/engine/analysis/column_statistics.py index dd4fb1e9..d7441434 100644 --- a/src/data_designer/engine/analysis/column_statistics.py +++ b/src/data_designer/engine/analysis/column_statistics.py @@ -22,7 +22,7 @@ SeedDatasetColumnStatistics, ValidationColumnStatistics, ) -from data_designer.config.columns import ColumnConfigT, DataDesignerColumnType +from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType from data_designer.config.sampler_params import SamplerType, is_numerical_sampler_type from data_designer.engine.analysis.column_profilers.base import ColumnConfigWithDataFrame from data_designer.engine.analysis.utils.column_statistics_calculations import ( diff --git a/src/data_designer/engine/analysis/dataset_profiler.py b/src/data_designer/engine/analysis/dataset_profiler.py index 3dd39ed0..3f173bba 100644 --- a/src/data_designer/engine/analysis/dataset_profiler.py +++ b/src/data_designer/engine/analysis/dataset_profiler.py @@ -11,10 +11,10 @@ from data_designer.config.analysis.column_profilers import ColumnProfilerConfigT from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults from data_designer.config.base import ConfigBase -from data_designer.config.columns import ( +from data_designer.config.column_configs import SingleColumnConfig +from data_designer.config.column_types import ( COLUMN_TYPE_EMOJI_MAP, ColumnConfigT, - SingleColumnConfig, ) from data_designer.engine.analysis.column_profilers.base import ColumnConfigWithDataFrame, ColumnProfiler from data_designer.engine.analysis.column_statistics import get_column_statistics_calculator diff --git a/src/data_designer/engine/analysis/utils/column_statistics_calculations.py b/src/data_designer/engine/analysis/utils/column_statistics_calculations.py index 67d72800..2b47c8cb 100644 --- a/src/data_designer/engine/analysis/utils/column_statistics_calculations.py +++ b/src/data_designer/engine/analysis/utils/column_statistics_calculations.py @@ -20,7 +20,7 @@ MissingValue, NumericalDistribution, ) -from data_designer.config.columns import ( +from data_designer.config.column_configs import ( LLMTextColumnConfig, SingleColumnConfig, ValidationColumnConfig, diff --git a/src/data_designer/engine/analysis/utils/judge_score_processing.py b/src/data_designer/engine/analysis/utils/judge_score_processing.py index 95b01ccc..d9bbc764 100644 --- a/src/data_designer/engine/analysis/utils/judge_score_processing.py +++ b/src/data_designer/engine/analysis/utils/judge_score_processing.py @@ -14,7 +14,7 @@ MissingValue, NumericalDistribution, ) -from data_designer.config.columns import LLMJudgeColumnConfig +from data_designer.config.column_configs import LLMJudgeColumnConfig logger = logging.getLogger(__name__) diff --git a/src/data_designer/engine/column_generators/generators/expression.py b/src/data_designer/engine/column_generators/generators/expression.py index 7da80e66..00ce4771 100644 --- a/src/data_designer/engine/column_generators/generators/expression.py +++ b/src/data_designer/engine/column_generators/generators/expression.py @@ -5,7 +5,7 @@ import pandas as pd -from data_designer.config.columns import ExpressionColumnConfig +from data_designer.config.column_configs import ExpressionColumnConfig from data_designer.engine.column_generators.generators.base import ( ColumnGenerator, GenerationStrategy, diff --git a/src/data_designer/engine/column_generators/generators/llm_generators.py b/src/data_designer/engine/column_generators/generators/llm_generators.py index 14324f36..ee0ab58a 100644 --- a/src/data_designer/engine/column_generators/generators/llm_generators.py +++ b/src/data_designer/engine/column_generators/generators/llm_generators.py @@ -4,13 +4,13 @@ import functools import logging -from data_designer.config.columns import ( - COLUMN_TYPE_EMOJI_MAP, +from data_designer.config.column_configs import ( LLMCodeColumnConfig, LLMJudgeColumnConfig, LLMStructuredColumnConfig, LLMTextColumnConfig, ) +from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP from data_designer.config.models import InferenceParameters, ModelConfig from data_designer.config.utils.constants import REASONING_TRACE_COLUMN_POSTFIX from data_designer.engine.column_generators.generators.base import ( diff --git a/src/data_designer/engine/column_generators/generators/validation.py b/src/data_designer/engine/column_generators/generators/validation.py index 424165fc..f46296b9 100644 --- a/src/data_designer/engine/column_generators/generators/validation.py +++ b/src/data_designer/engine/column_generators/generators/validation.py @@ -5,7 +5,7 @@ import pandas as pd -from data_designer.config.columns import ValidationColumnConfig +from data_designer.config.column_configs import ValidationColumnConfig from data_designer.config.errors import InvalidConfigError from data_designer.config.utils.code_lang import SQL_DIALECTS, CodeLang from data_designer.config.validator_params import ( diff --git a/src/data_designer/engine/column_generators/registry.py b/src/data_designer/engine/column_generators/registry.py index 95c6ee7a..6d099f1c 100644 --- a/src/data_designer/engine/column_generators/registry.py +++ b/src/data_designer/engine/column_generators/registry.py @@ -2,8 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from data_designer.config.base import ConfigBase -from data_designer.config.columns import ( - DataDesignerColumnType, +from data_designer.config.column_configs import ( ExpressionColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, @@ -11,6 +10,7 @@ LLMTextColumnConfig, ValidationColumnConfig, ) +from data_designer.config.column_types import DataDesignerColumnType from data_designer.engine.column_generators.generators.base import ColumnGenerator from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator from data_designer.engine.column_generators.generators.llm_generators import ( @@ -32,7 +32,7 @@ class ColumnGeneratorRegistry(TaskRegistry[DataDesignerColumnType, ColumnGenerator, ConfigBase]): ... -def create_default_column_generator_registry() -> ColumnGeneratorRegistry: +def create_builtin_column_generator_registry() -> ColumnGeneratorRegistry: registry = ColumnGeneratorRegistry() registry.register(DataDesignerColumnType.LLM_TEXT, LLMTextCellGenerator, LLMTextColumnConfig, False) registry.register(DataDesignerColumnType.LLM_CODE, LLMCodeCellGenerator, LLMCodeColumnConfig, False) diff --git a/src/data_designer/engine/column_generators/utils/judge_score_factory.py b/src/data_designer/engine/column_generators/utils/judge_score_factory.py index b4d458c6..a02a86ee 100644 --- a/src/data_designer/engine/column_generators/utils/judge_score_factory.py +++ b/src/data_designer/engine/column_generators/utils/judge_score_factory.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict, Field, create_model -from data_designer.config.columns import Score +from data_designer.config.column_configs import Score SCORING_FORMAT = "* {score}: {description}" SCORE_FIELD_DESCRIPTION_FORMAT = "Score Descriptions for {enum_name}:\n{scoring}" diff --git a/src/data_designer/engine/column_generators/utils/prompt_renderer.py b/src/data_designer/engine/column_generators/utils/prompt_renderer.py index d895fe9e..0c1cd708 100644 --- a/src/data_designer/engine/column_generators/utils/prompt_renderer.py +++ b/src/data_designer/engine/column_generators/utils/prompt_renderer.py @@ -5,7 +5,8 @@ import json import logging -from data_designer.config.columns import DataDesignerColumnType, SingleColumnConfig +from data_designer.config.column_configs import SingleColumnConfig +from data_designer.config.column_types import DataDesignerColumnType from data_designer.config.models import ModelConfig from data_designer.config.utils.code_lang import CodeLang from data_designer.config.utils.misc import get_prompt_template_keywords diff --git a/src/data_designer/engine/dataset_builders/column_wise_builder.py b/src/data_designer/engine/dataset_builders/column_wise_builder.py index 26e794ca..14d63dce 100644 --- a/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -10,7 +10,7 @@ import pandas as pd -from data_designer.config.columns import ColumnConfigT, column_type_is_llm_generated +from data_designer.config.column_types import ColumnConfigT, column_type_is_llm_generated from data_designer.config.dataset_builders import BuildStage from data_designer.config.processors import ( DropColumnsProcessorConfig, diff --git a/src/data_designer/engine/dataset_builders/multi_column_configs.py b/src/data_designer/engine/dataset_builders/multi_column_configs.py index b7435706..fc2e74f5 100644 --- a/src/data_designer/engine/dataset_builders/multi_column_configs.py +++ b/src/data_designer/engine/dataset_builders/multi_column_configs.py @@ -7,13 +7,8 @@ from pydantic import Field, field_validator from data_designer.config.base import ConfigBase -from data_designer.config.columns import ( - ColumnConfigT, - DataDesignerColumnType, - SamplerColumnConfig, - SeedDatasetColumnConfig, - SingleColumnConfig, -) +from data_designer.config.column_configs import SamplerColumnConfig, SeedDatasetColumnConfig, SingleColumnConfig +from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType from data_designer.config.sampler_constraints import ColumnConstraintT from data_designer.config.seed import SeedConfig diff --git a/src/data_designer/engine/dataset_builders/utils/config_compiler.py b/src/data_designer/engine/dataset_builders/utils/config_compiler.py index d80ec37a..5359bfa1 100644 --- a/src/data_designer/engine/dataset_builders/utils/config_compiler.py +++ b/src/data_designer/engine/dataset_builders/utils/config_compiler.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from data_designer.config.columns import DataDesignerColumnType +from data_designer.config.column_types import DataDesignerColumnType from data_designer.config.data_designer_config import DataDesignerConfig from data_designer.config.processors import ProcessorConfig from data_designer.engine.dataset_builders.multi_column_configs import ( diff --git a/src/data_designer/engine/dataset_builders/utils/dag.py b/src/data_designer/engine/dataset_builders/utils/dag.py index e6d9da14..6c056d11 100644 --- a/src/data_designer/engine/dataset_builders/utils/dag.py +++ b/src/data_designer/engine/dataset_builders/utils/dag.py @@ -5,7 +5,7 @@ import networkx as nx -from data_designer.config.columns import ColumnConfigT, column_type_used_in_execution_dag +from data_designer.config.column_types import ColumnConfigT, column_type_used_in_execution_dag from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError logger = logging.getLogger(__name__) diff --git a/src/data_designer/engine/registry/data_designer_registry.py b/src/data_designer/engine/registry/data_designer_registry.py index 8ed2f0ba..60c6e1fc 100644 --- a/src/data_designer/engine/registry/data_designer_registry.py +++ b/src/data_designer/engine/registry/data_designer_registry.py @@ -7,7 +7,7 @@ ) from data_designer.engine.column_generators.registry import ( ColumnGeneratorRegistry, - create_default_column_generator_registry, + create_builtin_column_generator_registry, ) from data_designer.engine.processing.processors.registry import ( ProcessorRegistry, @@ -23,7 +23,7 @@ def __init__( column_profiler_registry: ColumnProfilerRegistry | None = None, processor_registry: ProcessorRegistry | None = None, ): - self._column_generator_registry = column_generator_registry or create_default_column_generator_registry() + self._column_generator_registry = column_generator_registry or create_builtin_column_generator_registry() self._column_profiler_registry = column_profiler_registry or create_default_column_profiler_registry() self._processor_registry = processor_registry or create_default_processor_registry() diff --git a/src/data_designer/engine/sampling_gen/column.py b/src/data_designer/engine/sampling_gen/column.py index 12c910c7..26ada316 100644 --- a/src/data_designer/engine/sampling_gen/column.py +++ b/src/data_designer/engine/sampling_gen/column.py @@ -6,7 +6,7 @@ from pydantic import field_serializer, model_validator from typing_extensions import Self -from data_designer.config.columns import SamplerColumnConfig +from data_designer.config.column_configs import SamplerColumnConfig from data_designer.config.sampler_params import SamplerParamsT, SamplerType from data_designer.engine.sampling_gen.data_sources.base import DataSource from data_designer.engine.sampling_gen.data_sources.sources import SamplerRegistry diff --git a/src/data_designer/engine/sampling_gen/schema_builder.py b/src/data_designer/engine/sampling_gen/schema_builder.py index c5ee010f..3c3e3eaf 100644 --- a/src/data_designer/engine/sampling_gen/schema_builder.py +++ b/src/data_designer/engine/sampling_gen/schema_builder.py @@ -3,7 +3,7 @@ from copy import deepcopy -from data_designer.config.columns import SamplerColumnConfig +from data_designer.config.column_configs import SamplerColumnConfig from data_designer.config.sampler_constraints import ColumnConstraintT from data_designer.config.sampler_params import SamplerParamsT from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig diff --git a/src/data_designer/essentials/__init__.py b/src/data_designer/essentials/__init__.py index 2d6516c6..80dfa7b5 100644 --- a/src/data_designer/essentials/__init__.py +++ b/src/data_designer/essentials/__init__.py @@ -2,8 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from ..config.analysis.column_profilers import JudgeScoreProfilerConfig -from ..config.columns import ( - DataDesignerColumnType, +from ..config.column_configs import ( ExpressionColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, @@ -14,6 +13,7 @@ SeedDatasetColumnConfig, ValidationColumnConfig, ) +from ..config.column_types import DataDesignerColumnType from ..config.config_builder import DataDesignerConfigBuilder from ..config.data_designer_config import DataDesignerConfig from ..config.dataset_builders import BuildStage diff --git a/tests/config/analysis/utils/test_reporting.py b/tests/config/analysis/utils/test_reporting.py index ec9a9912..82b5234c 100644 --- a/tests/config/analysis/utils/test_reporting.py +++ b/tests/config/analysis/utils/test_reporting.py @@ -11,7 +11,7 @@ from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults from data_designer.config.analysis.utils.errors import AnalysisReportError from data_designer.config.analysis.utils.reporting import ReportSection, generate_analysis_report -from data_designer.config.columns import DataDesignerColumnType +from data_designer.config.column_types import DataDesignerColumnType @pytest.fixture diff --git a/tests/config/test_columns.py b/tests/config/test_columns.py index df2d0668..2accdaa3 100644 --- a/tests/config/test_columns.py +++ b/tests/config/test_columns.py @@ -4,8 +4,7 @@ from pydantic import ValidationError import pytest -from data_designer.config.columns import ( - DataDesignerColumnType, +from data_designer.config.column_configs import ( ExpressionColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, @@ -15,6 +14,9 @@ Score, SeedDatasetColumnConfig, ValidationColumnConfig, +) +from data_designer.config.column_types import ( + DataDesignerColumnType, column_type_is_llm_generated, column_type_used_in_execution_dag, get_column_config_from_kwargs, diff --git a/tests/config/test_config_builder.py b/tests/config/test_config_builder.py index 76ce7bc7..fef6cfb9 100644 --- a/tests/config/test_config_builder.py +++ b/tests/config/test_config_builder.py @@ -11,8 +11,7 @@ import yaml from data_designer.config.analysis.column_profilers import JudgeScoreProfilerConfig -from data_designer.config.columns import ( - DataDesignerColumnType, +from data_designer.config.column_configs import ( ExpressionColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, @@ -21,8 +20,8 @@ SamplerColumnConfig, Score, ValidationColumnConfig, - get_column_config_from_kwargs, ) +from data_designer.config.column_types import DataDesignerColumnType, get_column_config_from_kwargs from data_designer.config.config_builder import BuilderConfig, DataDesignerConfigBuilder from data_designer.config.data_designer_config import DataDesignerConfig from data_designer.config.datastore import DatastoreSettings diff --git a/tests/config/utils/test_validation.py b/tests/config/utils/test_validation.py index 451afbbe..0b5c5684 100644 --- a/tests/config/utils/test_validation.py +++ b/tests/config/utils/test_validation.py @@ -3,7 +3,7 @@ from unittest.mock import Mock, patch -from data_designer.config.columns import ( +from data_designer.config.column_configs import ( ExpressionColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, diff --git a/tests/conftest.py b/tests/conftest.py index f0e6546e..31dc0057 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,7 +13,7 @@ from data_designer.config.analysis.column_statistics import GeneralColumnStatistics from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults -from data_designer.config.columns import SamplerColumnConfig +from data_designer.config.column_configs import SamplerColumnConfig from data_designer.config.config_builder import DataDesignerConfigBuilder from data_designer.config.data_designer_config import DataDesignerConfig from data_designer.config.datastore import DatastoreSettings diff --git a/tests/engine/analysis/column_profilers/test_base.py b/tests/engine/analysis/column_profilers/test_base.py index 966a78ab..be74f2b5 100644 --- a/tests/engine/analysis/column_profilers/test_base.py +++ b/tests/engine/analysis/column_profilers/test_base.py @@ -5,7 +5,9 @@ from pydantic import ValidationError import pytest -from data_designer.config.columns import DataDesignerColumnType, SamplerColumnConfig, SamplerType +from data_designer.config.column_configs import SamplerColumnConfig +from data_designer.config.column_types import DataDesignerColumnType +from data_designer.config.sampler_params import SamplerType from data_designer.engine.analysis.column_profilers.base import ( ColumnConfigWithDataFrame, ColumnProfilerMetadata, diff --git a/tests/engine/analysis/column_profilers/test_judge_score_profiler.py b/tests/engine/analysis/column_profilers/test_judge_score_profiler.py index a380b5c4..343f44a6 100644 --- a/tests/engine/analysis/column_profilers/test_judge_score_profiler.py +++ b/tests/engine/analysis/column_profilers/test_judge_score_profiler.py @@ -14,7 +14,7 @@ MissingValue, NumericalDistribution, ) -from data_designer.config.columns import LLMJudgeColumnConfig, Score +from data_designer.config.column_configs import LLMJudgeColumnConfig, Score from data_designer.engine.analysis.column_profilers.base import ColumnConfigWithDataFrame from data_designer.engine.analysis.column_profilers.judge_score_profiler import ( JudgeScoreProfiler, diff --git a/tests/engine/analysis/conftest.py b/tests/engine/analysis/conftest.py index 7d571d5f..ff64af3f 100644 --- a/tests/engine/analysis/conftest.py +++ b/tests/engine/analysis/conftest.py @@ -14,7 +14,8 @@ ColumnDistributionType, NumericalDistribution, ) -from data_designer.config.columns import ColumnConfigT, LLMJudgeColumnConfig, Score +from data_designer.config.column_configs import LLMJudgeColumnConfig, Score +from data_designer.config.column_types import ColumnConfigT from data_designer.config.models import ModelConfig from data_designer.engine.analysis.dataset_profiler import ( DataDesignerDatasetProfiler, diff --git a/tests/engine/analysis/test_column_statistics_calculator.py b/tests/engine/analysis/test_column_statistics_calculator.py index b72c1b93..a9030d9e 100644 --- a/tests/engine/analysis/test_column_statistics_calculator.py +++ b/tests/engine/analysis/test_column_statistics_calculator.py @@ -4,7 +4,7 @@ import pandas as pd from data_designer.config.analysis.column_statistics import ColumnDistributionType -from data_designer.config.columns import DataDesignerColumnType +from data_designer.config.column_types import DataDesignerColumnType from data_designer.config.sampler_params import SamplerType from data_designer.engine.analysis.column_profilers.base import ColumnConfigWithDataFrame from data_designer.engine.analysis.column_statistics import get_column_statistics_calculator diff --git a/tests/engine/analysis/test_dataset_profiler.py b/tests/engine/analysis/test_dataset_profiler.py index 154f6360..0bd1e6cf 100644 --- a/tests/engine/analysis/test_dataset_profiler.py +++ b/tests/engine/analysis/test_dataset_profiler.py @@ -6,7 +6,7 @@ import pytest from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults -from data_designer.config.columns import SamplerColumnConfig +from data_designer.config.column_configs import SamplerColumnConfig from data_designer.config.sampler_params import CategorySamplerParams, SamplerType from data_designer.engine.analysis.column_profilers.judge_score_profiler import JudgeScoreProfilerConfig from data_designer.engine.analysis.dataset_profiler import ( diff --git a/tests/engine/analysis/utils/test_column_statistics_calculations.py b/tests/engine/analysis/utils/test_column_statistics_calculations.py index 94b20503..2f7316aa 100644 --- a/tests/engine/analysis/utils/test_column_statistics_calculations.py +++ b/tests/engine/analysis/utils/test_column_statistics_calculations.py @@ -16,7 +16,7 @@ MissingValue, NumericalDistribution, ) -from data_designer.config.columns import LLMTextColumnConfig +from data_designer.config.column_configs import LLMTextColumnConfig from data_designer.config.utils.numerical_helpers import prepare_number_for_reporting from data_designer.engine.analysis.utils.column_statistics_calculations import ( calculate_column_distribution, diff --git a/tests/engine/analysis/utils/test_judge_score_processing.py b/tests/engine/analysis/utils/test_judge_score_processing.py index 85663c75..e20e65f5 100644 --- a/tests/engine/analysis/utils/test_judge_score_processing.py +++ b/tests/engine/analysis/utils/test_judge_score_processing.py @@ -11,7 +11,7 @@ MissingValue, NumericalDistribution, ) -from data_designer.config.columns import LLMJudgeColumnConfig, Score +from data_designer.config.column_configs import LLMJudgeColumnConfig, Score from data_designer.engine.analysis.utils.judge_score_processing import ( JudgeScoreDistributions, JudgeScoreSample, diff --git a/tests/engine/column_generators/generators/test_column_generator_base.py b/tests/engine/column_generators/generators/test_column_generator_base.py index 433214f5..511b1c0f 100644 --- a/tests/engine/column_generators/generators/test_column_generator_base.py +++ b/tests/engine/column_generators/generators/test_column_generator_base.py @@ -5,7 +5,7 @@ import pandas as pd -from data_designer.config.columns import ExpressionColumnConfig +from data_designer.config.column_configs import ExpressionColumnConfig from data_designer.engine.column_generators.generators.base import ( ColumnGenerator, FromScratchColumnGenerator, diff --git a/tests/engine/column_generators/generators/test_expression.py b/tests/engine/column_generators/generators/test_expression.py index e2db903d..ec42c99d 100644 --- a/tests/engine/column_generators/generators/test_expression.py +++ b/tests/engine/column_generators/generators/test_expression.py @@ -6,7 +6,7 @@ import pandas as pd import pytest -from data_designer.config.columns import ExpressionColumnConfig +from data_designer.config.column_configs import ExpressionColumnConfig from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator from data_designer.engine.column_generators.utils.errors import ExpressionTemplateRenderError from data_designer.engine.resources.resource_provider import ResourceProvider diff --git a/tests/engine/column_generators/generators/test_llm_generators.py b/tests/engine/column_generators/generators/test_llm_generators.py index 8854aa7d..259f3a08 100644 --- a/tests/engine/column_generators/generators/test_llm_generators.py +++ b/tests/engine/column_generators/generators/test_llm_generators.py @@ -5,7 +5,7 @@ import pytest -from data_designer.config.columns import ( +from data_designer.config.column_configs import ( LLMCodeColumnConfig, LLMJudgeColumnConfig, LLMStructuredColumnConfig, diff --git a/tests/engine/column_generators/generators/test_samplers.py b/tests/engine/column_generators/generators/test_samplers.py index 0b30b969..23e68b8e 100644 --- a/tests/engine/column_generators/generators/test_samplers.py +++ b/tests/engine/column_generators/generators/test_samplers.py @@ -3,7 +3,7 @@ import pytest -from data_designer.config.columns import SamplerColumnConfig +from data_designer.config.column_configs import SamplerColumnConfig from data_designer.config.sampler_params import ( BernoulliMixtureSamplerParams, BernoulliSamplerParams, diff --git a/tests/engine/column_generators/generators/test_seed_dataset.py b/tests/engine/column_generators/generators/test_seed_dataset.py index ebb9c72a..37e69a26 100644 --- a/tests/engine/column_generators/generators/test_seed_dataset.py +++ b/tests/engine/column_generators/generators/test_seed_dataset.py @@ -9,7 +9,7 @@ import pandas as pd import pytest -from data_designer.config.columns import SeedDatasetColumnConfig +from data_designer.config.column_configs import SeedDatasetColumnConfig from data_designer.config.seed import IndexRange, PartitionBlock, SamplingStrategy from data_designer.engine.column_generators.generators.base import GenerationStrategy from data_designer.engine.column_generators.generators.seed_dataset import ( diff --git a/tests/engine/column_generators/generators/test_validation.py b/tests/engine/column_generators/generators/test_validation.py index e84d0335..65f9bdd4 100644 --- a/tests/engine/column_generators/generators/test_validation.py +++ b/tests/engine/column_generators/generators/test_validation.py @@ -6,7 +6,7 @@ import pandas as pd import pytest -from data_designer.config.columns import ValidationColumnConfig +from data_designer.config.column_configs import ValidationColumnConfig from data_designer.config.utils.code_lang import CodeLang from data_designer.config.validator_params import ( CodeValidatorParams, diff --git a/tests/engine/column_generators/test_registry.py b/tests/engine/column_generators/test_registry.py index e658283e..03e67486 100644 --- a/tests/engine/column_generators/test_registry.py +++ b/tests/engine/column_generators/test_registry.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from data_designer.config.columns import DataDesignerColumnType +from data_designer.config.column_types import DataDesignerColumnType from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator from data_designer.engine.column_generators.generators.llm_generators import ( LLMCodeCellGenerator, @@ -14,12 +14,12 @@ from data_designer.engine.column_generators.generators.validation import ValidationColumnGenerator from data_designer.engine.column_generators.registry import ( ColumnGeneratorRegistry, - create_default_column_generator_registry, + create_builtin_column_generator_registry, ) def test_column_generator_registry_create_default_registry_with_generators(): - registry = create_default_column_generator_registry() + registry = create_builtin_column_generator_registry() assert isinstance(registry, ColumnGeneratorRegistry) diff --git a/tests/engine/column_generators/utils/test_judge_score_factory.py b/tests/engine/column_generators/utils/test_judge_score_factory.py index d692247a..afb830f3 100644 --- a/tests/engine/column_generators/utils/test_judge_score_factory.py +++ b/tests/engine/column_generators/utils/test_judge_score_factory.py @@ -6,7 +6,7 @@ from pydantic import BaseModel import pytest -from data_designer.config.columns import Score +from data_designer.config.column_configs import Score from data_designer.engine.column_generators.utils.judge_score_factory import ( SCORE_FIELD_DESCRIPTION_FORMAT, SCORING_FORMAT, diff --git a/tests/engine/column_generators/utils/test_prompt_renderer.py b/tests/engine/column_generators/utils/test_prompt_renderer.py index 09c70253..cd3697b9 100644 --- a/tests/engine/column_generators/utils/test_prompt_renderer.py +++ b/tests/engine/column_generators/utils/test_prompt_renderer.py @@ -5,7 +5,7 @@ import pytest -from data_designer.config.columns import ( +from data_designer.config.column_configs import ( LLMCodeColumnConfig, LLMJudgeColumnConfig, LLMStructuredColumnConfig, diff --git a/tests/engine/dataset_builders/test_column_wise_builder.py b/tests/engine/dataset_builders/test_column_wise_builder.py index 487efca7..5572af44 100644 --- a/tests/engine/dataset_builders/test_column_wise_builder.py +++ b/tests/engine/dataset_builders/test_column_wise_builder.py @@ -6,7 +6,7 @@ import pandas as pd import pytest -from data_designer.config.columns import LLMTextColumnConfig, SamplerColumnConfig +from data_designer.config.column_configs import LLMTextColumnConfig, SamplerColumnConfig from data_designer.config.dataset_builders import BuildStage from data_designer.config.processors import DropColumnsProcessorConfig from data_designer.engine.dataset_builders.column_wise_builder import ( diff --git a/tests/engine/dataset_builders/test_multi_column_configs.py b/tests/engine/dataset_builders/test_multi_column_configs.py index e030c1da..2754d5c5 100644 --- a/tests/engine/dataset_builders/test_multi_column_configs.py +++ b/tests/engine/dataset_builders/test_multi_column_configs.py @@ -4,11 +4,8 @@ from pydantic import ValidationError import pytest -from data_designer.config.columns import ( - DataDesignerColumnType, - SamplerColumnConfig, - SeedDatasetColumnConfig, -) +from data_designer.config.column_configs import SamplerColumnConfig, SeedDatasetColumnConfig +from data_designer.config.column_types import DataDesignerColumnType from data_designer.config.sampler_constraints import ( ColumnInequalityConstraint, InequalityOperator, diff --git a/tests/engine/dataset_builders/utils/test_config_compiler.py b/tests/engine/dataset_builders/utils/test_config_compiler.py index b17dab5a..00809e46 100644 --- a/tests/engine/dataset_builders/utils/test_config_compiler.py +++ b/tests/engine/dataset_builders/utils/test_config_compiler.py @@ -3,12 +3,8 @@ import pytest -from data_designer.config.columns import ( - DataDesignerColumnType, - LLMTextColumnConfig, - SamplerColumnConfig, - SeedDatasetColumnConfig, -) +from data_designer.config.column_configs import LLMTextColumnConfig, SamplerColumnConfig, SeedDatasetColumnConfig +from data_designer.config.column_types import DataDesignerColumnType from data_designer.config.data_designer_config import DataDesignerConfig from data_designer.config.seed import SamplingStrategy, SeedConfig from data_designer.engine.dataset_builders.utils.config_compiler import ( diff --git a/tests/engine/dataset_builders/utils/test_dag.py b/tests/engine/dataset_builders/utils/test_dag.py index dec06be8..b914ff85 100644 --- a/tests/engine/dataset_builders/utils/test_dag.py +++ b/tests/engine/dataset_builders/utils/test_dag.py @@ -3,8 +3,7 @@ import pytest -from data_designer.config.columns import ( - DataDesignerColumnType, +from data_designer.config.column_configs import ( ExpressionColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, @@ -13,6 +12,7 @@ Score, ValidationColumnConfig, ) +from data_designer.config.column_types import DataDesignerColumnType from data_designer.config.sampler_params import SamplerType from data_designer.config.utils.code_lang import CodeLang from data_designer.config.validator_params import CodeValidatorParams diff --git a/tests/engine/registry/test_data_designer_registry.py b/tests/engine/registry/test_data_designer_registry.py index 5f98970c..56d7cae8 100644 --- a/tests/engine/registry/test_data_designer_registry.py +++ b/tests/engine/registry/test_data_designer_registry.py @@ -21,7 +21,7 @@ def stub_column_profiler_registry(): @pytest.fixture def stub_default_registries(): with patch( - "data_designer.engine.registry.data_designer_registry.create_default_column_generator_registry" + "data_designer.engine.registry.data_designer_registry.create_builtin_column_generator_registry" ) as mock_gen: with patch( "data_designer.engine.registry.data_designer_registry.create_default_column_profiler_registry" From 8604dc1a262752779d7ede810fe966c05fdfd15a Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Fri, 7 Nov 2025 14:48:43 -0500 Subject: [PATCH 02/27] create plugin object --- src/data_designer/plugins/plugin.py | 53 ++++++++ tests/plugins/test_plugin.py | 181 ++++++++++++++++++++++++++++ 2 files changed, 234 insertions(+) create mode 100644 src/data_designer/plugins/plugin.py create mode 100644 tests/plugins/test_plugin.py diff --git a/src/data_designer/plugins/plugin.py b/src/data_designer/plugins/plugin.py new file mode 100644 index 00000000..e7e4d29c --- /dev/null +++ b/src/data_designer/plugins/plugin.py @@ -0,0 +1,53 @@ +from enum import Enum +from typing import Literal, Type, get_origin + +from pydantic import BaseModel, model_validator +from typing_extensions import Self + +from data_designer.config.base import ConfigBase +from data_designer.engine.configurable_task import ConfigurableTask + + +class PluginType(str, Enum): + COLUMN_GENERATOR = "column_generator" + + @property + def discriminator_field(self) -> str: + if self == PluginType.COLUMN_GENERATOR: + return "column_type" + else: + raise ValueError(f"Invalid plugin type: {self}") + + +class Plugin(BaseModel): + task_cls: Type[ConfigurableTask] + config_cls: Type[ConfigBase] + plugin_type: PluginType + + @property + def name(self) -> str: + return self.config_cls.model_fields[self.discriminator_field].default + + @property + def discriminator_field(self) -> str: + return self.plugin_type.discriminator_field + + @model_validator(mode="after") + def validate_discriminator_field(self) -> Self: + cfg = self.config_cls.__name__ + field = self.plugin_type.discriminator_field + if field not in self.config_cls.model_fields: + raise ValueError(f"Discriminator field '{field}' not found in config class {cfg}") + field_info = self.config_cls.model_fields[field] + if get_origin(field_info.annotation) is not Literal: + raise ValueError(f"Field '{field}' of {cfg} must be a Literal type, not {field_info.annotation}.") + if not isinstance(field_info.default, str): + raise ValueError(f"The default of '{field}' must be a string, not {type(field_info.default)}.") + enum_key = field_info.default.replace("-", "_").upper() + if not enum_key.isidentifier(): + raise ValueError( + f"The default value '{field_info.default}' for discriminator field '{field}' " + f"cannot be converted to a valid enum key. The converted key '{enum_key}' " + f"must be a valid Python identifier." + ) + return self diff --git a/tests/plugins/test_plugin.py b/tests/plugins/test_plugin.py new file mode 100644 index 00000000..c505f8bb --- /dev/null +++ b/tests/plugins/test_plugin.py @@ -0,0 +1,181 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Literal + +from pydantic import ValidationError +import pytest + +from data_designer.config.base import ConfigBase +from data_designer.config.column_configs import SamplerColumnConfig, SingleColumnConfig +from data_designer.engine.column_generators.generators.samplers import SamplerColumnGenerator +from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata +from data_designer.plugins.plugin import Plugin, PluginType + + +class ValidTestConfig(SingleColumnConfig): + """Valid config for testing plugin creation.""" + + column_type: Literal["test-generator"] = "test-generator" + name: str + + +class ValidTestTask(ConfigurableTask[ValidTestConfig]): + """Valid task for testing plugin creation.""" + + @staticmethod + def metadata() -> ConfigurableTaskMetadata: + return ConfigurableTaskMetadata( + name="test_generator", + description="Test generator", + required_resources=None, + ) + + +@pytest.fixture +def valid_plugin() -> Plugin: + """Fixture providing a valid plugin instance for testing.""" + return Plugin( + task_cls=ValidTestTask, + config_cls=ValidTestConfig, + plugin_type=PluginType.COLUMN_GENERATOR, + ) + + +# ============================================================================= +# PluginType Tests +# ============================================================================= + + +def test_plugin_type_discriminator_field_for_column_generator() -> None: + """Test that COLUMN_GENERATOR type returns 'column_type' as discriminator field.""" + assert PluginType.COLUMN_GENERATOR.discriminator_field == "column_type" + + +def test_plugin_type_all_types_have_discriminator_fields() -> None: + """Test that all plugin types have valid discriminator fields.""" + for plugin_type in PluginType: + assert isinstance(plugin_type.discriminator_field, str) + assert len(plugin_type.discriminator_field) > 0 + + +# ============================================================================= +# Plugin Creation and Properties Tests +# ============================================================================= + + +def test_create_plugin_with_valid_inputs(valid_plugin: Plugin) -> None: + """Test that Plugin can be created with valid task, config, and plugin type.""" + assert valid_plugin.task_cls == ValidTestTask + assert valid_plugin.config_cls == ValidTestConfig + assert valid_plugin.plugin_type == PluginType.COLUMN_GENERATOR + + +def test_plugin_name_derived_from_config_default(valid_plugin: Plugin) -> None: + """Test that plugin.name returns the discriminator field's default value.""" + assert valid_plugin.name == "test-generator" + + +def test_plugin_discriminator_field_from_type(valid_plugin: Plugin) -> None: + """Test that plugin.discriminator_field returns the correct field name.""" + assert valid_plugin.discriminator_field == "column_type" + + +def test_plugin_requires_all_fields() -> None: + """Test that Plugin creation fails without required fields.""" + with pytest.raises(ValidationError): + Plugin() # type: ignore + + with pytest.raises(ValidationError): + Plugin(task_cls=ValidTestTask) # type: ignore + + with pytest.raises(ValidationError): + Plugin(config_cls=ValidTestConfig) # type: ignore + + +# ============================================================================= +# Plugin Validation Tests +# ============================================================================= + + +def test_validation_fails_when_config_missing_discriminator_field() -> None: + """Test validation fails when config lacks the required discriminator field.""" + + class ConfigWithoutDiscriminator(ConfigBase): + some_field: str + + with pytest.raises(ValueError, match="Discriminator field 'column_type' not found in config class"): + Plugin( + task_cls=ValidTestTask, + config_cls=ConfigWithoutDiscriminator, + plugin_type=PluginType.COLUMN_GENERATOR, + ) + + +def test_validation_fails_when_discriminator_field_not_literal_type() -> None: + """Test validation fails when discriminator field is not a Literal type.""" + + class ConfigWithStringField(ConfigBase): + column_type: str = "test-generator" + + with pytest.raises(ValueError, match="Field 'column_type' of .* must be a Literal type"): + Plugin( + task_cls=ValidTestTask, + config_cls=ConfigWithStringField, + plugin_type=PluginType.COLUMN_GENERATOR, + ) + + +def test_validation_fails_when_discriminator_default_not_string() -> None: + """Test validation fails when discriminator field default is not a string.""" + + class ConfigWithNonStringDefault(ConfigBase): + column_type: Literal["test-generator"] = 123 # type: ignore + + with pytest.raises(ValueError, match="The default of 'column_type' must be a string"): + Plugin( + task_cls=ValidTestTask, + config_cls=ConfigWithNonStringDefault, + plugin_type=PluginType.COLUMN_GENERATOR, + ) + + +def test_validation_fails_with_invalid_enum_key_conversion() -> None: + """Test validation fails when default value cannot be converted to valid Python identifier.""" + + class ConfigWithInvalidKey(ConfigBase): + column_type: Literal["invalid-key-!@#"] = "invalid-key-!@#" + + with pytest.raises(ValueError, match="cannot be converted to a valid enum key"): + Plugin( + task_cls=ValidTestTask, + config_cls=ConfigWithInvalidKey, + plugin_type=PluginType.COLUMN_GENERATOR, + ) + + +# ============================================================================= +# Integration Tests +# ============================================================================= + + +def test_plugin_works_with_real_sampler_column_generator() -> None: + """Test that Plugin works with actual SamplerColumnGenerator from the codebase.""" + plugin = Plugin( + task_cls=SamplerColumnGenerator, + config_cls=SamplerColumnConfig, + plugin_type=PluginType.COLUMN_GENERATOR, + ) + + assert plugin.name == "sampler" + assert plugin.discriminator_field == "column_type" + assert plugin.task_cls == SamplerColumnGenerator + assert plugin.config_cls == SamplerColumnConfig + + +def test_plugin_preserves_type_information(valid_plugin: Plugin) -> None: + """Test that Plugin correctly stores and provides access to type information.""" + assert isinstance(valid_plugin.task_cls, type) + assert isinstance(valid_plugin.config_cls, type) + assert issubclass(valid_plugin.task_cls, ConfigurableTask) + assert issubclass(valid_plugin.config_cls, ConfigBase) From 3223d1249f6289c308eda9a4d3fecc009ab12aa5 Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Sat, 8 Nov 2025 14:07:55 -0500 Subject: [PATCH 03/27] create plugin manager --- pyproject.toml | 4 + src/data_designer/config/config_builder.py | 13 + .../engine/registry/data_designer_registry.py | 17 +- src/data_designer/essentials/__init__.py | 7 +- src/data_designer/plugins/errors.py | 7 + src/data_designer/plugins/manager.py | 119 ++++ src/data_designer/plugins/plugin.py | 7 +- tests/plugins/test_manager.py | 527 ++++++++++++++++++ uv.lock | 15 + 9 files changed, 707 insertions(+), 9 deletions(-) create mode 100644 src/data_designer/plugins/errors.py create mode 100644 src/data_designer/plugins/manager.py create mode 100644 tests/plugins/test_manager.py diff --git a/pyproject.toml b/pyproject.toml index 362b5298..53df0082 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ dev = [ "pytest>=8.3.3", "pytest-asyncio>=0.24.0", "pytest-cov>=7.0.0", + "pytest-env>=1.2.0", "pytest-httpx>=0.35.0", ] docs = [ @@ -89,6 +90,9 @@ version-file = "src/data_designer/_version.py" [tool.pytest.ini_options] testpaths = ["tests"] asyncio_default_fixture_loop_scope = "session" +env = [ + "DATA_DESIGNER_PLUGIN_DIR=/tmp/pytest-no-plugins", +] [tool.uv] package = true diff --git a/src/data_designer/config/config_builder.py b/src/data_designer/config/config_builder.py index db1b8691..a3bad644 100644 --- a/src/data_designer/config/config_builder.py +++ b/src/data_designer/config/config_builder.py @@ -54,6 +54,11 @@ from .utils.type_helpers import resolve_string_enum from .utils.validation import ViolationLevel, rich_print_violations, validate_data_designer_config +if can_run_data_designer_locally(): + from data_designer.plugins.manager import PluginManager, PluginType + + plugin_manager = PluginManager() + logger = logging.getLogger(__name__) @@ -633,6 +638,14 @@ def __repr__(self) -> str: DataDesignerColumnType.LLM_JUDGE, DataDesignerColumnType.VALIDATION, DataDesignerColumnType.EXPRESSION, + *( + [] + if not can_run_data_designer_locally() + else [ + DataDesignerColumnType(name) + for name in plugin_manager.get_plugin_names(PluginType.COLUMN_GENERATOR) + ] + ), ]: columns = self.get_columns_of_type(column_type) if len(columns) > 0: diff --git a/src/data_designer/engine/registry/data_designer_registry.py b/src/data_designer/engine/registry/data_designer_registry.py index 60c6e1fc..763b0d46 100644 --- a/src/data_designer/engine/registry/data_designer_registry.py +++ b/src/data_designer/engine/registry/data_designer_registry.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from data_designer.config.column_types import DataDesignerColumnType from data_designer.engine.analysis.column_profilers.registry import ( ColumnProfilerRegistry, create_default_column_profiler_registry, @@ -9,10 +10,10 @@ ColumnGeneratorRegistry, create_builtin_column_generator_registry, ) -from data_designer.engine.processing.processors.registry import ( - ProcessorRegistry, - create_default_processor_registry, -) +from data_designer.engine.processing.processors.registry import ProcessorRegistry, create_default_processor_registry +from data_designer.plugins.manager import PluginManager, PluginType + +plugin_manager = PluginManager() class DataDesignerRegistry: @@ -27,6 +28,14 @@ def __init__( self._column_profiler_registry = column_profiler_registry or create_default_column_profiler_registry() self._processor_registry = processor_registry or create_default_processor_registry() + for plugin in plugin_manager.get_plugins(PluginType.COLUMN_GENERATOR): + self._column_generator_registry.register( + DataDesignerColumnType(plugin.name), + plugin.task_cls, + plugin.config_cls, + raise_on_collision=True, + ) + @property def column_generators(self) -> ColumnGeneratorRegistry: return self._column_generator_registry diff --git a/src/data_designer/essentials/__init__.py b/src/data_designer/essentials/__init__.py index 80dfa7b5..70ec184d 100644 --- a/src/data_designer/essentials/__init__.py +++ b/src/data_designer/essentials/__init__.py @@ -1,5 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from ..logging import LoggingConfig, configure_logging + +configure_logging(LoggingConfig.default()) from ..config.analysis.column_profilers import JudgeScoreProfilerConfig from ..config.column_configs import ( @@ -58,7 +61,6 @@ RemoteValidatorParams, ValidatorType, ) -from ..logging import LoggingConfig, configure_logging local_library_imports = [] try: @@ -131,6 +133,3 @@ ] __all__.extend(local_library_imports) - - -configure_logging(LoggingConfig.default()) diff --git a/src/data_designer/plugins/errors.py b/src/data_designer/plugins/errors.py new file mode 100644 index 00000000..a57f7565 --- /dev/null +++ b/src/data_designer/plugins/errors.py @@ -0,0 +1,7 @@ +from data_designer.errors import DataDesignerError + + +class PluginRegistrationError(DataDesignerError): ... + + +class PluginNotFoundError(DataDesignerError): ... diff --git a/src/data_designer/plugins/manager.py b/src/data_designer/plugins/manager.py new file mode 100644 index 00000000..2b3b6de6 --- /dev/null +++ b/src/data_designer/plugins/manager.py @@ -0,0 +1,119 @@ +import importlib.util +import inspect +import logging +import os +from pathlib import Path +import sys +import threading +from typing import Iterator, Optional, Type, TypeAlias + +from typing_extensions import Self + +from data_designer.plugins.errors import PluginNotFoundError, PluginRegistrationError +from data_designer.plugins.plugin import Plugin, PluginType + +logger = logging.getLogger(__name__) + + +def _get_default_plugin_directory() -> Path: + """Get the default plugin directory from environment or user's home directory. + + This function is called at runtime rather than at module import time, + allowing tests to override the plugin directory via environment variables. + """ + env_dir = os.getenv("DATA_DESIGNER_PLUGIN_DIR") + if env_dir: + return Path(env_dir) + return Path.home() / ".data_designer" / "plugins" + + +class PluginManager: + def __init__(self): + self.registry = _PluginRegistry() + + def get_plugin(self, plugin_name: str) -> Plugin: + return self.registry.get(plugin_name) + + def get_plugins(self, plugin_type: PluginType) -> list[Plugin]: + return [plugin for plugin in self.registry._plugins.values() if plugin.plugin_type == plugin_type] + + def get_plugin_names(self, plugin_type: PluginType) -> list[str]: + return [plugin.name for plugin in self.get_plugins(plugin_type)] + + def num_plugins(self, plugin_type: PluginType) -> int: + return len(self.get_plugins(plugin_type)) + + def update_type_union(self, type_union: Type[TypeAlias], plugin_type: PluginType) -> Type[TypeAlias]: + for plugin in self.get_plugins(plugin_type): + type_union |= plugin.config_cls + return type_union + + def discover(self, plugin_dir: Optional[Path] = None) -> Self: + plugin_dir = Path(plugin_dir or _get_default_plugin_directory()) + + if not plugin_dir.exists(): + return self + + for file_path in plugin_dir.rglob("*.py"): + if file_path.name.startswith("_"): + continue + + for plugin in self._iter_plugins_from_file(file_path, plugin_dir): + if isinstance(plugin, Plugin): + self.registry.register_plugin(plugin) + logger.info( + f"🔌 Plugin discovered ➜ {plugin.plugin_type.value.replace('-', ' ')} " + f"{plugin.name.upper().replace('-', '_')} is now available ⚡️" + ) + + return self + + def _iter_plugins_from_file(self, file_path: Path, plugin_dir: Path) -> Optional[Iterator[Plugin]]: + label = str(file_path.relative_to(plugin_dir)).replace("/", "_").replace(".", "_") + module_name = f"_plugin_{label}" + + try: + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None or spec.loader is None: + return + + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + + for _, obj in inspect.getmembers(module): + if isinstance(obj, Plugin): + yield obj + + except Exception: + return + + +class _PluginRegistry: + _plugins: dict[str, Plugin] = {} + _instance = None + _lock = threading.Lock() + + def get(self, plugin_name: str) -> Plugin: + if plugin_name not in self._plugins: + raise PluginNotFoundError(f"Plugin '{plugin_name}' not found.") + return self._plugins[plugin_name] + + def register_plugin(self, plugin: Plugin) -> None: + with self._lock: + if plugin.name in self._plugins: + raise PluginRegistrationError(f"Plugin '{plugin.name}' already registered.") + self._plugins[plugin.name] = plugin + + def clear(self) -> None: + """Clear all registered plugins. Primarily for testing purposes.""" + with self._lock: + self._plugins.clear() + + def __new__(cls, *args, **kwargs): + """Plugin registry is a singleton.""" + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super().__new__(cls) + return cls._instance diff --git a/src/data_designer/plugins/plugin.py b/src/data_designer/plugins/plugin.py index e7e4d29c..6285e8d1 100644 --- a/src/data_designer/plugins/plugin.py +++ b/src/data_designer/plugins/plugin.py @@ -9,7 +9,7 @@ class PluginType(str, Enum): - COLUMN_GENERATOR = "column_generator" + COLUMN_GENERATOR = "column-generator" @property def discriminator_field(self) -> str: @@ -23,6 +23,11 @@ class Plugin(BaseModel): task_cls: Type[ConfigurableTask] config_cls: Type[ConfigBase] plugin_type: PluginType + emoji: str = "🔌" + + @property + def enum_key(self) -> str: + return self.name.replace("-", "_").upper() @property def name(self) -> str: diff --git a/tests/plugins/test_manager.py b/tests/plugins/test_manager.py new file mode 100644 index 00000000..2098cd9d --- /dev/null +++ b/tests/plugins/test_manager.py @@ -0,0 +1,527 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path +from typing import Literal + +import pytest + +from data_designer.config.base import ConfigBase +from data_designer.config.column_configs import SingleColumnConfig +from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata +from data_designer.plugins.errors import PluginNotFoundError, PluginRegistrationError +from data_designer.plugins.manager import PluginManager, _PluginRegistry +from data_designer.plugins.plugin import Plugin, PluginType + +# ============================================================================= +# Test Fixtures +# ============================================================================= + + +@pytest.fixture +def plugin_manager() -> PluginManager: + """Create a PluginManager with a clean registry. + + This fixture ensures the singleton registry is cleared before and after each test, + preventing state leakage between tests and from any plugins in the default + plugin directory. + """ + manager = PluginManager() + # Clear any plugins that may have been auto-discovered (e.g., from ~/.data_designer/plugins/) + manager.registry.clear() + yield manager + # Cleanup: clear the singleton registry after the test + manager.registry.clear() + + +def create_plugin_file( + dir_path: Path, + filename: str, + plugin_name: str, + column_type: str, + task_name: str | None = None, +) -> Path: + """Helper to create test plugin files with less boilerplate. + + Args: + dir_path: Directory to create the plugin file in + filename: Name of the plugin file (e.g., "test_plugin.py") + plugin_name: Name of the plugin (e.g., "MyPlugin") + column_type: Column type literal value (e.g., "my-plugin") + task_name: Task metadata name (defaults to plugin_name lowercase with underscores) + + Returns: + Path to the created plugin file + """ + if task_name is None: + task_name = plugin_name.lower().replace("-", "_") + + plugin_var_name = plugin_name.lower().replace("-", "_") + + plugin_file = dir_path / filename + plugin_file.write_text( + f""" +from typing import Literal +from data_designer.config.column_configs import SingleColumnConfig +from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata +from data_designer.plugins.plugin import Plugin, PluginType + + +class {plugin_name}Config(SingleColumnConfig): + column_type: Literal["{column_type}"] = "{column_type}" + name: str + + +class {plugin_name}Task(ConfigurableTask[{plugin_name}Config]): + @staticmethod + def metadata() -> ConfigurableTaskMetadata: + return ConfigurableTaskMetadata( + name="{task_name}", + description="{plugin_name} task", + required_resources=None, + ) + + +{plugin_var_name} = Plugin( + task_cls={plugin_name}Task, + config_cls={plugin_name}Config, + plugin_type=PluginType.COLUMN_GENERATOR, +) +""" + ) + return plugin_file + + +@pytest.fixture +def temp_plugin_dir(tmp_path: Path) -> Path: + """Create a temporary directory with a test plugin file.""" + plugin_dir = tmp_path / "plugins" + plugin_dir.mkdir() + create_plugin_file(plugin_dir, "test_plugin.py", "MyPlugin", "my-plugin") + return plugin_dir + + +@pytest.fixture +def invalid_plugin_dir(tmp_path: Path) -> Path: + """Create a directory with an invalid plugin file.""" + plugin_dir = tmp_path / "invalid_plugins" + plugin_dir.mkdir() + + invalid_file = plugin_dir / "invalid.py" + invalid_file.write_text("import syntax error here") + + return plugin_dir + + +# ============================================================================= +# Plugin Discovery Tests +# ============================================================================= + + +def test_discover_finds_plugin(plugin_manager: PluginManager, temp_plugin_dir: Path) -> None: + """Test that discover() finds and registers plugins in the plugin directory.""" + plugin_manager.discover(plugin_dir=temp_plugin_dir) + + assert plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) == 1 + assert "my-plugin" in plugin_manager.get_plugin_names(PluginType.COLUMN_GENERATOR) + + +@pytest.mark.parametrize( + "dir_setup", + [ + ("empty", lambda tmp_path: (tmp_path / "empty").mkdir() or (tmp_path / "empty")), + ("nonexistent", lambda tmp_path: tmp_path / "does_not_exist"), + ], + ids=["empty_directory", "nonexistent_directory"], +) +def test_discover_handles_missing_or_empty_directories( + plugin_manager: PluginManager, tmp_path: Path, dir_setup: tuple[str, callable] +) -> None: + """Test that discover() handles empty and nonexistent directories gracefully.""" + _, setup_func = dir_setup + plugin_dir = setup_func(tmp_path) + + plugin_manager.discover(plugin_dir=plugin_dir) + + assert plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) == 0 + + +def test_discover_skips_private_files(plugin_manager: PluginManager, tmp_path: Path) -> None: + """Test that discover() skips files starting with underscore.""" + plugin_dir = tmp_path / "plugins" + plugin_dir.mkdir() + + create_plugin_file(plugin_dir, "_private_plugin.py", "PrivatePlugin", "private-plugin") + + plugin_manager.discover(plugin_dir=plugin_dir) + + assert plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) == 0 + + +def test_discover_handles_invalid_files(plugin_manager: PluginManager, invalid_plugin_dir: Path) -> None: + """Test that discover() gracefully handles invalid Python files.""" + plugin_manager.discover(plugin_dir=invalid_plugin_dir) + + assert plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) == 0 + + +def test_discover_finds_multiple_plugins_in_same_file(plugin_manager: PluginManager, tmp_path: Path) -> None: + """Test that discover() can find multiple Plugin instances in the same file.""" + plugin_dir = tmp_path / "plugins" + plugin_dir.mkdir() + + multi_plugin_file = plugin_dir / "multi.py" + multi_plugin_file.write_text( + """ +from typing import Literal +from data_designer.config.column_configs import SingleColumnConfig +from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata +from data_designer.plugins.plugin import Plugin, PluginType + + +class Plugin1Config(SingleColumnConfig): + column_type: Literal["plugin-1"] = "plugin-1" + name: str + + +class Plugin1Task(ConfigurableTask[Plugin1Config]): + @staticmethod + def metadata() -> ConfigurableTaskMetadata: + return ConfigurableTaskMetadata( + name="plugin_1", + description="Plugin 1", + required_resources=None, + ) + + +class Plugin2Config(SingleColumnConfig): + column_type: Literal["plugin-2"] = "plugin-2" + name: str + + +class Plugin2Task(ConfigurableTask[Plugin2Config]): + @staticmethod + def metadata() -> ConfigurableTaskMetadata: + return ConfigurableTaskMetadata( + name="plugin_2", + description="Plugin 2", + required_resources=None, + ) + + +plugin1 = Plugin( + task_cls=Plugin1Task, + config_cls=Plugin1Config, + plugin_type=PluginType.COLUMN_GENERATOR, +) + +plugin2 = Plugin( + task_cls=Plugin2Task, + config_cls=Plugin2Config, + plugin_type=PluginType.COLUMN_GENERATOR, +) +""" + ) + + plugin_manager.discover(plugin_dir=plugin_dir) + + assert plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) == 2 + plugin_names = plugin_manager.get_plugin_names(PluginType.COLUMN_GENERATOR) + assert "plugin-1" in plugin_names + assert "plugin-2" in plugin_names + + +def test_discover_recursive_search(plugin_manager: PluginManager, tmp_path: Path) -> None: + """Test that discover() recursively searches subdirectories.""" + plugin_dir = tmp_path / "plugins" + plugin_dir.mkdir() + + subdir = plugin_dir / "subdir" + subdir.mkdir() + + create_plugin_file(subdir, "nested.py", "NestedPlugin", "nested-plugin") + + plugin_manager.discover(plugin_dir=plugin_dir) + + assert plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) == 1 + assert "nested-plugin" in plugin_manager.get_plugin_names(PluginType.COLUMN_GENERATOR) + + +def test_discover_multiple_calls(plugin_manager: PluginManager, tmp_path: Path) -> None: + """Test that discover() can be called multiple times to discover plugins from different directories.""" + dir1 = tmp_path / "plugins1" + dir1.mkdir() + create_plugin_file(dir1, "plugin1.py", "Plugin1", "plugin-1") + + dir2 = tmp_path / "plugins2" + dir2.mkdir() + create_plugin_file(dir2, "plugin2.py", "Plugin2", "plugin-2") + + plugin_manager.discover(plugin_dir=dir1) + plugin_manager.discover(plugin_dir=dir2) + + assert plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) == 2 + plugin_names = plugin_manager.get_plugin_names(PluginType.COLUMN_GENERATOR) + assert "plugin-1" in plugin_names + assert "plugin-2" in plugin_names + + +# ============================================================================= +# Plugin Retrieval Tests +# ============================================================================= + + +def test_get_plugin_returns_correct_plugin(plugin_manager: PluginManager, temp_plugin_dir: Path) -> None: + """Test that get_plugin() returns the correct plugin by name.""" + plugin_manager.discover(plugin_dir=temp_plugin_dir) + + plugin = plugin_manager.get_plugin("my-plugin") + + assert plugin.name == "my-plugin" + assert plugin.plugin_type == PluginType.COLUMN_GENERATOR + assert plugin.config_cls.__name__ == "MyPluginConfig" + assert plugin.task_cls.__name__ == "MyPluginTask" + + +def test_get_plugin_raises_not_found_error(plugin_manager: PluginManager) -> None: + """Test that get_plugin() raises PluginNotFoundError for nonexistent plugins.""" + with pytest.raises(PluginNotFoundError, match="Plugin 'nonexistent' not found"): + plugin_manager.get_plugin("nonexistent") + + +def test_get_plugins_returns_plugins_by_type(plugin_manager: PluginManager, tmp_path: Path) -> None: + """Test that get_plugins() returns all plugins of a specific type.""" + plugin_dir = tmp_path / "plugins" + plugin_dir.mkdir() + + create_plugin_file(plugin_dir, "plugin1.py", "Plugin1", "plugin-1") + create_plugin_file(plugin_dir, "plugin2.py", "Plugin2", "plugin-2") + + plugin_manager.discover(plugin_dir=plugin_dir) + + plugins = plugin_manager.get_plugins(PluginType.COLUMN_GENERATOR) + + assert len(plugins) == 2 + plugin_names = [p.name for p in plugins] + assert "plugin-1" in plugin_names + assert "plugin-2" in plugin_names + + +def test_get_plugin_names_returns_all_names(plugin_manager: PluginManager, temp_plugin_dir: Path) -> None: + """Test that get_plugin_names() returns all plugin names for a given type.""" + plugin_manager.discover(plugin_dir=temp_plugin_dir) + + names = plugin_manager.get_plugin_names(PluginType.COLUMN_GENERATOR) + + assert names == ["my-plugin"] + + +def test_num_plugins_returns_count(plugin_manager: PluginManager, temp_plugin_dir: Path) -> None: + """Test that num_plugins() returns the correct count.""" + assert plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) == 0 + + plugin_manager.discover(plugin_dir=temp_plugin_dir) + assert plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) == 1 + + +# ============================================================================= +# Type Union Tests +# ============================================================================= + + +def test_update_type_union_adds_config_types(plugin_manager: PluginManager, temp_plugin_dir: Path) -> None: + """Test that update_type_union() adds plugin config classes to the type union.""" + plugin_manager.discover(plugin_dir=temp_plugin_dir) + + # Start with a basic type + type_union = SingleColumnConfig + + updated_union = plugin_manager.update_type_union(type_union, PluginType.COLUMN_GENERATOR) + + # The union should now include the plugin's config class + plugin = plugin_manager.get_plugin("my-plugin") + assert plugin.config_cls in updated_union.__args__ + + +# ============================================================================= +# Error Handling Tests +# ============================================================================= + + +def test_register_duplicate_plugin_raises_error(plugin_manager: PluginManager, temp_plugin_dir: Path) -> None: + """Test that registering a duplicate plugin raises PluginRegistrationError.""" + plugin_manager.discover(plugin_dir=temp_plugin_dir) + + # Try to discover the same plugin again + with pytest.raises(PluginRegistrationError, match="Plugin 'my-plugin' already registered"): + plugin_manager.discover(plugin_dir=temp_plugin_dir) + + +# ============================================================================= +# Plugin Validation Tests +# ============================================================================= + + +def test_plugin_with_invalid_discriminator_field() -> None: + """Test that Plugin validation fails when discriminator field is missing.""" + + class InvalidConfig(ConfigBase): + name: str + + class InvalidTask(ConfigurableTask[InvalidConfig]): + @staticmethod + def metadata() -> ConfigurableTaskMetadata: + return ConfigurableTaskMetadata( + name="invalid", + description="Invalid plugin", + required_resources=None, + ) + + with pytest.raises(ValueError, match="Discriminator field 'column_type' not found"): + Plugin( + task_cls=InvalidTask, + config_cls=InvalidConfig, + plugin_type=PluginType.COLUMN_GENERATOR, + ) + + +def test_plugin_with_non_literal_discriminator() -> None: + """Test that Plugin validation fails when discriminator field is not a Literal type.""" + + class NonLiteralConfig(SingleColumnConfig): + column_type: str = "non-literal" # Should be Literal["non-literal"] + name: str + + class NonLiteralTask(ConfigurableTask[NonLiteralConfig]): + @staticmethod + def metadata() -> ConfigurableTaskMetadata: + return ConfigurableTaskMetadata( + name="non_literal", + description="Non-literal plugin", + required_resources=None, + ) + + with pytest.raises(ValueError, match="Field 'column_type' .* must be a Literal type"): + Plugin( + task_cls=NonLiteralTask, + config_cls=NonLiteralConfig, + plugin_type=PluginType.COLUMN_GENERATOR, + ) + + +def test_plugin_with_non_string_discriminator_default() -> None: + """Test that Plugin validation fails when discriminator default is not a string.""" + + class NonStringConfig(ConfigBase): + column_type: Literal[123] = 123 # Should be a string + name: str + + class NonStringTask(ConfigurableTask[NonStringConfig]): + @staticmethod + def metadata() -> ConfigurableTaskMetadata: + return ConfigurableTaskMetadata( + name="non_string", + description="Non-string plugin", + required_resources=None, + ) + + with pytest.raises(ValueError, match="The default of 'column_type' must be a string"): + Plugin( + task_cls=NonStringTask, + config_cls=NonStringConfig, + plugin_type=PluginType.COLUMN_GENERATOR, + ) + + +def test_plugin_with_invalid_enum_key() -> None: + """Test that Plugin validation fails when discriminator can't be converted to valid enum key.""" + + class InvalidEnumKeyConfig(SingleColumnConfig): + column_type: Literal["123-invalid"] = "123-invalid" # Starts with number + name: str + + class InvalidEnumKeyTask(ConfigurableTask[InvalidEnumKeyConfig]): + @staticmethod + def metadata() -> ConfigurableTaskMetadata: + return ConfigurableTaskMetadata( + name="invalid_enum", + description="Invalid enum key plugin", + required_resources=None, + ) + + with pytest.raises(ValueError, match="cannot be converted to a valid enum key"): + Plugin( + task_cls=InvalidEnumKeyTask, + config_cls=InvalidEnumKeyConfig, + plugin_type=PluginType.COLUMN_GENERATOR, + ) + + +def test_plugin_name_property(plugin_manager: PluginManager, temp_plugin_dir: Path) -> None: + """Test that plugin name property correctly extracts name from discriminator field.""" + plugin_manager.discover(plugin_dir=temp_plugin_dir) + + plugin = plugin_manager.get_plugin("my-plugin") + assert plugin.name == "my-plugin" + + +def test_plugin_enum_key_property(plugin_manager: PluginManager, temp_plugin_dir: Path) -> None: + """Test that plugin enum_key property correctly converts name to enum format.""" + plugin_manager.discover(plugin_dir=temp_plugin_dir) + + plugin = plugin_manager.get_plugin("my-plugin") + assert plugin.enum_key == "MY_PLUGIN" + + +# ============================================================================= +# Registry Singleton Tests +# ============================================================================= + + +def test_registry_is_singleton(plugin_manager: PluginManager) -> None: + """Test that _PluginRegistry is a singleton.""" + registry1 = _PluginRegistry() + registry2 = _PluginRegistry() + + assert registry1 is registry2 + assert registry1 is plugin_manager.registry + + +def test_registry_clear_affects_all_instances(plugin_manager: PluginManager, temp_plugin_dir: Path) -> None: + """Test that clearing registry affects all manager instances.""" + plugin_manager.discover(plugin_dir=temp_plugin_dir) + + assert plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) == 1 + + manager2 = PluginManager() + assert manager2.num_plugins(PluginType.COLUMN_GENERATOR) == 1 + + plugin_manager.registry.clear() + + assert plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) == 0 + assert manager2.num_plugins(PluginType.COLUMN_GENERATOR) == 0 + + +# ============================================================================= +# Integration Tests +# ============================================================================= + + +def test_full_plugin_workflow(plugin_manager: PluginManager, tmp_path: Path) -> None: + """Test complete workflow: discover → retrieve → validate plugin properties.""" + plugin_dir = tmp_path / "plugins" + plugin_dir.mkdir() + + create_plugin_file(plugin_dir, "workflow_plugin.py", "WorkflowPlugin", "workflow-plugin") + + plugin_manager.discover(plugin_dir=plugin_dir) + + assert plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) == 1 + assert "workflow-plugin" in plugin_manager.get_plugin_names(PluginType.COLUMN_GENERATOR) + + plugin = plugin_manager.get_plugin("workflow-plugin") + assert plugin.name == "workflow-plugin" + assert plugin.enum_key == "WORKFLOW_PLUGIN" + assert plugin.plugin_type == PluginType.COLUMN_GENERATOR + assert plugin.config_cls.__name__ == "WorkflowPluginConfig" + assert plugin.task_cls.__name__ == "WorkflowPluginTask" diff --git a/uv.lock b/uv.lock index b9f73f9a..5243435f 100644 --- a/uv.lock +++ b/uv.lock @@ -766,6 +766,7 @@ dev = [ { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, + { name = "pytest-env" }, { name = "pytest-httpx" }, ] docs = [ @@ -823,6 +824,7 @@ dev = [ { name = "pytest", specifier = ">=8.3.3" }, { name = "pytest-asyncio", specifier = ">=0.24.0" }, { name = "pytest-cov", specifier = ">=7.0.0" }, + { name = "pytest-env", specifier = ">=1.2.0" }, { name = "pytest-httpx", specifier = ">=0.35.0" }, ] docs = [ @@ -3398,6 +3400,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" }, ] +[[package]] +name = "pytest-env" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/13/12/9c87d0ca45d5992473208bcef2828169fa7d39b8d7fc6e3401f5c08b8bf7/pytest_env-1.2.0.tar.gz", hash = "sha256:475e2ebe8626cee01f491f304a74b12137742397d6c784ea4bc258f069232b80", size = 8973, upload-time = "2025-10-09T19:15:47.42Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/27/98/822b924a4a3eb58aacba84444c7439fce32680592f394de26af9c76e2569/pytest_env-1.2.0-py3-none-any.whl", hash = "sha256:d7e5b7198f9b83c795377c09feefa45d56083834e60d04767efd64819fc9da00", size = 6251, upload-time = "2025-10-09T19:15:46.077Z" }, +] + [[package]] name = "pytest-httpx" version = "0.35.0" From ef5ece6d29f005fd9cc6ab08b42e764b70d90c47 Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Sat, 8 Nov 2025 14:33:40 -0500 Subject: [PATCH 04/27] fix config integration --- src/data_designer/config/column_types.py | 34 ++++++++++++++++++---- src/data_designer/config/config_builder.py | 22 ++------------ src/data_designer/plugins/manager.py | 16 ++-------- 3 files changed, 34 insertions(+), 38 deletions(-) diff --git a/src/data_designer/config/column_types.py b/src/data_designer/config/column_types.py index 305c0304..68b0245c 100644 --- a/src/data_designer/config/column_types.py +++ b/src/data_designer/config/column_types.py @@ -23,6 +23,9 @@ if can_run_data_designer_locally(): from data_designer.plugins.manager import PluginManager, PluginType + plugin_manager = PluginManager().discover() + + ColumnConfigT: TypeAlias = Union[ ExpressionColumnConfig, LLMCodeColumnConfig, @@ -36,9 +39,8 @@ if can_run_data_designer_locally(): - pm = PluginManager() - if pm.num_plugins(PluginType.COLUMN_GENERATOR) > 0: - ColumnConfigT = pm.update_type_union(ColumnConfigT, PluginType.COLUMN_GENERATOR) + if plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) > 0: + ColumnConfigT = plugin_manager.update_type_union(ColumnConfigT, PluginType.COLUMN_GENERATOR) DataDesignerColumnType = create_str_enum_from_discriminated_type_union( @@ -59,12 +61,15 @@ DataDesignerColumnType.SAMPLER: "🎲", DataDesignerColumnType.VALIDATION: "🔍", } +if can_run_data_designer_locally(): + for plugin in plugin_manager.get_plugins(PluginType.COLUMN_GENERATOR): + COLUMN_TYPE_EMOJI_MAP[DataDesignerColumnType(plugin.name)] = plugin.emoji def column_type_used_in_execution_dag(column_type: Union[str, DataDesignerColumnType]) -> bool: """Return True if the column type is used in the workflow execution DAG.""" column_type = resolve_string_enum(column_type, DataDesignerColumnType) - return column_type in { + dag_column_types = { DataDesignerColumnType.EXPRESSION, DataDesignerColumnType.LLM_CODE, DataDesignerColumnType.LLM_JUDGE, @@ -72,17 +77,26 @@ def column_type_used_in_execution_dag(column_type: Union[str, DataDesignerColumn DataDesignerColumnType.LLM_TEXT, DataDesignerColumnType.VALIDATION, } + if can_run_data_designer_locally(): + for plugin in plugin_manager.get_plugins(PluginType.COLUMN_GENERATOR): + dag_column_types.add(DataDesignerColumnType(plugin.name)) + return column_type in dag_column_types def column_type_is_llm_generated(column_type: Union[str, DataDesignerColumnType]) -> bool: """Return True if the column type is an LLM-generated column.""" column_type = resolve_string_enum(column_type, DataDesignerColumnType) - return column_type in { + llm_generated_column_types = { DataDesignerColumnType.LLM_TEXT, DataDesignerColumnType.LLM_CODE, DataDesignerColumnType.LLM_STRUCTURED, DataDesignerColumnType.LLM_JUDGE, } + if can_run_data_designer_locally(): + for plugin in plugin_manager.get_plugins(PluginType.COLUMN_GENERATOR): + if "model_registry" in (plugin.task_cls.metadata().required_resources or []): + llm_generated_column_types.add(DataDesignerColumnType(plugin.name)) + return column_type in llm_generated_column_types def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType, **kwargs) -> ColumnConfigT: @@ -113,12 +127,16 @@ def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType return SamplerColumnConfig(name=name, **_resolve_sampler_kwargs(name, kwargs)) elif column_type == DataDesignerColumnType.SEED_DATASET: return SeedDatasetColumnConfig(name=name, **kwargs) + elif can_run_data_designer_locally() and column_type.value in plugin_manager.get_plugin_names( + PluginType.COLUMN_GENERATOR + ): + return plugin_manager.get_plugin(column_type.value).config_cls(name=name, **kwargs) raise InvalidColumnTypeError(f"🛑 {column_type} is not a valid column type.") # pragma: no cover def get_column_display_order() -> list[DataDesignerColumnType]: """Return the preferred display order of the column types.""" - return [ + display_order = [ DataDesignerColumnType.SEED_DATASET, DataDesignerColumnType.SAMPLER, DataDesignerColumnType.LLM_TEXT, @@ -128,6 +146,10 @@ def get_column_display_order() -> list[DataDesignerColumnType]: DataDesignerColumnType.VALIDATION, DataDesignerColumnType.EXPRESSION, ] + if can_run_data_designer_locally(): + for plugin in plugin_manager.get_plugins(PluginType.COLUMN_GENERATOR): + display_order.append(DataDesignerColumnType(plugin.name)) + return display_order def _resolve_sampler_kwargs(name: str, kwargs: dict) -> dict: diff --git a/src/data_designer/config/config_builder.py b/src/data_designer/config/config_builder.py index a3bad644..ad7b4e13 100644 --- a/src/data_designer/config/config_builder.py +++ b/src/data_designer/config/config_builder.py @@ -21,6 +21,7 @@ DataDesignerColumnType, column_type_is_llm_generated, get_column_config_from_kwargs, + get_column_display_order, ) from .data_designer_config import DataDesignerConfig from .dataset_builders import BuildStage @@ -55,7 +56,7 @@ from .utils.validation import ViolationLevel, rich_print_violations, validate_data_designer_config if can_run_data_designer_locally(): - from data_designer.plugins.manager import PluginManager, PluginType + from data_designer.plugins.manager import PluginManager plugin_manager = PluginManager() @@ -629,24 +630,7 @@ def __repr__(self) -> str: "seed_dataset": (None if self._seed_config is None else f"'{self._seed_config.dataset}'"), } - for column_type in [ - DataDesignerColumnType.SEED_DATASET, - DataDesignerColumnType.SAMPLER, - DataDesignerColumnType.LLM_TEXT, - DataDesignerColumnType.LLM_CODE, - DataDesignerColumnType.LLM_STRUCTURED, - DataDesignerColumnType.LLM_JUDGE, - DataDesignerColumnType.VALIDATION, - DataDesignerColumnType.EXPRESSION, - *( - [] - if not can_run_data_designer_locally() - else [ - DataDesignerColumnType(name) - for name in plugin_manager.get_plugin_names(PluginType.COLUMN_GENERATOR) - ] - ), - ]: + for column_type in get_column_display_order(): columns = self.get_columns_of_type(column_type) if len(columns) > 0: column_label = f"{kebab_to_snake(column_type.value)}_columns" diff --git a/src/data_designer/plugins/manager.py b/src/data_designer/plugins/manager.py index 2b3b6de6..fdba3bb4 100644 --- a/src/data_designer/plugins/manager.py +++ b/src/data_designer/plugins/manager.py @@ -15,18 +15,6 @@ logger = logging.getLogger(__name__) -def _get_default_plugin_directory() -> Path: - """Get the default plugin directory from environment or user's home directory. - - This function is called at runtime rather than at module import time, - allowing tests to override the plugin directory via environment variables. - """ - env_dir = os.getenv("DATA_DESIGNER_PLUGIN_DIR") - if env_dir: - return Path(env_dir) - return Path.home() / ".data_designer" / "plugins" - - class PluginManager: def __init__(self): self.registry = _PluginRegistry() @@ -49,7 +37,9 @@ def update_type_union(self, type_union: Type[TypeAlias], plugin_type: PluginType return type_union def discover(self, plugin_dir: Optional[Path] = None) -> Self: - plugin_dir = Path(plugin_dir or _get_default_plugin_directory()) + plugin_dir = Path( + plugin_dir or os.getenv("DATA_DESIGNER_PLUGIN_DIR", Path.home() / ".data_designer" / "plugins") + ) if not plugin_dir.exists(): return self From e648aa9b4d2021c1a544be8ad2eb4613b3170f74 Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Sat, 8 Nov 2025 14:51:47 -0500 Subject: [PATCH 05/27] make base task registry raise on collision false by default --- .../engine/column_generators/registry.py | 31 +++++-------------- src/data_designer/engine/registry/base.py | 2 +- .../engine/registry/data_designer_registry.py | 1 - 3 files changed, 9 insertions(+), 25 deletions(-) diff --git a/src/data_designer/engine/column_generators/registry.py b/src/data_designer/engine/column_generators/registry.py index 6d099f1c..502025b7 100644 --- a/src/data_designer/engine/column_generators/registry.py +++ b/src/data_designer/engine/column_generators/registry.py @@ -34,27 +34,12 @@ class ColumnGeneratorRegistry(TaskRegistry[DataDesignerColumnType, ColumnGenerat def create_builtin_column_generator_registry() -> ColumnGeneratorRegistry: registry = ColumnGeneratorRegistry() - registry.register(DataDesignerColumnType.LLM_TEXT, LLMTextCellGenerator, LLMTextColumnConfig, False) - registry.register(DataDesignerColumnType.LLM_CODE, LLMCodeCellGenerator, LLMCodeColumnConfig, False) - registry.register(DataDesignerColumnType.LLM_JUDGE, LLMJudgeCellGenerator, LLMJudgeColumnConfig, False) - registry.register(DataDesignerColumnType.EXPRESSION, ExpressionColumnGenerator, ExpressionColumnConfig, False) - registry.register(DataDesignerColumnType.SAMPLER, SamplerColumnGenerator, SamplerMultiColumnConfig, False) - registry.register( - DataDesignerColumnType.SEED_DATASET, - SeedDatasetColumnGenerator, - SeedDatasetMultiColumnConfig, - False, - ) - registry.register( - DataDesignerColumnType.VALIDATION, - ValidationColumnGenerator, - ValidationColumnConfig, - False, - ) - registry.register( - DataDesignerColumnType.LLM_STRUCTURED, - LLMStructuredCellGenerator, - LLMStructuredColumnConfig, - False, - ) + registry.register(DataDesignerColumnType.LLM_TEXT, LLMTextCellGenerator, LLMTextColumnConfig) + registry.register(DataDesignerColumnType.LLM_CODE, LLMCodeCellGenerator, LLMCodeColumnConfig) + registry.register(DataDesignerColumnType.LLM_JUDGE, LLMJudgeCellGenerator, LLMJudgeColumnConfig) + registry.register(DataDesignerColumnType.EXPRESSION, ExpressionColumnGenerator, ExpressionColumnConfig) + registry.register(DataDesignerColumnType.SAMPLER, SamplerColumnGenerator, SamplerMultiColumnConfig) + registry.register(DataDesignerColumnType.SEED_DATASET, SeedDatasetColumnGenerator, SeedDatasetMultiColumnConfig) + registry.register(DataDesignerColumnType.VALIDATION, ValidationColumnGenerator, ValidationColumnConfig) + registry.register(DataDesignerColumnType.LLM_STRUCTURED, LLMStructuredCellGenerator, LLMStructuredColumnConfig) return registry diff --git a/src/data_designer/engine/registry/base.py b/src/data_designer/engine/registry/base.py index 5f780940..47e51b85 100644 --- a/src/data_designer/engine/registry/base.py +++ b/src/data_designer/engine/registry/base.py @@ -35,7 +35,7 @@ def register( name: EnumNameT, task: Type[TaskT], config: Type[TaskConfigT], - raise_on_collision: bool = True, + raise_on_collision: bool = False, ) -> None: if cls._has_been_registered(name): if not raise_on_collision: diff --git a/src/data_designer/engine/registry/data_designer_registry.py b/src/data_designer/engine/registry/data_designer_registry.py index 763b0d46..3d6c5dc8 100644 --- a/src/data_designer/engine/registry/data_designer_registry.py +++ b/src/data_designer/engine/registry/data_designer_registry.py @@ -33,7 +33,6 @@ def __init__( DataDesignerColumnType(plugin.name), plugin.task_cls, plugin.config_cls, - raise_on_collision=True, ) @property From 15ed0f8c37eae5a1938d99768d77229723c9142f Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Sat, 8 Nov 2025 15:36:10 -0500 Subject: [PATCH 06/27] update registry test after raise on collision default update --- tests/engine/registry/test_base.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/engine/registry/test_base.py b/tests/engine/registry/test_base.py index 73d6a981..3e653de1 100644 --- a/tests/engine/registry/test_base.py +++ b/tests/engine/registry/test_base.py @@ -94,12 +94,13 @@ def test_register_task_scenarios( TaskRegistry.register(stub_test_enum.TASK_A, stub_test_task_class, stub_test_config_class) with pytest.raises(expected_error, match="task_a has already been registered!"): - TaskRegistry.register(stub_test_enum.TASK_A, stub_test_task_class, stub_test_config_class) + TaskRegistry.register( + stub_test_enum.TASK_A, stub_test_task_class, stub_test_config_class, raise_on_collision=True + ) elif test_case == "register_task_collision_no_raise": TaskRegistry.register(stub_test_enum.TASK_A, stub_test_task_class, stub_test_config_class) - TaskRegistry.register( - stub_test_enum.TASK_A, stub_test_task_class, stub_test_config_class, raise_on_collision=False - ) + # Default behavior is raise_on_collision=False, so no need to pass it explicitly + TaskRegistry.register(stub_test_enum.TASK_A, stub_test_task_class, stub_test_config_class) @pytest.mark.parametrize( From feb98173797ab74ba67beb16c64db12a366abc10 Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Sat, 8 Nov 2025 21:43:33 -0500 Subject: [PATCH 07/27] make analysis work using general stats calculation --- .../config/analysis/column_statistics.py | 61 ++++++++++++++----- .../config/analysis/dataset_profiler.py | 4 +- .../config/analysis/utils/reporting.py | 1 - .../engine/analysis/column_statistics.py | 21 +------ 4 files changed, 49 insertions(+), 38 deletions(-) diff --git a/src/data_designer/config/analysis/column_statistics.py b/src/data_designer/config/analysis/column_statistics.py index be1d11a9..874b386e 100644 --- a/src/data_designer/config/analysis/column_statistics.py +++ b/src/data_designer/config/analysis/column_statistics.py @@ -5,17 +5,23 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Annotated, Any, Literal, Optional, Union +from typing import Any, Literal, Optional, Union from pandas import Series -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, ConfigDict, create_model, field_validator, model_validator from typing_extensions import Self, TypeAlias from ..column_types import DataDesignerColumnType from ..sampler_params import SamplerType from ..utils.constants import EPSILON +from ..utils.misc import can_run_data_designer_locally from ..utils.numerical_helpers import is_float, is_int, prepare_number_for_reporting +if can_run_data_designer_locally(): + from data_designer.plugins.manager import PluginManager, PluginType + + plugin_manager = PluginManager() + class MissingValue(str, Enum): CALCULATION_FAILED = "--" @@ -238,17 +244,42 @@ def from_series(cls, series: Series) -> Self: ) -ColumnStatisticsT: TypeAlias = Annotated[ - Union[ - GeneralColumnStatistics, - LLMTextColumnStatistics, - LLMCodeColumnStatistics, - LLMStructuredColumnStatistics, - LLMJudgedColumnStatistics, - SamplerColumnStatistics, - SeedDatasetColumnStatistics, - ValidationColumnStatistics, - ExpressionColumnStatistics, - ], - Field(discriminator="column_type"), +ColumnStatisticsT: TypeAlias = Union[ + GeneralColumnStatistics, + LLMTextColumnStatistics, + LLMCodeColumnStatistics, + LLMStructuredColumnStatistics, + LLMJudgedColumnStatistics, + SamplerColumnStatistics, + SeedDatasetColumnStatistics, + ValidationColumnStatistics, + ExpressionColumnStatistics, ] + + +DEFAULT_COLUMN_STATISTICS_MAP = { + DataDesignerColumnType.EXPRESSION: ExpressionColumnStatistics, + DataDesignerColumnType.LLM_CODE: LLMCodeColumnStatistics, + DataDesignerColumnType.LLM_JUDGE: LLMJudgedColumnStatistics, + DataDesignerColumnType.LLM_STRUCTURED: LLMStructuredColumnStatistics, + DataDesignerColumnType.LLM_TEXT: LLMTextColumnStatistics, + DataDesignerColumnType.SAMPLER: SamplerColumnStatistics, + DataDesignerColumnType.SEED_DATASET: SeedDatasetColumnStatistics, + DataDesignerColumnType.VALIDATION: ValidationColumnStatistics, +} + +if can_run_data_designer_locally() and plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) > 0: + for plugin in plugin_manager.get_plugins(PluginType.COLUMN_GENERATOR): + # Dynamically create a statistics class for this plugin using Pydantic's create_model + plugin_stats_cls_name = f"{plugin.enum_key.title().replace('_', '')}ColumnStatistics" + + # Create the class with proper Pydantic field + plugin_stats_cls = create_model( + plugin_stats_cls_name, + __base__=GeneralColumnStatistics, + column_type=(Literal[plugin.name], plugin.name), + ) + + # Add the plugin statistics class to the union + ColumnStatisticsT |= plugin_stats_cls + DEFAULT_COLUMN_STATISTICS_MAP[DataDesignerColumnType(plugin.name)] = plugin_stats_cls diff --git a/src/data_designer/config/analysis/dataset_profiler.py b/src/data_designer/config/analysis/dataset_profiler.py index 058b58d7..72f4d528 100644 --- a/src/data_designer/config/analysis/dataset_profiler.py +++ b/src/data_designer/config/analysis/dataset_profiler.py @@ -3,7 +3,7 @@ from functools import cached_property from pathlib import Path -from typing import Optional, Union +from typing import Annotated, Optional, Union from pydantic import BaseModel, Field, field_validator @@ -18,7 +18,7 @@ class DatasetProfilerResults(BaseModel): num_records: int target_num_records: int - column_statistics: list[ColumnStatisticsT] = Field(..., min_length=1) + column_statistics: list[Annotated[ColumnStatisticsT, Field(discriminator="column_type")]] = Field(..., min_length=1) side_effect_column_names: Optional[list[str]] = None column_profiles: Optional[list[ColumnProfilerResultsT]] = None diff --git a/src/data_designer/config/analysis/utils/reporting.py b/src/data_designer/config/analysis/utils/reporting.py index fb7d116e..e1d2c2cf 100644 --- a/src/data_designer/config/analysis/utils/reporting.py +++ b/src/data_designer/config/analysis/utils/reporting.py @@ -27,7 +27,6 @@ if TYPE_CHECKING: from ...analysis.dataset_profiler import DatasetProfilerResults - HEADER_STYLE = "dim" RULE_STYLE = f"bold {ColorPalette.NVIDIA_GREEN.value}" ACCENT_STYLE = f"bold {ColorPalette.BLUE.value}" diff --git a/src/data_designer/engine/analysis/column_statistics.py b/src/data_designer/engine/analysis/column_statistics.py index d7441434..4b3e4f0e 100644 --- a/src/data_designer/engine/analysis/column_statistics.py +++ b/src/data_designer/engine/analysis/column_statistics.py @@ -11,16 +11,9 @@ from typing_extensions import Self from data_designer.config.analysis.column_statistics import ( + DEFAULT_COLUMN_STATISTICS_MAP, ColumnStatisticsT, - ExpressionColumnStatistics, GeneralColumnStatistics, - LLMCodeColumnStatistics, - LLMJudgedColumnStatistics, - LLMStructuredColumnStatistics, - LLMTextColumnStatistics, - SamplerColumnStatistics, - SeedDatasetColumnStatistics, - ValidationColumnStatistics, ) from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType from data_designer.config.sampler_params import SamplerType, is_numerical_sampler_type @@ -134,18 +127,6 @@ def calculate_validation_column_info(self) -> dict[str, Any]: class ExpressionColumnStatisticsCalculator(GeneralColumnStatisticsCalculator): ... -DEFAULT_COLUMN_STATISTICS_MAP = { - DataDesignerColumnType.EXPRESSION: ExpressionColumnStatistics, - DataDesignerColumnType.LLM_CODE: LLMCodeColumnStatistics, - DataDesignerColumnType.LLM_JUDGE: LLMJudgedColumnStatistics, - DataDesignerColumnType.LLM_STRUCTURED: LLMStructuredColumnStatistics, - DataDesignerColumnType.LLM_TEXT: LLMTextColumnStatistics, - DataDesignerColumnType.SAMPLER: SamplerColumnStatistics, - DataDesignerColumnType.SEED_DATASET: SeedDatasetColumnStatistics, - DataDesignerColumnType.VALIDATION: ValidationColumnStatistics, -} - - ColumnStatisticsCalculatorT: TypeAlias = Union[ ExpressionColumnStatisticsCalculator, ValidationColumnStatisticsCalculator, From a00e858acd3d92ee8d850508c0c079bcfc79a699 Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Sun, 9 Nov 2025 19:38:03 -0500 Subject: [PATCH 08/27] default -> builtin --- .../engine/analysis/column_profilers/registry.py | 2 +- src/data_designer/engine/processing/processors/registry.py | 2 +- tests/engine/analysis/test_errors.py | 4 ++-- tests/engine/processing/processors/test_registry.py | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/data_designer/engine/analysis/column_profilers/registry.py b/src/data_designer/engine/analysis/column_profilers/registry.py index ce022c68..7d6f06ca 100644 --- a/src/data_designer/engine/analysis/column_profilers/registry.py +++ b/src/data_designer/engine/analysis/column_profilers/registry.py @@ -14,7 +14,7 @@ class ColumnProfilerRegistry(TaskRegistry[ColumnProfilerType, ColumnProfiler, ConfigBase]): ... -def create_default_column_profiler_registry() -> ColumnProfilerRegistry: +def create_builtin_column_profiler_registry() -> ColumnProfilerRegistry: registry = ColumnProfilerRegistry() registry.register(ColumnProfilerType.JUDGE_SCORE, JudgeScoreProfiler, JudgeScoreProfilerConfig, False) return registry diff --git a/src/data_designer/engine/processing/processors/registry.py b/src/data_designer/engine/processing/processors/registry.py index dadcbc33..41201551 100644 --- a/src/data_designer/engine/processing/processors/registry.py +++ b/src/data_designer/engine/processing/processors/registry.py @@ -14,7 +14,7 @@ class ProcessorRegistry(TaskRegistry[str, Processor, ConfigBase]): ... -def create_default_processor_registry() -> ProcessorRegistry: +def create_builtin_processor_registry() -> ProcessorRegistry: registry = ProcessorRegistry() registry.register(ProcessorType.DROP_COLUMNS, DropColumnsProcessor, DropColumnsProcessorConfig, False) return registry diff --git a/tests/engine/analysis/test_errors.py b/tests/engine/analysis/test_errors.py index 6c6fb901..ca446596 100644 --- a/tests/engine/analysis/test_errors.py +++ b/tests/engine/analysis/test_errors.py @@ -10,7 +10,7 @@ ) from data_designer.engine.analysis.column_profilers.registry import ( ColumnProfilerRegistry, - create_default_column_profiler_registry, + create_builtin_column_profiler_registry, ) from data_designer.engine.registry.errors import NotFoundInRegistryError @@ -52,7 +52,7 @@ def test_get_nonexistent_profiler(): def test_create_default_registry(): - registry = create_default_column_profiler_registry() + registry = create_builtin_column_profiler_registry() assert isinstance(registry, ColumnProfilerRegistry) assert ColumnProfilerType.JUDGE_SCORE in ColumnProfilerRegistry._registry diff --git a/tests/engine/processing/processors/test_registry.py b/tests/engine/processing/processors/test_registry.py index 41ccf5a8..8fedd392 100644 --- a/tests/engine/processing/processors/test_registry.py +++ b/tests/engine/processing/processors/test_registry.py @@ -5,12 +5,12 @@ from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor from data_designer.engine.processing.processors.registry import ( ProcessorRegistry, - create_default_processor_registry, + create_builtin_processor_registry, ) def test_create_default_processor_registry(): - registry = create_default_processor_registry() + registry = create_builtin_processor_registry() assert isinstance(registry, ProcessorRegistry) assert ProcessorType.DROP_COLUMNS in ProcessorRegistry._registry From 74d33087fcca250a63b33114282f9ff6193402c5 Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Sun, 9 Nov 2025 19:41:44 -0500 Subject: [PATCH 09/27] use entry point approach instead --- src/data_designer/plugins/manager.py | 83 ++++++++++------------------ src/data_designer/plugins/plugin.py | 4 ++ 2 files changed, 32 insertions(+), 55 deletions(-) diff --git a/src/data_designer/plugins/manager.py b/src/data_designer/plugins/manager.py index fdba3bb4..101888ef 100644 --- a/src/data_designer/plugins/manager.py +++ b/src/data_designer/plugins/manager.py @@ -1,11 +1,7 @@ -import importlib.util -import inspect +from importlib.metadata import entry_points import logging -import os -from pathlib import Path -import sys import threading -from typing import Iterator, Optional, Type, TypeAlias +from typing import Type, TypeAlias from typing_extensions import Self @@ -16,8 +12,15 @@ class PluginManager: + _instance = None + _plugins_discovered = False + _lock = threading.Lock() + def __init__(self): self.registry = _PluginRegistry() + if not self._plugins_discovered: + self.discover() + self._plugins_discovered = True def get_plugin(self, plugin_name: str) -> Plugin: return self.registry.get(plugin_name) @@ -36,53 +39,33 @@ def update_type_union(self, type_union: Type[TypeAlias], plugin_type: PluginType type_union |= plugin.config_cls return type_union - def discover(self, plugin_dir: Optional[Path] = None) -> Self: - plugin_dir = Path( - plugin_dir or os.getenv("DATA_DESIGNER_PLUGIN_DIR", Path.home() / ".data_designer" / "plugins") - ) - - if not plugin_dir.exists(): - return self - - for file_path in plugin_dir.rglob("*.py"): - if file_path.name.startswith("_"): - continue - - for plugin in self._iter_plugins_from_file(file_path, plugin_dir): + def discover(self) -> Self: + for ep in entry_points(group="data_designer.plugins"): + try: + plugin = ep.load() if isinstance(plugin, Plugin): - self.registry.register_plugin(plugin) + with self._lock: + self.registry.register_plugin(plugin) logger.info( f"🔌 Plugin discovered ➜ {plugin.plugin_type.value.replace('-', ' ')} " f"{plugin.name.upper().replace('-', '_')} is now available ⚡️" ) + except Exception as e: + logger.warning(f"Failed to load plugin from entry point '{ep.name}': {e}") return self - def _iter_plugins_from_file(self, file_path: Path, plugin_dir: Path) -> Optional[Iterator[Plugin]]: - label = str(file_path.relative_to(plugin_dir)).replace("/", "_").replace(".", "_") - module_name = f"_plugin_{label}" - - try: - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None or spec.loader is None: - return - - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - spec.loader.exec_module(module) - - for _, obj in inspect.getmembers(module): - if isinstance(obj, Plugin): - yield obj - - except Exception: - return + def __new__(cls, *args, **kwargs): + """Plugin manager is a singleton.""" + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super().__new__(cls) + return cls._instance class _PluginRegistry: _plugins: dict[str, Plugin] = {} - _instance = None - _lock = threading.Lock() def get(self, plugin_name: str) -> Plugin: if plugin_name not in self._plugins: @@ -90,20 +73,10 @@ def get(self, plugin_name: str) -> Plugin: return self._plugins[plugin_name] def register_plugin(self, plugin: Plugin) -> None: - with self._lock: - if plugin.name in self._plugins: - raise PluginRegistrationError(f"Plugin '{plugin.name}' already registered.") - self._plugins[plugin.name] = plugin + if plugin.name in self._plugins: + raise PluginRegistrationError(f"Plugin '{plugin.name}' already registered.") + self._plugins[plugin.name] = plugin def clear(self) -> None: """Clear all registered plugins. Primarily for testing purposes.""" - with self._lock: - self._plugins.clear() - - def __new__(cls, *args, **kwargs): - """Plugin registry is a singleton.""" - if not cls._instance: - with cls._lock: - if not cls._instance: - cls._instance = super().__new__(cls) - return cls._instance + self._plugins.clear() diff --git a/src/data_designer/plugins/plugin.py b/src/data_designer/plugins/plugin.py index 6285e8d1..942b5bb3 100644 --- a/src/data_designer/plugins/plugin.py +++ b/src/data_designer/plugins/plugin.py @@ -25,6 +25,10 @@ class Plugin(BaseModel): plugin_type: PluginType emoji: str = "🔌" + @property + def config_type_as_class_name(self) -> str: + return self.enum_key.title().replace("_", "") + @property def enum_key(self) -> str: return self.name.replace("-", "_").upper() From 1ec27fd9345217cb4b207289a2c5a0af73ebff43 Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Sun, 9 Nov 2025 19:42:05 -0500 Subject: [PATCH 10/27] rewire using plugin helpers --- .../config/analysis/column_statistics.py | 32 +++----- src/data_designer/config/column_types.py | 64 ++++++--------- .../config/utils/plugin_helpers.py | 82 +++++++++++++++++++ .../engine/registry/data_designer_registry.py | 12 ++- src/data_designer/plugins/__init__.py | 3 + 5 files changed, 128 insertions(+), 65 deletions(-) create mode 100644 src/data_designer/config/utils/plugin_helpers.py create mode 100644 src/data_designer/plugins/__init__.py diff --git a/src/data_designer/config/analysis/column_statistics.py b/src/data_designer/config/analysis/column_statistics.py index 874b386e..14574a4b 100644 --- a/src/data_designer/config/analysis/column_statistics.py +++ b/src/data_designer/config/analysis/column_statistics.py @@ -13,15 +13,10 @@ from ..column_types import DataDesignerColumnType from ..sampler_params import SamplerType +from ..utils import plugin_helpers from ..utils.constants import EPSILON -from ..utils.misc import can_run_data_designer_locally from ..utils.numerical_helpers import is_float, is_int, prepare_number_for_reporting -if can_run_data_designer_locally(): - from data_designer.plugins.manager import PluginManager, PluginType - - plugin_manager = PluginManager() - class MissingValue(str, Enum): CALCULATION_FAILED = "--" @@ -268,18 +263,17 @@ def from_series(cls, series: Series) -> Self: DataDesignerColumnType.VALIDATION: ValidationColumnStatistics, } -if can_run_data_designer_locally() and plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) > 0: - for plugin in plugin_manager.get_plugins(PluginType.COLUMN_GENERATOR): - # Dynamically create a statistics class for this plugin using Pydantic's create_model - plugin_stats_cls_name = f"{plugin.enum_key.title().replace('_', '')}ColumnStatistics" +for plugin in plugin_helpers.get_plugin_column_configs(): + # Dynamically create a statistics class for this plugin using Pydantic's create_model + plugin_stats_cls_name = f"{plugin.config_type_as_class_name}ColumnStatistics" - # Create the class with proper Pydantic field - plugin_stats_cls = create_model( - plugin_stats_cls_name, - __base__=GeneralColumnStatistics, - column_type=(Literal[plugin.name], plugin.name), - ) + # Create the class with proper Pydantic field + plugin_stats_cls = create_model( + plugin_stats_cls_name, + __base__=GeneralColumnStatistics, + column_type=(Literal[plugin.name], plugin.name), + ) - # Add the plugin statistics class to the union - ColumnStatisticsT |= plugin_stats_cls - DEFAULT_COLUMN_STATISTICS_MAP[DataDesignerColumnType(plugin.name)] = plugin_stats_cls + # Add the plugin statistics class to the union + ColumnStatisticsT |= plugin_stats_cls + DEFAULT_COLUMN_STATISTICS_MAP[DataDesignerColumnType(plugin.name)] = plugin_stats_cls diff --git a/src/data_designer/config/column_types.py b/src/data_designer/config/column_types.py index 68b0245c..b8bc26f8 100644 --- a/src/data_designer/config/column_types.py +++ b/src/data_designer/config/column_types.py @@ -17,15 +17,9 @@ ) from .errors import InvalidColumnTypeError, InvalidConfigError from .sampler_params import SamplerType -from .utils.misc import can_run_data_designer_locally +from .utils import plugin_helpers from .utils.type_helpers import SAMPLER_PARAMS, create_str_enum_from_discriminated_type_union, resolve_string_enum -if can_run_data_designer_locally(): - from data_designer.plugins.manager import PluginManager, PluginType - - plugin_manager = PluginManager().discover() - - ColumnConfigT: TypeAlias = Union[ ExpressionColumnConfig, LLMCodeColumnConfig, @@ -36,12 +30,7 @@ SeedDatasetColumnConfig, ValidationColumnConfig, ] - - -if can_run_data_designer_locally(): - if plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) > 0: - ColumnConfigT = plugin_manager.update_type_union(ColumnConfigT, PluginType.COLUMN_GENERATOR) - +ColumnConfigT = plugin_helpers.inject_into_column_config_type_union(ColumnConfigT) DataDesignerColumnType = create_str_enum_from_discriminated_type_union( enum_name="DataDesignerColumnType", @@ -49,7 +38,6 @@ discriminator_field_name="column_type", ) - COLUMN_TYPE_EMOJI_MAP = { "general": "⚛️", # possible analysis column type DataDesignerColumnType.EXPRESSION: "🧩", @@ -61,9 +49,9 @@ DataDesignerColumnType.SAMPLER: "🎲", DataDesignerColumnType.VALIDATION: "🔍", } -if can_run_data_designer_locally(): - for plugin in plugin_manager.get_plugins(PluginType.COLUMN_GENERATOR): - COLUMN_TYPE_EMOJI_MAP[DataDesignerColumnType(plugin.name)] = plugin.emoji +COLUMN_TYPE_EMOJI_MAP.update( + {DataDesignerColumnType(p.name): p.emoji for p in plugin_helpers.get_plugin_column_configs()} +) def column_type_used_in_execution_dag(column_type: Union[str, DataDesignerColumnType]) -> bool: @@ -77,9 +65,7 @@ def column_type_used_in_execution_dag(column_type: Union[str, DataDesignerColumn DataDesignerColumnType.LLM_TEXT, DataDesignerColumnType.VALIDATION, } - if can_run_data_designer_locally(): - for plugin in plugin_manager.get_plugins(PluginType.COLUMN_GENERATOR): - dag_column_types.add(DataDesignerColumnType(plugin.name)) + dag_column_types.update(plugin_helpers.get_plugin_column_types(DataDesignerColumnType)) return column_type in dag_column_types @@ -92,10 +78,12 @@ def column_type_is_llm_generated(column_type: Union[str, DataDesignerColumnType] DataDesignerColumnType.LLM_STRUCTURED, DataDesignerColumnType.LLM_JUDGE, } - if can_run_data_designer_locally(): - for plugin in plugin_manager.get_plugins(PluginType.COLUMN_GENERATOR): - if "model_registry" in (plugin.task_cls.metadata().required_resources or []): - llm_generated_column_types.add(DataDesignerColumnType(plugin.name)) + llm_generated_column_types.update( + plugin_helpers.get_plugin_column_types( + DataDesignerColumnType, + required_resources=["model_registry"], + ) + ) return column_type in llm_generated_column_types @@ -113,24 +101,22 @@ def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType column_type = resolve_string_enum(column_type, DataDesignerColumnType) if column_type == DataDesignerColumnType.LLM_TEXT: return LLMTextColumnConfig(name=name, **kwargs) - elif column_type == DataDesignerColumnType.LLM_CODE: + if column_type == DataDesignerColumnType.LLM_CODE: return LLMCodeColumnConfig(name=name, **kwargs) - elif column_type == DataDesignerColumnType.LLM_STRUCTURED: + if column_type == DataDesignerColumnType.LLM_STRUCTURED: return LLMStructuredColumnConfig(name=name, **kwargs) - elif column_type == DataDesignerColumnType.LLM_JUDGE: + if column_type == DataDesignerColumnType.LLM_JUDGE: return LLMJudgeColumnConfig(name=name, **kwargs) - elif column_type == DataDesignerColumnType.VALIDATION: + if column_type == DataDesignerColumnType.VALIDATION: return ValidationColumnConfig(name=name, **kwargs) - elif column_type == DataDesignerColumnType.EXPRESSION: + if column_type == DataDesignerColumnType.EXPRESSION: return ExpressionColumnConfig(name=name, **kwargs) - elif column_type == DataDesignerColumnType.SAMPLER: + if column_type == DataDesignerColumnType.SAMPLER: return SamplerColumnConfig(name=name, **_resolve_sampler_kwargs(name, kwargs)) - elif column_type == DataDesignerColumnType.SEED_DATASET: + if column_type == DataDesignerColumnType.SEED_DATASET: return SeedDatasetColumnConfig(name=name, **kwargs) - elif can_run_data_designer_locally() and column_type.value in plugin_manager.get_plugin_names( - PluginType.COLUMN_GENERATOR - ): - return plugin_manager.get_plugin(column_type.value).config_cls(name=name, **kwargs) + if plugin := plugin_helpers.get_plugin_column_config_if_available(column_type.value): + return plugin.config_cls(name=name, **kwargs) raise InvalidColumnTypeError(f"🛑 {column_type} is not a valid column type.") # pragma: no cover @@ -146,9 +132,7 @@ def get_column_display_order() -> list[DataDesignerColumnType]: DataDesignerColumnType.VALIDATION, DataDesignerColumnType.EXPRESSION, ] - if can_run_data_designer_locally(): - for plugin in plugin_manager.get_plugins(PluginType.COLUMN_GENERATOR): - display_order.append(DataDesignerColumnType(plugin.name)) + display_order.extend(plugin_helpers.get_plugin_column_types(DataDesignerColumnType)) return display_order @@ -170,7 +154,9 @@ def _resolve_sampler_kwargs(name: str, kwargs: dict) -> dict: else: # params is neither dict nor expected type raise InvalidConfigError( - f"🛑 Invalid params for sampler column '{name}'. Expected a dictionary or an instance of {expected_params_class.__name__}." + f"🛑 Invalid params for sampler column '{name}'. " + f"Expected a dictionary or an instance of {expected_params_class.__name__}. " + f"You provided {params_value=}." ) return { diff --git a/src/data_designer/config/utils/plugin_helpers.py b/src/data_designer/config/utils/plugin_helpers.py new file mode 100644 index 00000000..2b22c1e2 --- /dev/null +++ b/src/data_designer/config/utils/plugin_helpers.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING, Type, TypeAlias + +from .misc import can_run_data_designer_locally + +if TYPE_CHECKING: + from data_designer.plugins.manager import PluginManager + from data_designer.plugins.plugin import Plugin + + +plugin_manager = None +if can_run_data_designer_locally(): + from data_designer.plugins.manager import PluginManager, PluginType + + plugin_manager = PluginManager() + + +def get_plugin_column_configs() -> list[Plugin]: + """Get all plugin column configs. + + Returns: + A list of all plugin column configs. + """ + if plugin_manager: + return [ + plugin_manager.get_plugin(plugin_name) + for plugin_name in plugin_manager.get_plugin_names(PluginType.COLUMN_GENERATOR) + ] + return [] + + +def get_plugin_column_config_if_available(plugin_name: str) -> Plugin | None: + """Get a plugin column config by name if available. + + Args: + plugin_name: The name of the plugin to retrieve. + + Returns: + The plugin if found, otherwise None. + """ + if plugin_manager: + for name in plugin_manager.get_plugin_names(PluginType.COLUMN_GENERATOR): + if plugin_name == name: + return plugin_manager.get_plugin(plugin_name) + return None + + +def get_plugin_column_types(enum_type: Type[Enum], required_resources: list[str] | None = None) -> list[Enum]: + """Get a list of plugin column types. + + Args: + enum_type: The enum type to use for plugin entries. + required_resources: If provided, only return plugins with the required resources. + + Returns: + A list of plugin column types. + """ + type_list = [] + if plugin_manager: + for plugin in plugin_manager.get_plugins(PluginType.COLUMN_GENERATOR): + if required_resources: + task_required_resources = plugin.task_cls.metadata().required_resources or [] + if not all(resource in task_required_resources for resource in required_resources): + continue + type_list.append(enum_type(plugin.name)) + return type_list + + +def inject_into_column_config_type_union(column_config_type: Type[TypeAlias]) -> Type[TypeAlias]: + """Inject plugins into the column config type. + + Args: + column_config_type: The column config type to inject plugins into. + + Returns: + The column config type with plugins injected. + """ + if plugin_manager: + column_config_type = plugin_manager.update_type_union(column_config_type, PluginType.COLUMN_GENERATOR) + return column_config_type diff --git a/src/data_designer/engine/registry/data_designer_registry.py b/src/data_designer/engine/registry/data_designer_registry.py index 3d6c5dc8..6b1ae1c3 100644 --- a/src/data_designer/engine/registry/data_designer_registry.py +++ b/src/data_designer/engine/registry/data_designer_registry.py @@ -4,17 +4,15 @@ from data_designer.config.column_types import DataDesignerColumnType from data_designer.engine.analysis.column_profilers.registry import ( ColumnProfilerRegistry, - create_default_column_profiler_registry, + create_builtin_column_profiler_registry, ) from data_designer.engine.column_generators.registry import ( ColumnGeneratorRegistry, create_builtin_column_generator_registry, ) -from data_designer.engine.processing.processors.registry import ProcessorRegistry, create_default_processor_registry +from data_designer.engine.processing.processors.registry import ProcessorRegistry, create_builtin_processor_registry from data_designer.plugins.manager import PluginManager, PluginType -plugin_manager = PluginManager() - class DataDesignerRegistry: def __init__( @@ -25,10 +23,10 @@ def __init__( processor_registry: ProcessorRegistry | None = None, ): self._column_generator_registry = column_generator_registry or create_builtin_column_generator_registry() - self._column_profiler_registry = column_profiler_registry or create_default_column_profiler_registry() - self._processor_registry = processor_registry or create_default_processor_registry() + self._column_profiler_registry = column_profiler_registry or create_builtin_column_profiler_registry() + self._processor_registry = processor_registry or create_builtin_processor_registry() - for plugin in plugin_manager.get_plugins(PluginType.COLUMN_GENERATOR): + for plugin in PluginManager().get_plugins(PluginType.COLUMN_GENERATOR): self._column_generator_registry.register( DataDesignerColumnType(plugin.name), plugin.task_cls, diff --git a/src/data_designer/plugins/__init__.py b/src/data_designer/plugins/__init__.py new file mode 100644 index 00000000..75e3343a --- /dev/null +++ b/src/data_designer/plugins/__init__.py @@ -0,0 +1,3 @@ +from data_designer.plugins.plugin import Plugin, PluginType + +__all__ = ["Plugin", "PluginType"] From e648b0f0e29b18970b86a259f2da1d6f80e2394c Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Sun, 9 Nov 2025 20:13:30 -0500 Subject: [PATCH 11/27] add env var to disable plugins --- pyproject.toml | 2 +- src/data_designer/plugins/manager.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 53df0082..c4fb25b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,7 +91,7 @@ version-file = "src/data_designer/_version.py" testpaths = ["tests"] asyncio_default_fixture_loop_scope = "session" env = [ - "DATA_DESIGNER_PLUGIN_DIR=/tmp/pytest-no-plugins", + "DISABLE_DATA_DESIGNER_PLUGINS=true", ] [tool.uv] diff --git a/src/data_designer/plugins/manager.py b/src/data_designer/plugins/manager.py index 101888ef..ded6d1c1 100644 --- a/src/data_designer/plugins/manager.py +++ b/src/data_designer/plugins/manager.py @@ -1,5 +1,6 @@ from importlib.metadata import entry_points import logging +import os import threading from typing import Type, TypeAlias @@ -11,6 +12,9 @@ logger = logging.getLogger(__name__) +PLUGINS_DISABLED = os.getenv("DISABLE_DATA_DESIGNER_PLUGINS", "false").lower() == "true" + + class PluginManager: _instance = None _plugins_discovered = False @@ -40,6 +44,8 @@ def update_type_union(self, type_union: Type[TypeAlias], plugin_type: PluginType return type_union def discover(self) -> Self: + if PLUGINS_DISABLED: + return self for ep in entry_points(group="data_designer.plugins"): try: plugin = ep.load() From f7e708a9d2b8d9dbca0d6e48b4679a19594fd9cb Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Sun, 9 Nov 2025 20:13:35 -0500 Subject: [PATCH 12/27] fix tests --- tests/engine/registry/test_data_designer_registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/engine/registry/test_data_designer_registry.py b/tests/engine/registry/test_data_designer_registry.py index 56d7cae8..6f131bb6 100644 --- a/tests/engine/registry/test_data_designer_registry.py +++ b/tests/engine/registry/test_data_designer_registry.py @@ -24,7 +24,7 @@ def stub_default_registries(): "data_designer.engine.registry.data_designer_registry.create_builtin_column_generator_registry" ) as mock_gen: with patch( - "data_designer.engine.registry.data_designer_registry.create_default_column_profiler_registry" + "data_designer.engine.registry.data_designer_registry.create_builtin_column_profiler_registry" ) as mock_prof: mock_gen_registry = Mock() mock_prof_registry = Mock() From f3e392e08247544a20d0b1ba18a848a93eeffe66 Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Sun, 9 Nov 2025 20:43:07 -0500 Subject: [PATCH 13/27] update plugin manager tests --- tests/plugins/test_manager.py | 595 ++++++++++++---------------------- 1 file changed, 200 insertions(+), 395 deletions(-) diff --git a/tests/plugins/test_manager.py b/tests/plugins/test_manager.py index 2098cd9d..94c58a2e 100644 --- a/tests/plugins/test_manager.py +++ b/tests/plugins/test_manager.py @@ -1,8 +1,11 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from pathlib import Path +from contextlib import contextmanager +from importlib.metadata import EntryPoint +import threading from typing import Literal +from unittest.mock import MagicMock, patch import pytest @@ -14,514 +17,316 @@ from data_designer.plugins.plugin import Plugin, PluginType # ============================================================================= -# Test Fixtures +# Test Stubs # ============================================================================= -@pytest.fixture -def plugin_manager() -> PluginManager: - """Create a PluginManager with a clean registry. - - This fixture ensures the singleton registry is cleared before and after each test, - preventing state leakage between tests and from any plugins in the default - plugin directory. - """ - manager = PluginManager() - # Clear any plugins that may have been auto-discovered (e.g., from ~/.data_designer/plugins/) - manager.registry.clear() - yield manager - # Cleanup: clear the singleton registry after the test - manager.registry.clear() - - -def create_plugin_file( - dir_path: Path, - filename: str, - plugin_name: str, - column_type: str, - task_name: str | None = None, -) -> Path: - """Helper to create test plugin files with less boilerplate. - - Args: - dir_path: Directory to create the plugin file in - filename: Name of the plugin file (e.g., "test_plugin.py") - plugin_name: Name of the plugin (e.g., "MyPlugin") - column_type: Column type literal value (e.g., "my-plugin") - task_name: Task metadata name (defaults to plugin_name lowercase with underscores) - - Returns: - Path to the created plugin file - """ - if task_name is None: - task_name = plugin_name.lower().replace("-", "_") - - plugin_var_name = plugin_name.lower().replace("-", "_") - - plugin_file = dir_path / filename - plugin_file.write_text( - f""" -from typing import Literal -from data_designer.config.column_configs import SingleColumnConfig -from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata -from data_designer.plugins.plugin import Plugin, PluginType +class StubPluginConfigA(SingleColumnConfig): + column_type: Literal["test-plugin-a"] = "test-plugin-a" -class {plugin_name}Config(SingleColumnConfig): - column_type: Literal["{column_type}"] = "{column_type}" - name: str +class StubPluginConfigB(SingleColumnConfig): + column_type: Literal["test-plugin-b"] = "test-plugin-b" -class {plugin_name}Task(ConfigurableTask[{plugin_name}Config]): +class StubPluginTaskA(ConfigurableTask[StubPluginConfigA]): @staticmethod def metadata() -> ConfigurableTaskMetadata: return ConfigurableTaskMetadata( - name="{task_name}", - description="{plugin_name} task", + name="test_plugin_a", + description="Test plugin A", required_resources=None, ) -{plugin_var_name} = Plugin( - task_cls={plugin_name}Task, - config_cls={plugin_name}Config, - plugin_type=PluginType.COLUMN_GENERATOR, -) -""" - ) - return plugin_file - - -@pytest.fixture -def temp_plugin_dir(tmp_path: Path) -> Path: - """Create a temporary directory with a test plugin file.""" - plugin_dir = tmp_path / "plugins" - plugin_dir.mkdir() - create_plugin_file(plugin_dir, "test_plugin.py", "MyPlugin", "my-plugin") - return plugin_dir - - -@pytest.fixture -def invalid_plugin_dir(tmp_path: Path) -> Path: - """Create a directory with an invalid plugin file.""" - plugin_dir = tmp_path / "invalid_plugins" - plugin_dir.mkdir() - - invalid_file = plugin_dir / "invalid.py" - invalid_file.write_text("import syntax error here") - - return plugin_dir - - -# ============================================================================= -# Plugin Discovery Tests -# ============================================================================= - - -def test_discover_finds_plugin(plugin_manager: PluginManager, temp_plugin_dir: Path) -> None: - """Test that discover() finds and registers plugins in the plugin directory.""" - plugin_manager.discover(plugin_dir=temp_plugin_dir) - - assert plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) == 1 - assert "my-plugin" in plugin_manager.get_plugin_names(PluginType.COLUMN_GENERATOR) - - -@pytest.mark.parametrize( - "dir_setup", - [ - ("empty", lambda tmp_path: (tmp_path / "empty").mkdir() or (tmp_path / "empty")), - ("nonexistent", lambda tmp_path: tmp_path / "does_not_exist"), - ], - ids=["empty_directory", "nonexistent_directory"], -) -def test_discover_handles_missing_or_empty_directories( - plugin_manager: PluginManager, tmp_path: Path, dir_setup: tuple[str, callable] -) -> None: - """Test that discover() handles empty and nonexistent directories gracefully.""" - _, setup_func = dir_setup - plugin_dir = setup_func(tmp_path) - - plugin_manager.discover(plugin_dir=plugin_dir) - - assert plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) == 0 - - -def test_discover_skips_private_files(plugin_manager: PluginManager, tmp_path: Path) -> None: - """Test that discover() skips files starting with underscore.""" - plugin_dir = tmp_path / "plugins" - plugin_dir.mkdir() - - create_plugin_file(plugin_dir, "_private_plugin.py", "PrivatePlugin", "private-plugin") - - plugin_manager.discover(plugin_dir=plugin_dir) - - assert plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) == 0 - - -def test_discover_handles_invalid_files(plugin_manager: PluginManager, invalid_plugin_dir: Path) -> None: - """Test that discover() gracefully handles invalid Python files.""" - plugin_manager.discover(plugin_dir=invalid_plugin_dir) - - assert plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) == 0 - - -def test_discover_finds_multiple_plugins_in_same_file(plugin_manager: PluginManager, tmp_path: Path) -> None: - """Test that discover() can find multiple Plugin instances in the same file.""" - plugin_dir = tmp_path / "plugins" - plugin_dir.mkdir() - - multi_plugin_file = plugin_dir / "multi.py" - multi_plugin_file.write_text( - """ -from typing import Literal -from data_designer.config.column_configs import SingleColumnConfig -from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata -from data_designer.plugins.plugin import Plugin, PluginType - - -class Plugin1Config(SingleColumnConfig): - column_type: Literal["plugin-1"] = "plugin-1" - name: str - - -class Plugin1Task(ConfigurableTask[Plugin1Config]): +class StubPluginTaskB(ConfigurableTask[StubPluginConfigB]): @staticmethod def metadata() -> ConfigurableTaskMetadata: return ConfigurableTaskMetadata( - name="plugin_1", - description="Plugin 1", + name="test_plugin_b", + description="Test plugin B", required_resources=None, ) -class Plugin2Config(SingleColumnConfig): - column_type: Literal["plugin-2"] = "plugin-2" - name: str +# ============================================================================= +# Test Fixtures +# ============================================================================= -class Plugin2Task(ConfigurableTask[Plugin2Config]): - @staticmethod - def metadata() -> ConfigurableTaskMetadata: - return ConfigurableTaskMetadata( - name="plugin_2", - description="Plugin 2", - required_resources=None, - ) - +@pytest.fixture +def plugin_a() -> Plugin: + return Plugin( + task_cls=StubPluginTaskA, + config_cls=StubPluginConfigA, + plugin_type=PluginType.COLUMN_GENERATOR, + ) -plugin1 = Plugin( - task_cls=Plugin1Task, - config_cls=Plugin1Config, - plugin_type=PluginType.COLUMN_GENERATOR, -) -plugin2 = Plugin( - task_cls=Plugin2Task, - config_cls=Plugin2Config, - plugin_type=PluginType.COLUMN_GENERATOR, -) -""" +@pytest.fixture +def plugin_b() -> Plugin: + return Plugin( + task_cls=StubPluginTaskB, + config_cls=StubPluginConfigB, + plugin_type=PluginType.COLUMN_GENERATOR, ) - plugin_manager.discover(plugin_dir=plugin_dir) - assert plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) == 2 - plugin_names = plugin_manager.get_plugin_names(PluginType.COLUMN_GENERATOR) - assert "plugin-1" in plugin_names - assert "plugin-2" in plugin_names +@pytest.fixture(autouse=True) +def clean_plugin_manager() -> None: + """Reset PluginManager singleton state before and after each test.""" + original_instance = PluginManager._instance + original_discovered = PluginManager._plugins_discovered + original_plugins = _PluginRegistry._plugins.copy() + PluginManager._instance = None + PluginManager._plugins_discovered = False + _PluginRegistry._plugins = {} -def test_discover_recursive_search(plugin_manager: PluginManager, tmp_path: Path) -> None: - """Test that discover() recursively searches subdirectories.""" - plugin_dir = tmp_path / "plugins" - plugin_dir.mkdir() + yield - subdir = plugin_dir / "subdir" - subdir.mkdir() + PluginManager._instance = original_instance + PluginManager._plugins_discovered = original_discovered + _PluginRegistry._plugins = original_plugins - create_plugin_file(subdir, "nested.py", "NestedPlugin", "nested-plugin") - plugin_manager.discover(plugin_dir=plugin_dir) +@pytest.fixture +def mock_plugin_discovery(): + """Mock plugin discovery to test with specific entry points.""" - assert plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) == 1 - assert "nested-plugin" in plugin_manager.get_plugin_names(PluginType.COLUMN_GENERATOR) + @contextmanager + def _mock_discovery(entry_points_list): + with patch("data_designer.plugins.manager.PLUGINS_DISABLED", False): + with patch("data_designer.plugins.manager.entry_points", return_value=entry_points_list): + yield + return _mock_discovery -def test_discover_multiple_calls(plugin_manager: PluginManager, tmp_path: Path) -> None: - """Test that discover() can be called multiple times to discover plugins from different directories.""" - dir1 = tmp_path / "plugins1" - dir1.mkdir() - create_plugin_file(dir1, "plugin1.py", "Plugin1", "plugin-1") - dir2 = tmp_path / "plugins2" - dir2.mkdir() - create_plugin_file(dir2, "plugin2.py", "Plugin2", "plugin-2") +@pytest.fixture +def mock_entry_points(plugin_a: Plugin, plugin_b: Plugin) -> list[MagicMock]: + """Create mock entry points for plugin_a and plugin_b.""" + mock_ep_a = MagicMock(spec=EntryPoint) + mock_ep_a.name = "test-plugin-a" + mock_ep_a.load.return_value = plugin_a - plugin_manager.discover(plugin_dir=dir1) - plugin_manager.discover(plugin_dir=dir2) + mock_ep_b = MagicMock(spec=EntryPoint) + mock_ep_b.name = "test-plugin-b" + mock_ep_b.load.return_value = plugin_b - assert plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) == 2 - plugin_names = plugin_manager.get_plugin_names(PluginType.COLUMN_GENERATOR) - assert "plugin-1" in plugin_names - assert "plugin-2" in plugin_names + return [mock_ep_a, mock_ep_b] # ============================================================================= -# Plugin Retrieval Tests +# _PluginRegistry Tests # ============================================================================= -def test_get_plugin_returns_correct_plugin(plugin_manager: PluginManager, temp_plugin_dir: Path) -> None: - """Test that get_plugin() returns the correct plugin by name.""" - plugin_manager.discover(plugin_dir=temp_plugin_dir) - - plugin = plugin_manager.get_plugin("my-plugin") - - assert plugin.name == "my-plugin" - assert plugin.plugin_type == PluginType.COLUMN_GENERATOR - assert plugin.config_cls.__name__ == "MyPluginConfig" - assert plugin.task_cls.__name__ == "MyPluginTask" - +def test_plugin_registry_register_and_get(plugin_a: Plugin) -> None: + """Test plugin registration and retrieval.""" + registry = _PluginRegistry() -def test_get_plugin_raises_not_found_error(plugin_manager: PluginManager) -> None: - """Test that get_plugin() raises PluginNotFoundError for nonexistent plugins.""" - with pytest.raises(PluginNotFoundError, match="Plugin 'nonexistent' not found"): - plugin_manager.get_plugin("nonexistent") - - -def test_get_plugins_returns_plugins_by_type(plugin_manager: PluginManager, tmp_path: Path) -> None: - """Test that get_plugins() returns all plugins of a specific type.""" - plugin_dir = tmp_path / "plugins" - plugin_dir.mkdir() + registry.register_plugin(plugin_a) - create_plugin_file(plugin_dir, "plugin1.py", "Plugin1", "plugin-1") - create_plugin_file(plugin_dir, "plugin2.py", "Plugin2", "plugin-2") + assert registry.get("test-plugin-a") == plugin_a - plugin_manager.discover(plugin_dir=plugin_dir) - plugins = plugin_manager.get_plugins(PluginType.COLUMN_GENERATOR) +def test_plugin_registry_duplicate_raises_error(plugin_a: Plugin) -> None: + """Test duplicate registration raises PluginRegistrationError.""" + registry = _PluginRegistry() + registry.register_plugin(plugin_a) - assert len(plugins) == 2 - plugin_names = [p.name for p in plugins] - assert "plugin-1" in plugin_names - assert "plugin-2" in plugin_names + with pytest.raises(PluginRegistrationError, match="Plugin 'test-plugin-a' already registered"): + registry.register_plugin(plugin_a) -def test_get_plugin_names_returns_all_names(plugin_manager: PluginManager, temp_plugin_dir: Path) -> None: - """Test that get_plugin_names() returns all plugin names for a given type.""" - plugin_manager.discover(plugin_dir=temp_plugin_dir) +def test_plugin_registry_get_nonexistent_raises_error() -> None: + """Test nonexistent plugin raises PluginNotFoundError.""" + registry = _PluginRegistry() - names = plugin_manager.get_plugin_names(PluginType.COLUMN_GENERATOR) + with pytest.raises(PluginNotFoundError, match="Plugin 'nonexistent' not found"): + registry.get("nonexistent") - assert names == ["my-plugin"] +def test_plugin_registry_clear(plugin_a: Plugin, plugin_b: Plugin) -> None: + """Test clear() removes all plugins.""" + registry = _PluginRegistry() + registry.register_plugin(plugin_a) + registry.register_plugin(plugin_b) -def test_num_plugins_returns_count(plugin_manager: PluginManager, temp_plugin_dir: Path) -> None: - """Test that num_plugins() returns the correct count.""" - assert plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) == 0 + registry.clear() - plugin_manager.discover(plugin_dir=temp_plugin_dir) - assert plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) == 1 + with pytest.raises(PluginNotFoundError): + registry.get("test-plugin-a") + with pytest.raises(PluginNotFoundError): + registry.get("test-plugin-b") # ============================================================================= -# Type Union Tests +# PluginManager Singleton Tests # ============================================================================= -def test_update_type_union_adds_config_types(plugin_manager: PluginManager, temp_plugin_dir: Path) -> None: - """Test that update_type_union() adds plugin config classes to the type union.""" - plugin_manager.discover(plugin_dir=temp_plugin_dir) +def test_plugin_manager_is_singleton(mock_plugin_discovery) -> None: + """Test PluginManager returns same instance.""" + with mock_plugin_discovery([]): + manager1 = PluginManager() + manager2 = PluginManager() - # Start with a basic type - type_union = SingleColumnConfig + assert manager1 is manager2 - updated_union = plugin_manager.update_type_union(type_union, PluginType.COLUMN_GENERATOR) - # The union should now include the plugin's config class - plugin = plugin_manager.get_plugin("my-plugin") - assert plugin.config_cls in updated_union.__args__ +def test_plugin_manager_singleton_thread_safety(mock_plugin_discovery) -> None: + """Test PluginManager singleton creation is thread-safe.""" + instances: list[PluginManager] = [] + with mock_plugin_discovery([]): -# ============================================================================= -# Error Handling Tests -# ============================================================================= + def create_manager() -> None: + instances.append(PluginManager()) + threads = [threading.Thread(target=create_manager) for _ in range(10)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() -def test_register_duplicate_plugin_raises_error(plugin_manager: PluginManager, temp_plugin_dir: Path) -> None: - """Test that registering a duplicate plugin raises PluginRegistrationError.""" - plugin_manager.discover(plugin_dir=temp_plugin_dir) - - # Try to discover the same plugin again - with pytest.raises(PluginRegistrationError, match="Plugin 'my-plugin' already registered"): - plugin_manager.discover(plugin_dir=temp_plugin_dir) + assert all(instance is instances[0] for instance in instances) # ============================================================================= -# Plugin Validation Tests +# PluginManager Discovery Tests # ============================================================================= -def test_plugin_with_invalid_discriminator_field() -> None: - """Test that Plugin validation fails when discriminator field is missing.""" +def test_plugin_manager_discovers_plugins( + mock_plugin_discovery, mock_entry_points: list[MagicMock], plugin_a: Plugin, plugin_b: Plugin +) -> None: + """Test PluginManager discovers and loads plugins from entry points.""" + with mock_plugin_discovery(mock_entry_points): + manager = PluginManager() - class InvalidConfig(ConfigBase): - name: str + assert manager.num_plugins(PluginType.COLUMN_GENERATOR) == 2 + assert manager.get_plugin("test-plugin-a") == plugin_a + assert manager.get_plugin("test-plugin-b") == plugin_b - class InvalidTask(ConfigurableTask[InvalidConfig]): - @staticmethod - def metadata() -> ConfigurableTaskMetadata: - return ConfigurableTaskMetadata( - name="invalid", - description="Invalid plugin", - required_resources=None, - ) - with pytest.raises(ValueError, match="Discriminator field 'column_type' not found"): - Plugin( - task_cls=InvalidTask, - config_cls=InvalidConfig, - plugin_type=PluginType.COLUMN_GENERATOR, - ) +def test_plugin_manager_skips_invalid_plugins(mock_plugin_discovery, plugin_a: Plugin) -> None: + """Test PluginManager skips non-Plugin objects during discovery.""" + mock_ep_valid = MagicMock(spec=EntryPoint) + mock_ep_valid.name = "test-plugin-a" + mock_ep_valid.load.return_value = plugin_a + mock_ep_invalid = MagicMock(spec=EntryPoint) + mock_ep_invalid.name = "invalid-plugin" + mock_ep_invalid.load.return_value = "not a plugin" -def test_plugin_with_non_literal_discriminator() -> None: - """Test that Plugin validation fails when discriminator field is not a Literal type.""" + with mock_plugin_discovery([mock_ep_valid, mock_ep_invalid]): + manager = PluginManager() - class NonLiteralConfig(SingleColumnConfig): - column_type: str = "non-literal" # Should be Literal["non-literal"] - name: str + assert manager.num_plugins(PluginType.COLUMN_GENERATOR) == 1 + assert manager.get_plugin("test-plugin-a") == plugin_a - class NonLiteralTask(ConfigurableTask[NonLiteralConfig]): - @staticmethod - def metadata() -> ConfigurableTaskMetadata: - return ConfigurableTaskMetadata( - name="non_literal", - description="Non-literal plugin", - required_resources=None, - ) - with pytest.raises(ValueError, match="Field 'column_type' .* must be a Literal type"): - Plugin( - task_cls=NonLiteralTask, - config_cls=NonLiteralConfig, - plugin_type=PluginType.COLUMN_GENERATOR, - ) +def test_plugin_manager_handles_loading_errors(mock_plugin_discovery, plugin_a: Plugin) -> None: + """Test PluginManager gracefully handles plugin loading errors.""" + mock_ep_valid = MagicMock(spec=EntryPoint) + mock_ep_valid.name = "test-plugin-a" + mock_ep_valid.load.return_value = plugin_a + mock_ep_error = MagicMock(spec=EntryPoint) + mock_ep_error.name = "error-plugin" + mock_ep_error.load.side_effect = Exception("Loading failed") -def test_plugin_with_non_string_discriminator_default() -> None: - """Test that Plugin validation fails when discriminator default is not a string.""" + with mock_plugin_discovery([mock_ep_valid, mock_ep_error]): + manager = PluginManager() - class NonStringConfig(ConfigBase): - column_type: Literal[123] = 123 # Should be a string - name: str + assert manager.num_plugins(PluginType.COLUMN_GENERATOR) == 1 + assert manager.get_plugin("test-plugin-a") == plugin_a - class NonStringTask(ConfigurableTask[NonStringConfig]): - @staticmethod - def metadata() -> ConfigurableTaskMetadata: - return ConfigurableTaskMetadata( - name="non_string", - description="Non-string plugin", - required_resources=None, - ) - with pytest.raises(ValueError, match="The default of 'column_type' must be a string"): - Plugin( - task_cls=NonStringTask, - config_cls=NonStringConfig, - plugin_type=PluginType.COLUMN_GENERATOR, - ) +def test_plugin_manager_discovery_runs_once() -> None: + """Test discovery runs once even with multiple PluginManager instances.""" + mock_entry_points = MagicMock(return_value=[]) + with patch("data_designer.plugins.manager.PLUGINS_DISABLED", False): + with patch("data_designer.plugins.manager.entry_points", mock_entry_points): + PluginManager() + PluginManager() + PluginManager() -def test_plugin_with_invalid_enum_key() -> None: - """Test that Plugin validation fails when discriminator can't be converted to valid enum key.""" + assert mock_entry_points.call_count == 1 - class InvalidEnumKeyConfig(SingleColumnConfig): - column_type: Literal["123-invalid"] = "123-invalid" # Starts with number - name: str - class InvalidEnumKeyTask(ConfigurableTask[InvalidEnumKeyConfig]): - @staticmethod - def metadata() -> ConfigurableTaskMetadata: - return ConfigurableTaskMetadata( - name="invalid_enum", - description="Invalid enum key plugin", - required_resources=None, - ) +def test_plugin_manager_respects_disabled_flag() -> None: + """Test PluginManager respects DISABLE_DATA_DESIGNER_PLUGINS flag.""" + mock_entry_points = MagicMock(return_value=[]) - with pytest.raises(ValueError, match="cannot be converted to a valid enum key"): - Plugin( - task_cls=InvalidEnumKeyTask, - config_cls=InvalidEnumKeyConfig, - plugin_type=PluginType.COLUMN_GENERATOR, - ) - - -def test_plugin_name_property(plugin_manager: PluginManager, temp_plugin_dir: Path) -> None: - """Test that plugin name property correctly extracts name from discriminator field.""" - plugin_manager.discover(plugin_dir=temp_plugin_dir) + with patch("data_designer.plugins.manager.PLUGINS_DISABLED", True): + with patch("data_designer.plugins.manager.entry_points", mock_entry_points): + manager = PluginManager() - plugin = plugin_manager.get_plugin("my-plugin") - assert plugin.name == "my-plugin" + assert mock_entry_points.call_count == 0 + assert manager.num_plugins(PluginType.COLUMN_GENERATOR) == 0 -def test_plugin_enum_key_property(plugin_manager: PluginManager, temp_plugin_dir: Path) -> None: - """Test that plugin enum_key property correctly converts name to enum format.""" - plugin_manager.discover(plugin_dir=temp_plugin_dir) +# ============================================================================= +# PluginManager Query Methods Tests +# ============================================================================= - plugin = plugin_manager.get_plugin("my-plugin") - assert plugin.enum_key == "MY_PLUGIN" +def test_plugin_manager_get_plugin_raises_error(mock_plugin_discovery) -> None: + """Test get_plugin() raises error for nonexistent plugin.""" + with mock_plugin_discovery([]): + manager = PluginManager() -# ============================================================================= -# Registry Singleton Tests -# ============================================================================= + with pytest.raises(PluginNotFoundError, match="Plugin 'nonexistent' not found"): + manager.get_plugin("nonexistent") -def test_registry_is_singleton(plugin_manager: PluginManager) -> None: - """Test that _PluginRegistry is a singleton.""" - registry1 = _PluginRegistry() - registry2 = _PluginRegistry() +def test_plugin_manager_get_plugins_by_type( + mock_plugin_discovery, mock_entry_points: list[MagicMock], plugin_a: Plugin, plugin_b: Plugin +) -> None: + """Test get_plugins() filters by plugin type.""" + with mock_plugin_discovery(mock_entry_points): + manager = PluginManager() + plugins = manager.get_plugins(PluginType.COLUMN_GENERATOR) - assert registry1 is registry2 - assert registry1 is plugin_manager.registry + assert len(plugins) == 2 + assert plugin_a in plugins + assert plugin_b in plugins -def test_registry_clear_affects_all_instances(plugin_manager: PluginManager, temp_plugin_dir: Path) -> None: - """Test that clearing registry affects all manager instances.""" - plugin_manager.discover(plugin_dir=temp_plugin_dir) +def test_plugin_manager_get_plugins_empty(mock_plugin_discovery) -> None: + """Test get_plugins() returns empty list when no plugins match.""" + with mock_plugin_discovery([]): + manager = PluginManager() + plugins = manager.get_plugins(PluginType.COLUMN_GENERATOR) - assert plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) == 1 + assert plugins == [] - manager2 = PluginManager() - assert manager2.num_plugins(PluginType.COLUMN_GENERATOR) == 1 - plugin_manager.registry.clear() +def test_plugin_manager_get_plugin_names(mock_plugin_discovery, mock_entry_points: list[MagicMock]) -> None: + """Test get_plugin_names() returns plugin names by type.""" + with mock_plugin_discovery(mock_entry_points): + manager = PluginManager() + names = manager.get_plugin_names(PluginType.COLUMN_GENERATOR) - assert plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) == 0 - assert manager2.num_plugins(PluginType.COLUMN_GENERATOR) == 0 + assert set(names) == {"test-plugin-a", "test-plugin-b"} # ============================================================================= -# Integration Tests +# PluginManager Type Union Tests # ============================================================================= -def test_full_plugin_workflow(plugin_manager: PluginManager, tmp_path: Path) -> None: - """Test complete workflow: discover → retrieve → validate plugin properties.""" - plugin_dir = tmp_path / "plugins" - plugin_dir.mkdir() - - create_plugin_file(plugin_dir, "workflow_plugin.py", "WorkflowPlugin", "workflow-plugin") - - plugin_manager.discover(plugin_dir=plugin_dir) +def test_plugin_manager_update_type_union(mock_plugin_discovery, mock_entry_points: list[MagicMock]) -> None: + """Test update_type_union() adds plugin config types to union.""" + with mock_plugin_discovery(mock_entry_points): + manager = PluginManager() - assert plugin_manager.num_plugins(PluginType.COLUMN_GENERATOR) == 1 - assert "workflow-plugin" in plugin_manager.get_plugin_names(PluginType.COLUMN_GENERATOR) + type_union: type = ConfigBase + updated_union = manager.update_type_union(type_union, PluginType.COLUMN_GENERATOR) - plugin = plugin_manager.get_plugin("workflow-plugin") - assert plugin.name == "workflow-plugin" - assert plugin.enum_key == "WORKFLOW_PLUGIN" - assert plugin.plugin_type == PluginType.COLUMN_GENERATOR - assert plugin.config_cls.__name__ == "WorkflowPluginConfig" - assert plugin.task_cls.__name__ == "WorkflowPluginTask" + assert StubPluginConfigA in updated_union.__args__ + assert StubPluginConfigB in updated_union.__args__ From 6a9b011edfc1ba1a6a09defc182713aa8b324838 Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Sun, 9 Nov 2025 20:56:14 -0500 Subject: [PATCH 14/27] add tests for plugin helpers --- tests/config/utils/test_plugin_helpers.py | 235 ++++++++++++++++++++++ 1 file changed, 235 insertions(+) create mode 100644 tests/config/utils/test_plugin_helpers.py diff --git a/tests/config/utils/test_plugin_helpers.py b/tests/config/utils/test_plugin_helpers.py new file mode 100644 index 00000000..4111d6e8 --- /dev/null +++ b/tests/config/utils/test_plugin_helpers.py @@ -0,0 +1,235 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from enum import Enum +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from data_designer.config.utils.plugin_helpers import ( + get_plugin_column_config_if_available, + get_plugin_column_configs, + get_plugin_column_types, + inject_into_column_config_type_union, +) + + +class MockPluginType(str, Enum): + """Mock PluginType enum for testing.""" + + COLUMN_GENERATOR = "column-generator" + + @property + def discriminator_field(self) -> str: + return "column_type" + + +def create_mock_plugin(name: str, plugin_type: MockPluginType, resources: list[str] | None = None) -> Mock: + """Create a mock plugin with specified name and resources. + + Args: + name: Plugin name. + plugin_type: Plugin type enum. + resources: List of required resources, or None if no resource requirements. + + Returns: + Mock plugin object. + """ + plugin = Mock() + plugin.name = name + plugin.plugin_type = plugin_type + plugin.config_cls = Mock(name=name) + + mock_task = Mock() + mock_task.metadata = Mock(return_value=Mock(required_resources=resources)) + plugin.task_cls = mock_task + + return plugin + + +@pytest.fixture +def mock_plugin_manager() -> MagicMock: + """Create a mock plugin manager.""" + return MagicMock() + + +@pytest.fixture +def mock_plugins() -> list[Mock]: + """Create mock plugins for testing.""" + return [ + create_mock_plugin("plugin-one", MockPluginType.COLUMN_GENERATOR, ["resource1", "resource2"]), + create_mock_plugin("plugin-two", MockPluginType.COLUMN_GENERATOR, ["resource1"]), + create_mock_plugin("plugin-three", MockPluginType.COLUMN_GENERATOR, ["resource2", "resource3"]), + ] + + +def test_get_plugin_column_configs_with_plugins(mock_plugin_manager: MagicMock, mock_plugins: list[Mock]) -> None: + """Test getting plugin column configs when plugins are available.""" + with patch("data_designer.config.utils.plugin_helpers.plugin_manager", mock_plugin_manager): + mock_plugin_manager.get_plugin_names.return_value = ["plugin-one", "plugin-two"] + mock_plugin_manager.get_plugin.side_effect = lambda name: next(p for p in mock_plugins if p.name == name) + + with patch("data_designer.config.utils.plugin_helpers.PluginType", MockPluginType): + result = get_plugin_column_configs() + + assert len(result) == 2 + assert result[0].name == "plugin-one" + assert result[1].name == "plugin-two" + mock_plugin_manager.get_plugin_names.assert_called_once_with(MockPluginType.COLUMN_GENERATOR) + + +def test_get_plugin_column_configs_no_plugins(mock_plugin_manager: MagicMock) -> None: + """Test getting plugin column configs when no plugins are available.""" + with patch("data_designer.config.utils.plugin_helpers.plugin_manager", mock_plugin_manager): + mock_plugin_manager.get_plugin_names.return_value = [] + + with patch("data_designer.config.utils.plugin_helpers.PluginType", MockPluginType): + result = get_plugin_column_configs() + + assert result == [] + + +def test_get_plugin_column_configs_plugin_manager_disabled() -> None: + """Test getting plugin column configs when plugin_manager is None.""" + with patch("data_designer.config.utils.plugin_helpers.plugin_manager", None): + result = get_plugin_column_configs() + + assert result == [] + + +def test_get_plugin_column_config_if_available_found(mock_plugin_manager: MagicMock, mock_plugins: list[Mock]) -> None: + """Test getting a specific plugin by name when it exists.""" + with patch("data_designer.config.utils.plugin_helpers.plugin_manager", mock_plugin_manager): + mock_plugin_manager.get_plugin_names.return_value = ["plugin-one", "plugin-two"] + mock_plugin_manager.get_plugin.return_value = mock_plugins[0] + + with patch("data_designer.config.utils.plugin_helpers.PluginType", MockPluginType): + result = get_plugin_column_config_if_available("plugin-one") + + assert result is not None + assert result.name == "plugin-one" + mock_plugin_manager.get_plugin.assert_called_once_with("plugin-one") + + +def test_get_plugin_column_config_if_available_not_found(mock_plugin_manager: MagicMock) -> None: + """Test getting a specific plugin by name when it doesn't exist.""" + with patch("data_designer.config.utils.plugin_helpers.plugin_manager", mock_plugin_manager): + mock_plugin_manager.get_plugin_names.return_value = ["plugin-one", "plugin-two"] + + with patch("data_designer.config.utils.plugin_helpers.PluginType", MockPluginType): + result = get_plugin_column_config_if_available("plugin-three") + + assert result is None + mock_plugin_manager.get_plugin.assert_not_called() + + +def test_get_plugin_column_config_if_available_plugin_manager_disabled() -> None: + """Test getting a specific plugin when plugin_manager is None.""" + with patch("data_designer.config.utils.plugin_helpers.plugin_manager", None): + result = get_plugin_column_config_if_available("plugin-one") + + assert result is None + + +def test_get_plugin_column_types_with_plugins(mock_plugin_manager: MagicMock, mock_plugins: list[Mock]) -> None: + """Test getting plugin column types when plugins are available.""" + TestEnum = Enum( + "TestEnum", {plugin.name.replace("-", "_").upper(): plugin.name for plugin in mock_plugins}, type=str + ) + + with patch("data_designer.config.utils.plugin_helpers.plugin_manager", mock_plugin_manager): + mock_plugin_manager.get_plugins.return_value = mock_plugins + + with patch("data_designer.config.utils.plugin_helpers.PluginType", MockPluginType): + result = get_plugin_column_types(TestEnum) + + assert len(result) == 3 + assert all(isinstance(item, TestEnum) for item in result) + mock_plugin_manager.get_plugins.assert_called_once_with(MockPluginType.COLUMN_GENERATOR) + + +def test_get_plugin_column_types_with_resource_filtering(mock_plugin_manager: MagicMock) -> None: + """Test filtering plugins by required resources.""" + plugins = [ + create_mock_plugin("plugin-one", MockPluginType.COLUMN_GENERATOR, ["resource1", "resource2"]), + create_mock_plugin("plugin-two", MockPluginType.COLUMN_GENERATOR, ["resource1"]), + create_mock_plugin("plugin-three", MockPluginType.COLUMN_GENERATOR, ["resource2", "resource3"]), + ] + TestEnum = Enum( + "TestEnum", {"PLUGIN_ONE": "plugin-one", "PLUGIN_TWO": "plugin-two", "PLUGIN_THREE": "plugin-three"}, type=str + ) + + with patch("data_designer.config.utils.plugin_helpers.plugin_manager", mock_plugin_manager): + mock_plugin_manager.get_plugins.return_value = plugins + + with patch("data_designer.config.utils.plugin_helpers.PluginType", MockPluginType): + result = get_plugin_column_types(TestEnum, required_resources=["resource1"]) + + assert len(result) == 2 + assert TestEnum.PLUGIN_ONE in result + assert TestEnum.PLUGIN_TWO in result + + +def test_get_plugin_column_types_plugin_without_required_resources(mock_plugin_manager: MagicMock) -> None: + """Test filtering when plugin has None for required_resources.""" + plugin = create_mock_plugin("plugin-one", MockPluginType.COLUMN_GENERATOR, None) + TestEnum = Enum("TestEnum", {"PLUGIN_ONE": "plugin-one"}, type=str) + + with patch("data_designer.config.utils.plugin_helpers.plugin_manager", mock_plugin_manager): + mock_plugin_manager.get_plugins.return_value = [plugin] + + with patch("data_designer.config.utils.plugin_helpers.PluginType", MockPluginType): + result = get_plugin_column_types(TestEnum, required_resources=["resource1"]) + + assert result == [] + + +def test_get_plugin_column_types_no_plugins(mock_plugin_manager: MagicMock) -> None: + """Test getting plugin column types when no plugins are available.""" + TestEnum = Enum("TestEnum", {}, type=str) + + with patch("data_designer.config.utils.plugin_helpers.plugin_manager", mock_plugin_manager): + mock_plugin_manager.get_plugins.return_value = [] + + with patch("data_designer.config.utils.plugin_helpers.PluginType", MockPluginType): + result = get_plugin_column_types(TestEnum) + + assert result == [] + + +def test_get_plugin_column_types_plugin_manager_disabled() -> None: + """Test getting plugin column types when plugin_manager is None.""" + TestEnum = Enum("TestEnum", {}, type=str) + + with patch("data_designer.config.utils.plugin_helpers.plugin_manager", None): + result = get_plugin_column_types(TestEnum) + + assert result == [] + + +def test_inject_into_column_config_type_union_with_plugins(mock_plugin_manager: MagicMock) -> None: + """Test injecting plugins into column config type union.""" + + class BaseType: + pass + + with patch("data_designer.config.utils.plugin_helpers.plugin_manager", mock_plugin_manager): + mock_plugin_manager.update_type_union.return_value = str | int + + with patch("data_designer.config.utils.plugin_helpers.PluginType", MockPluginType): + result = inject_into_column_config_type_union(BaseType) + + assert result == str | int + mock_plugin_manager.update_type_union.assert_called_once_with(BaseType, MockPluginType.COLUMN_GENERATOR) + + +def test_inject_into_column_config_type_union_plugin_manager_disabled() -> None: + """Test injecting plugins when plugin_manager is None.""" + + class BaseType: + pass + + with patch("data_designer.config.utils.plugin_helpers.plugin_manager", None): + result = inject_into_column_config_type_union(BaseType) + + assert result == BaseType From be273b0aefd924aee3026b62ab990876c211ce9b Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Sun, 9 Nov 2025 20:58:15 -0500 Subject: [PATCH 15/27] update license headers --- src/data_designer/config/utils/plugin_helpers.py | 3 +++ src/data_designer/plugins/__init__.py | 3 +++ src/data_designer/plugins/errors.py | 3 +++ src/data_designer/plugins/manager.py | 3 +++ src/data_designer/plugins/plugin.py | 3 +++ 5 files changed, 15 insertions(+) diff --git a/src/data_designer/config/utils/plugin_helpers.py b/src/data_designer/config/utils/plugin_helpers.py index 2b22c1e2..88e43171 100644 --- a/src/data_designer/config/utils/plugin_helpers.py +++ b/src/data_designer/config/utils/plugin_helpers.py @@ -1,3 +1,6 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + from __future__ import annotations from enum import Enum diff --git a/src/data_designer/plugins/__init__.py b/src/data_designer/plugins/__init__.py index 75e3343a..b7acb81e 100644 --- a/src/data_designer/plugins/__init__.py +++ b/src/data_designer/plugins/__init__.py @@ -1,3 +1,6 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + from data_designer.plugins.plugin import Plugin, PluginType __all__ = ["Plugin", "PluginType"] diff --git a/src/data_designer/plugins/errors.py b/src/data_designer/plugins/errors.py index a57f7565..de6e4435 100644 --- a/src/data_designer/plugins/errors.py +++ b/src/data_designer/plugins/errors.py @@ -1,3 +1,6 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + from data_designer.errors import DataDesignerError diff --git a/src/data_designer/plugins/manager.py b/src/data_designer/plugins/manager.py index ded6d1c1..a85c72e9 100644 --- a/src/data_designer/plugins/manager.py +++ b/src/data_designer/plugins/manager.py @@ -1,3 +1,6 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + from importlib.metadata import entry_points import logging import os diff --git a/src/data_designer/plugins/plugin.py b/src/data_designer/plugins/plugin.py index 942b5bb3..680ae966 100644 --- a/src/data_designer/plugins/plugin.py +++ b/src/data_designer/plugins/plugin.py @@ -1,3 +1,6 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + from enum import Enum from typing import Literal, Type, get_origin From 808fd0cfd8deaf4ebddc0bc0620eaa2470f9e9e8 Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Sun, 9 Nov 2025 20:59:41 -0500 Subject: [PATCH 16/27] add emoji --- src/data_designer/plugins/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/data_designer/plugins/manager.py b/src/data_designer/plugins/manager.py index a85c72e9..aafaac82 100644 --- a/src/data_designer/plugins/manager.py +++ b/src/data_designer/plugins/manager.py @@ -60,7 +60,7 @@ def discover(self) -> Self: f"{plugin.name.upper().replace('-', '_')} is now available ⚡️" ) except Exception as e: - logger.warning(f"Failed to load plugin from entry point '{ep.name}': {e}") + logger.warning(f"🛑 Failed to load plugin from entry point '{ep.name}': {e}") return self From c98750974f7899a88c99ab55de9950aa95ac20eb Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Sun, 9 Nov 2025 22:05:01 -0500 Subject: [PATCH 17/27] not using the pm in the builder code --- src/data_designer/config/config_builder.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/data_designer/config/config_builder.py b/src/data_designer/config/config_builder.py index ad7b4e13..9f1eee86 100644 --- a/src/data_designer/config/config_builder.py +++ b/src/data_designer/config/config_builder.py @@ -47,19 +47,10 @@ from .utils.constants import DEFAULT_REPR_HTML_STYLE, REPR_HTML_TEMPLATE from .utils.info import ConfigBuilderInfo from .utils.io_helpers import serialize_data, smart_load_yaml -from .utils.misc import ( - can_run_data_designer_locally, - json_indent_list_of_strings, - kebab_to_snake, -) +from .utils.misc import can_run_data_designer_locally, json_indent_list_of_strings, kebab_to_snake from .utils.type_helpers import resolve_string_enum from .utils.validation import ViolationLevel, rich_print_violations, validate_data_designer_config -if can_run_data_designer_locally(): - from data_designer.plugins.manager import PluginManager - - plugin_manager = PluginManager() - logger = logging.getLogger(__name__) From dcc6ee84cfae5dd3cc4bc5997d392d09ed8bfb3f Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Mon, 10 Nov 2025 15:43:26 -0500 Subject: [PATCH 18/27] Update src/data_designer/plugins/manager.py Co-authored-by: Nabin Mulepati --- src/data_designer/plugins/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/data_designer/plugins/manager.py b/src/data_designer/plugins/manager.py index aafaac82..41d1bcb2 100644 --- a/src/data_designer/plugins/manager.py +++ b/src/data_designer/plugins/manager.py @@ -78,7 +78,7 @@ class _PluginRegistry: def get(self, plugin_name: str) -> Plugin: if plugin_name not in self._plugins: - raise PluginNotFoundError(f"Plugin '{plugin_name}' not found.") + raise PluginNotFoundError(f"Plugin {plugin_name!r} not found.") return self._plugins[plugin_name] def register_plugin(self, plugin: Plugin) -> None: From 43c8f5d56368b02880d194205128bbe0aeabaabd Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Mon, 10 Nov 2025 15:43:35 -0500 Subject: [PATCH 19/27] Update src/data_designer/plugins/manager.py Co-authored-by: Nabin Mulepati --- src/data_designer/plugins/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/data_designer/plugins/manager.py b/src/data_designer/plugins/manager.py index 41d1bcb2..6c33b50e 100644 --- a/src/data_designer/plugins/manager.py +++ b/src/data_designer/plugins/manager.py @@ -60,7 +60,7 @@ def discover(self) -> Self: f"{plugin.name.upper().replace('-', '_')} is now available ⚡️" ) except Exception as e: - logger.warning(f"🛑 Failed to load plugin from entry point '{ep.name}': {e}") + logger.warning(f"🛑 Failed to load plugin from entry point {ep.name!r}: {e}") return self From 34084adfa0f7ee93055e691a0df5ce5c39c51070 Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Mon, 10 Nov 2025 16:01:12 -0500 Subject: [PATCH 20/27] Update src/data_designer/plugins/manager.py Co-authored-by: Nabin Mulepati --- src/data_designer/plugins/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/data_designer/plugins/manager.py b/src/data_designer/plugins/manager.py index 6c33b50e..414125d8 100644 --- a/src/data_designer/plugins/manager.py +++ b/src/data_designer/plugins/manager.py @@ -83,7 +83,7 @@ def get(self, plugin_name: str) -> Plugin: def register_plugin(self, plugin: Plugin) -> None: if plugin.name in self._plugins: - raise PluginRegistrationError(f"Plugin '{plugin.name}' already registered.") + raise PluginRegistrationError(f"Plugin {plugin.name!r} already registered.") self._plugins[plugin.name] = plugin def clear(self) -> None: From ba06651226f39b12ff067c0c2cf7c1347b05a7b2 Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Mon, 10 Nov 2025 16:26:26 -0500 Subject: [PATCH 21/27] merge plugin registry into the manager --- .../config/utils/plugin_helpers.py | 2 +- src/data_designer/plugins/manager.py | 58 +++++++++-------- tests/config/utils/test_plugin_helpers.py | 4 +- tests/plugins/test_manager.py | 63 ++----------------- 4 files changed, 40 insertions(+), 87 deletions(-) diff --git a/src/data_designer/config/utils/plugin_helpers.py b/src/data_designer/config/utils/plugin_helpers.py index 88e43171..d14bb708 100644 --- a/src/data_designer/config/utils/plugin_helpers.py +++ b/src/data_designer/config/utils/plugin_helpers.py @@ -81,5 +81,5 @@ def inject_into_column_config_type_union(column_config_type: Type[TypeAlias]) -> The column config type with plugins injected. """ if plugin_manager: - column_config_type = plugin_manager.update_type_union(column_config_type, PluginType.COLUMN_GENERATOR) + column_config_type = plugin_manager.add_plugin_types(column_config_type, PluginType.COLUMN_GENERATOR) return column_config_type diff --git a/src/data_designer/plugins/manager.py b/src/data_designer/plugins/manager.py index 414125d8..10debe37 100644 --- a/src/data_designer/plugins/manager.py +++ b/src/data_designer/plugins/manager.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from copy import deepcopy from importlib.metadata import entry_points import logging import os @@ -23,17 +24,42 @@ class PluginManager: _plugins_discovered = False _lock = threading.Lock() + _plugins: dict[str, Plugin] = {} + def __init__(self): - self.registry = _PluginRegistry() if not self._plugins_discovered: self.discover() self._plugins_discovered = True + @classmethod + def reset(cls) -> None: + cls._instance = None + cls._plugins_discovered = False + cls._plugins = {} + + def add_plugin(self, plugin: Plugin) -> None: + if plugin.name in self._plugins: + raise PluginRegistrationError(f"Plugin {plugin.name!r} already added.") + self._plugins[plugin.name] = plugin + + def add_plugin_types(self, type_union: Type[TypeAlias], plugin_type: PluginType) -> Type[TypeAlias]: + for plugin in self.get_plugins(plugin_type): + type_union |= plugin.config_cls + return type_union + + def clear_plugins(self) -> None: + self._plugins.clear() + + def copy_plugins(self) -> dict[str, Plugin]: + return deepcopy(self._plugins) + def get_plugin(self, plugin_name: str) -> Plugin: - return self.registry.get(plugin_name) + if plugin_name not in self._plugins: + raise PluginNotFoundError(f"Plugin {plugin_name!r} not found.") + return self._plugins[plugin_name] def get_plugins(self, plugin_type: PluginType) -> list[Plugin]: - return [plugin for plugin in self.registry._plugins.values() if plugin.plugin_type == plugin_type] + return [plugin for plugin in self._plugins.values() if plugin.plugin_type == plugin_type] def get_plugin_names(self, plugin_type: PluginType) -> list[str]: return [plugin.name for plugin in self.get_plugins(plugin_type)] @@ -41,10 +67,8 @@ def get_plugin_names(self, plugin_type: PluginType) -> list[str]: def num_plugins(self, plugin_type: PluginType) -> int: return len(self.get_plugins(plugin_type)) - def update_type_union(self, type_union: Type[TypeAlias], plugin_type: PluginType) -> Type[TypeAlias]: - for plugin in self.get_plugins(plugin_type): - type_union |= plugin.config_cls - return type_union + def set_plugins(self, plugins: dict[str, Plugin]) -> None: + self._plugins = plugins def discover(self) -> Self: if PLUGINS_DISABLED: @@ -54,7 +78,7 @@ def discover(self) -> Self: plugin = ep.load() if isinstance(plugin, Plugin): with self._lock: - self.registry.register_plugin(plugin) + self.add_plugin(plugin) logger.info( f"🔌 Plugin discovered ➜ {plugin.plugin_type.value.replace('-', ' ')} " f"{plugin.name.upper().replace('-', '_')} is now available ⚡️" @@ -71,21 +95,3 @@ def __new__(cls, *args, **kwargs): if not cls._instance: cls._instance = super().__new__(cls) return cls._instance - - -class _PluginRegistry: - _plugins: dict[str, Plugin] = {} - - def get(self, plugin_name: str) -> Plugin: - if plugin_name not in self._plugins: - raise PluginNotFoundError(f"Plugin {plugin_name!r} not found.") - return self._plugins[plugin_name] - - def register_plugin(self, plugin: Plugin) -> None: - if plugin.name in self._plugins: - raise PluginRegistrationError(f"Plugin {plugin.name!r} already registered.") - self._plugins[plugin.name] = plugin - - def clear(self) -> None: - """Clear all registered plugins. Primarily for testing purposes.""" - self._plugins.clear() diff --git a/tests/config/utils/test_plugin_helpers.py b/tests/config/utils/test_plugin_helpers.py index 4111d6e8..ab4d760f 100644 --- a/tests/config/utils/test_plugin_helpers.py +++ b/tests/config/utils/test_plugin_helpers.py @@ -214,13 +214,13 @@ class BaseType: pass with patch("data_designer.config.utils.plugin_helpers.plugin_manager", mock_plugin_manager): - mock_plugin_manager.update_type_union.return_value = str | int + mock_plugin_manager.add_plugin_types.return_value = str | int with patch("data_designer.config.utils.plugin_helpers.PluginType", MockPluginType): result = inject_into_column_config_type_union(BaseType) assert result == str | int - mock_plugin_manager.update_type_union.assert_called_once_with(BaseType, MockPluginType.COLUMN_GENERATOR) + mock_plugin_manager.add_plugin_types.assert_called_once_with(BaseType, MockPluginType.COLUMN_GENERATOR) def test_inject_into_column_config_type_union_plugin_manager_disabled() -> None: diff --git a/tests/plugins/test_manager.py b/tests/plugins/test_manager.py index 94c58a2e..e5bd8eb9 100644 --- a/tests/plugins/test_manager.py +++ b/tests/plugins/test_manager.py @@ -12,8 +12,8 @@ from data_designer.config.base import ConfigBase from data_designer.config.column_configs import SingleColumnConfig from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata -from data_designer.plugins.errors import PluginNotFoundError, PluginRegistrationError -from data_designer.plugins.manager import PluginManager, _PluginRegistry +from data_designer.plugins.errors import PluginNotFoundError +from data_designer.plugins.manager import PluginManager from data_designer.plugins.plugin import Plugin, PluginType # ============================================================================= @@ -75,19 +75,11 @@ def plugin_b() -> Plugin: @pytest.fixture(autouse=True) def clean_plugin_manager() -> None: """Reset PluginManager singleton state before and after each test.""" - original_instance = PluginManager._instance - original_discovered = PluginManager._plugins_discovered - original_plugins = _PluginRegistry._plugins.copy() - - PluginManager._instance = None - PluginManager._plugins_discovered = False - _PluginRegistry._plugins = {} + PluginManager.reset() yield - PluginManager._instance = original_instance - PluginManager._plugins_discovered = original_discovered - _PluginRegistry._plugins = original_plugins + PluginManager.reset() @pytest.fixture @@ -117,51 +109,6 @@ def mock_entry_points(plugin_a: Plugin, plugin_b: Plugin) -> list[MagicMock]: return [mock_ep_a, mock_ep_b] -# ============================================================================= -# _PluginRegistry Tests -# ============================================================================= - - -def test_plugin_registry_register_and_get(plugin_a: Plugin) -> None: - """Test plugin registration and retrieval.""" - registry = _PluginRegistry() - - registry.register_plugin(plugin_a) - - assert registry.get("test-plugin-a") == plugin_a - - -def test_plugin_registry_duplicate_raises_error(plugin_a: Plugin) -> None: - """Test duplicate registration raises PluginRegistrationError.""" - registry = _PluginRegistry() - registry.register_plugin(plugin_a) - - with pytest.raises(PluginRegistrationError, match="Plugin 'test-plugin-a' already registered"): - registry.register_plugin(plugin_a) - - -def test_plugin_registry_get_nonexistent_raises_error() -> None: - """Test nonexistent plugin raises PluginNotFoundError.""" - registry = _PluginRegistry() - - with pytest.raises(PluginNotFoundError, match="Plugin 'nonexistent' not found"): - registry.get("nonexistent") - - -def test_plugin_registry_clear(plugin_a: Plugin, plugin_b: Plugin) -> None: - """Test clear() removes all plugins.""" - registry = _PluginRegistry() - registry.register_plugin(plugin_a) - registry.register_plugin(plugin_b) - - registry.clear() - - with pytest.raises(PluginNotFoundError): - registry.get("test-plugin-a") - with pytest.raises(PluginNotFoundError): - registry.get("test-plugin-b") - - # ============================================================================= # PluginManager Singleton Tests # ============================================================================= @@ -326,7 +273,7 @@ def test_plugin_manager_update_type_union(mock_plugin_discovery, mock_entry_poin manager = PluginManager() type_union: type = ConfigBase - updated_union = manager.update_type_union(type_union, PluginType.COLUMN_GENERATOR) + updated_union = manager.add_plugin_types(type_union, PluginType.COLUMN_GENERATOR) assert StubPluginConfigA in updated_union.__args__ assert StubPluginConfigB in updated_union.__args__ From 31a1d9b40487b70a0ac131396365e7596e29992d Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Mon, 10 Nov 2025 20:46:52 -0500 Subject: [PATCH 22/27] small pr feedback --- src/data_designer/plugins/manager.py | 2 +- src/data_designer/plugins/plugin.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/data_designer/plugins/manager.py b/src/data_designer/plugins/manager.py index 10debe37..3a734f56 100644 --- a/src/data_designer/plugins/manager.py +++ b/src/data_designer/plugins/manager.py @@ -81,7 +81,7 @@ def discover(self) -> Self: self.add_plugin(plugin) logger.info( f"🔌 Plugin discovered ➜ {plugin.plugin_type.value.replace('-', ' ')} " - f"{plugin.name.upper().replace('-', '_')} is now available ⚡️" + f"{plugin.enum_key_name} is now available ⚡️" ) except Exception as e: logger.warning(f"🛑 Failed to load plugin from entry point {ep.name!r}: {e}") diff --git a/src/data_designer/plugins/plugin.py b/src/data_designer/plugins/plugin.py index 680ae966..aa831846 100644 --- a/src/data_designer/plugins/plugin.py +++ b/src/data_designer/plugins/plugin.py @@ -19,7 +19,7 @@ def discriminator_field(self) -> str: if self == PluginType.COLUMN_GENERATOR: return "column_type" else: - raise ValueError(f"Invalid plugin type: {self}") + raise ValueError(f"Invalid plugin type: {self.value}") class Plugin(BaseModel): @@ -30,10 +30,10 @@ class Plugin(BaseModel): @property def config_type_as_class_name(self) -> str: - return self.enum_key.title().replace("_", "") + return self.enum_key_name.title().replace("_", "") @property - def enum_key(self) -> str: + def enum_key_name(self) -> str: return self.name.replace("-", "_").upper() @property From cd4183b2b07661e3cde0e585a1a4543fd8f9ef08 Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Mon, 10 Nov 2025 21:21:57 -0500 Subject: [PATCH 23/27] client side plugin manager --- .../config/analysis/column_statistics.py | 4 +- src/data_designer/config/column_types.py | 16 +- .../config/utils/plugin_helpers.py | 85 ------- .../engine/registry/data_designer_registry.py | 5 +- src/data_designer/plugin_manager.py | 88 +++++++ .../plugins/{manager.py => registry.py} | 2 +- tests/config/utils/test_plugin_helpers.py | 235 ------------------ ...est_manager.py => test_plugin_registry.py} | 100 ++++---- tests/test_plugin_manager.py | 231 +++++++++++++++++ 9 files changed, 384 insertions(+), 382 deletions(-) delete mode 100644 src/data_designer/config/utils/plugin_helpers.py create mode 100644 src/data_designer/plugin_manager.py rename src/data_designer/plugins/{manager.py => registry.py} (99%) delete mode 100644 tests/config/utils/test_plugin_helpers.py rename tests/plugins/{test_manager.py => test_plugin_registry.py} (71%) create mode 100644 tests/test_plugin_manager.py diff --git a/src/data_designer/config/analysis/column_statistics.py b/src/data_designer/config/analysis/column_statistics.py index 14574a4b..bf00234c 100644 --- a/src/data_designer/config/analysis/column_statistics.py +++ b/src/data_designer/config/analysis/column_statistics.py @@ -11,9 +11,9 @@ from pydantic import BaseModel, ConfigDict, create_model, field_validator, model_validator from typing_extensions import Self, TypeAlias +from ...plugin_manager import PluginManager from ..column_types import DataDesignerColumnType from ..sampler_params import SamplerType -from ..utils import plugin_helpers from ..utils.constants import EPSILON from ..utils.numerical_helpers import is_float, is_int, prepare_number_for_reporting @@ -263,7 +263,7 @@ def from_series(cls, series: Series) -> Self: DataDesignerColumnType.VALIDATION: ValidationColumnStatistics, } -for plugin in plugin_helpers.get_plugin_column_configs(): +for plugin in PluginManager().get_plugin_column_configs(): # Dynamically create a statistics class for this plugin using Pydantic's create_model plugin_stats_cls_name = f"{plugin.config_type_as_class_name}ColumnStatistics" diff --git a/src/data_designer/config/column_types.py b/src/data_designer/config/column_types.py index b8bc26f8..ab1b3d52 100644 --- a/src/data_designer/config/column_types.py +++ b/src/data_designer/config/column_types.py @@ -5,6 +5,7 @@ from typing_extensions import TypeAlias +from ..plugin_manager import PluginManager from .column_configs import ( ExpressionColumnConfig, LLMCodeColumnConfig, @@ -17,9 +18,10 @@ ) from .errors import InvalidColumnTypeError, InvalidConfigError from .sampler_params import SamplerType -from .utils import plugin_helpers from .utils.type_helpers import SAMPLER_PARAMS, create_str_enum_from_discriminated_type_union, resolve_string_enum +plugin_manager = PluginManager() + ColumnConfigT: TypeAlias = Union[ ExpressionColumnConfig, LLMCodeColumnConfig, @@ -30,7 +32,7 @@ SeedDatasetColumnConfig, ValidationColumnConfig, ] -ColumnConfigT = plugin_helpers.inject_into_column_config_type_union(ColumnConfigT) +ColumnConfigT = plugin_manager.inject_into_column_config_type_union(ColumnConfigT) DataDesignerColumnType = create_str_enum_from_discriminated_type_union( enum_name="DataDesignerColumnType", @@ -50,7 +52,7 @@ DataDesignerColumnType.VALIDATION: "🔍", } COLUMN_TYPE_EMOJI_MAP.update( - {DataDesignerColumnType(p.name): p.emoji for p in plugin_helpers.get_plugin_column_configs()} + {DataDesignerColumnType(p.name): p.emoji for p in plugin_manager.get_plugin_column_configs()} ) @@ -65,7 +67,7 @@ def column_type_used_in_execution_dag(column_type: Union[str, DataDesignerColumn DataDesignerColumnType.LLM_TEXT, DataDesignerColumnType.VALIDATION, } - dag_column_types.update(plugin_helpers.get_plugin_column_types(DataDesignerColumnType)) + dag_column_types.update(plugin_manager.get_plugin_column_types(DataDesignerColumnType)) return column_type in dag_column_types @@ -79,7 +81,7 @@ def column_type_is_llm_generated(column_type: Union[str, DataDesignerColumnType] DataDesignerColumnType.LLM_JUDGE, } llm_generated_column_types.update( - plugin_helpers.get_plugin_column_types( + plugin_manager.get_plugin_column_types( DataDesignerColumnType, required_resources=["model_registry"], ) @@ -115,7 +117,7 @@ def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType return SamplerColumnConfig(name=name, **_resolve_sampler_kwargs(name, kwargs)) if column_type == DataDesignerColumnType.SEED_DATASET: return SeedDatasetColumnConfig(name=name, **kwargs) - if plugin := plugin_helpers.get_plugin_column_config_if_available(column_type.value): + if plugin := plugin_manager.get_plugin_column_config_if_available(column_type.value): return plugin.config_cls(name=name, **kwargs) raise InvalidColumnTypeError(f"🛑 {column_type} is not a valid column type.") # pragma: no cover @@ -132,7 +134,7 @@ def get_column_display_order() -> list[DataDesignerColumnType]: DataDesignerColumnType.VALIDATION, DataDesignerColumnType.EXPRESSION, ] - display_order.extend(plugin_helpers.get_plugin_column_types(DataDesignerColumnType)) + display_order.extend(plugin_manager.get_plugin_column_types(DataDesignerColumnType)) return display_order diff --git a/src/data_designer/config/utils/plugin_helpers.py b/src/data_designer/config/utils/plugin_helpers.py deleted file mode 100644 index d14bb708..00000000 --- a/src/data_designer/config/utils/plugin_helpers.py +++ /dev/null @@ -1,85 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from enum import Enum -from typing import TYPE_CHECKING, Type, TypeAlias - -from .misc import can_run_data_designer_locally - -if TYPE_CHECKING: - from data_designer.plugins.manager import PluginManager - from data_designer.plugins.plugin import Plugin - - -plugin_manager = None -if can_run_data_designer_locally(): - from data_designer.plugins.manager import PluginManager, PluginType - - plugin_manager = PluginManager() - - -def get_plugin_column_configs() -> list[Plugin]: - """Get all plugin column configs. - - Returns: - A list of all plugin column configs. - """ - if plugin_manager: - return [ - plugin_manager.get_plugin(plugin_name) - for plugin_name in plugin_manager.get_plugin_names(PluginType.COLUMN_GENERATOR) - ] - return [] - - -def get_plugin_column_config_if_available(plugin_name: str) -> Plugin | None: - """Get a plugin column config by name if available. - - Args: - plugin_name: The name of the plugin to retrieve. - - Returns: - The plugin if found, otherwise None. - """ - if plugin_manager: - for name in plugin_manager.get_plugin_names(PluginType.COLUMN_GENERATOR): - if plugin_name == name: - return plugin_manager.get_plugin(plugin_name) - return None - - -def get_plugin_column_types(enum_type: Type[Enum], required_resources: list[str] | None = None) -> list[Enum]: - """Get a list of plugin column types. - - Args: - enum_type: The enum type to use for plugin entries. - required_resources: If provided, only return plugins with the required resources. - - Returns: - A list of plugin column types. - """ - type_list = [] - if plugin_manager: - for plugin in plugin_manager.get_plugins(PluginType.COLUMN_GENERATOR): - if required_resources: - task_required_resources = plugin.task_cls.metadata().required_resources or [] - if not all(resource in task_required_resources for resource in required_resources): - continue - type_list.append(enum_type(plugin.name)) - return type_list - - -def inject_into_column_config_type_union(column_config_type: Type[TypeAlias]) -> Type[TypeAlias]: - """Inject plugins into the column config type. - - Args: - column_config_type: The column config type to inject plugins into. - - Returns: - The column config type with plugins injected. - """ - if plugin_manager: - column_config_type = plugin_manager.add_plugin_types(column_config_type, PluginType.COLUMN_GENERATOR) - return column_config_type diff --git a/src/data_designer/engine/registry/data_designer_registry.py b/src/data_designer/engine/registry/data_designer_registry.py index 6b1ae1c3..3ecb3e2c 100644 --- a/src/data_designer/engine/registry/data_designer_registry.py +++ b/src/data_designer/engine/registry/data_designer_registry.py @@ -11,7 +11,8 @@ create_builtin_column_generator_registry, ) from data_designer.engine.processing.processors.registry import ProcessorRegistry, create_builtin_processor_registry -from data_designer.plugins.manager import PluginManager, PluginType +from data_designer.plugins.plugin import PluginType +from data_designer.plugins.registry import PluginRegistry class DataDesignerRegistry: @@ -26,7 +27,7 @@ def __init__( self._column_profiler_registry = column_profiler_registry or create_builtin_column_profiler_registry() self._processor_registry = processor_registry or create_builtin_processor_registry() - for plugin in PluginManager().get_plugins(PluginType.COLUMN_GENERATOR): + for plugin in PluginRegistry().get_plugins(PluginType.COLUMN_GENERATOR): self._column_generator_registry.register( DataDesignerColumnType(plugin.name), plugin.task_cls, diff --git a/src/data_designer/plugin_manager.py b/src/data_designer/plugin_manager.py new file mode 100644 index 00000000..79bd5eee --- /dev/null +++ b/src/data_designer/plugin_manager.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING, Type, TypeAlias + +from .config.utils.misc import can_run_data_designer_locally + +if TYPE_CHECKING: + from data_designer.plugins.plugin import Plugin + + +if can_run_data_designer_locally(): + from data_designer.plugins.plugin import PluginType + from data_designer.plugins.registry import PluginRegistry + + +class PluginManager: + def __init__(self): + if can_run_data_designer_locally(): + self._plugins_available = True + self._plugin_registry = PluginRegistry() + else: + self._plugins_available = False + self._plugin_registry = None + + def get_plugin_column_configs(self) -> list[Plugin]: + """Get all plugin column configs. + + Returns: + A list of all plugin column configs. + """ + if self._plugins_available: + return [ + self._plugin_registry.get_plugin(plugin_name) + for plugin_name in self._plugin_registry.get_plugin_names(PluginType.COLUMN_GENERATOR) + ] + return [] + + def get_plugin_column_config_if_available(self, plugin_name: str) -> Plugin | None: + """Get a plugin column config by name if available. + + Args: + plugin_name: The name of the plugin to retrieve. + + Returns: + The plugin if found, otherwise None. + """ + if self._plugins_available: + for name in self._plugin_registry.get_plugin_names(PluginType.COLUMN_GENERATOR): + if plugin_name == name: + return self._plugin_registry.get_plugin(plugin_name) + return None + + def get_plugin_column_types(self, enum_type: Type[Enum], required_resources: list[str] | None = None) -> list[Enum]: + """Get a list of plugin column types. + + Args: + enum_type: The enum type to use for plugin entries. + required_resources: If provided, only return plugins with the required resources. + + Returns: + A list of plugin column types. + """ + type_list = [] + if self._plugins_available: + for plugin in self._plugin_registry.get_plugins(PluginType.COLUMN_GENERATOR): + if required_resources: + task_required_resources = plugin.task_cls.metadata().required_resources or [] + if not all(resource in task_required_resources for resource in required_resources): + continue + type_list.append(enum_type(plugin.name)) + return type_list + + def inject_into_column_config_type_union(self, column_config_type: Type[TypeAlias]) -> Type[TypeAlias]: + """Inject plugins into the column config type. + + Args: + column_config_type: The column config type to inject plugins into. + + Returns: + The column config type with plugins injected. + """ + if self._plugins_available: + column_config_type = self._plugin_registry.add_plugin_types(column_config_type, PluginType.COLUMN_GENERATOR) + return column_config_type diff --git a/src/data_designer/plugins/manager.py b/src/data_designer/plugins/registry.py similarity index 99% rename from src/data_designer/plugins/manager.py rename to src/data_designer/plugins/registry.py index 3a734f56..a0fd86bc 100644 --- a/src/data_designer/plugins/manager.py +++ b/src/data_designer/plugins/registry.py @@ -19,7 +19,7 @@ PLUGINS_DISABLED = os.getenv("DISABLE_DATA_DESIGNER_PLUGINS", "false").lower() == "true" -class PluginManager: +class PluginRegistry: _instance = None _plugins_discovered = False _lock = threading.Lock() diff --git a/tests/config/utils/test_plugin_helpers.py b/tests/config/utils/test_plugin_helpers.py deleted file mode 100644 index ab4d760f..00000000 --- a/tests/config/utils/test_plugin_helpers.py +++ /dev/null @@ -1,235 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from enum import Enum -from unittest.mock import MagicMock, Mock, patch - -import pytest - -from data_designer.config.utils.plugin_helpers import ( - get_plugin_column_config_if_available, - get_plugin_column_configs, - get_plugin_column_types, - inject_into_column_config_type_union, -) - - -class MockPluginType(str, Enum): - """Mock PluginType enum for testing.""" - - COLUMN_GENERATOR = "column-generator" - - @property - def discriminator_field(self) -> str: - return "column_type" - - -def create_mock_plugin(name: str, plugin_type: MockPluginType, resources: list[str] | None = None) -> Mock: - """Create a mock plugin with specified name and resources. - - Args: - name: Plugin name. - plugin_type: Plugin type enum. - resources: List of required resources, or None if no resource requirements. - - Returns: - Mock plugin object. - """ - plugin = Mock() - plugin.name = name - plugin.plugin_type = plugin_type - plugin.config_cls = Mock(name=name) - - mock_task = Mock() - mock_task.metadata = Mock(return_value=Mock(required_resources=resources)) - plugin.task_cls = mock_task - - return plugin - - -@pytest.fixture -def mock_plugin_manager() -> MagicMock: - """Create a mock plugin manager.""" - return MagicMock() - - -@pytest.fixture -def mock_plugins() -> list[Mock]: - """Create mock plugins for testing.""" - return [ - create_mock_plugin("plugin-one", MockPluginType.COLUMN_GENERATOR, ["resource1", "resource2"]), - create_mock_plugin("plugin-two", MockPluginType.COLUMN_GENERATOR, ["resource1"]), - create_mock_plugin("plugin-three", MockPluginType.COLUMN_GENERATOR, ["resource2", "resource3"]), - ] - - -def test_get_plugin_column_configs_with_plugins(mock_plugin_manager: MagicMock, mock_plugins: list[Mock]) -> None: - """Test getting plugin column configs when plugins are available.""" - with patch("data_designer.config.utils.plugin_helpers.plugin_manager", mock_plugin_manager): - mock_plugin_manager.get_plugin_names.return_value = ["plugin-one", "plugin-two"] - mock_plugin_manager.get_plugin.side_effect = lambda name: next(p for p in mock_plugins if p.name == name) - - with patch("data_designer.config.utils.plugin_helpers.PluginType", MockPluginType): - result = get_plugin_column_configs() - - assert len(result) == 2 - assert result[0].name == "plugin-one" - assert result[1].name == "plugin-two" - mock_plugin_manager.get_plugin_names.assert_called_once_with(MockPluginType.COLUMN_GENERATOR) - - -def test_get_plugin_column_configs_no_plugins(mock_plugin_manager: MagicMock) -> None: - """Test getting plugin column configs when no plugins are available.""" - with patch("data_designer.config.utils.plugin_helpers.plugin_manager", mock_plugin_manager): - mock_plugin_manager.get_plugin_names.return_value = [] - - with patch("data_designer.config.utils.plugin_helpers.PluginType", MockPluginType): - result = get_plugin_column_configs() - - assert result == [] - - -def test_get_plugin_column_configs_plugin_manager_disabled() -> None: - """Test getting plugin column configs when plugin_manager is None.""" - with patch("data_designer.config.utils.plugin_helpers.plugin_manager", None): - result = get_plugin_column_configs() - - assert result == [] - - -def test_get_plugin_column_config_if_available_found(mock_plugin_manager: MagicMock, mock_plugins: list[Mock]) -> None: - """Test getting a specific plugin by name when it exists.""" - with patch("data_designer.config.utils.plugin_helpers.plugin_manager", mock_plugin_manager): - mock_plugin_manager.get_plugin_names.return_value = ["plugin-one", "plugin-two"] - mock_plugin_manager.get_plugin.return_value = mock_plugins[0] - - with patch("data_designer.config.utils.plugin_helpers.PluginType", MockPluginType): - result = get_plugin_column_config_if_available("plugin-one") - - assert result is not None - assert result.name == "plugin-one" - mock_plugin_manager.get_plugin.assert_called_once_with("plugin-one") - - -def test_get_plugin_column_config_if_available_not_found(mock_plugin_manager: MagicMock) -> None: - """Test getting a specific plugin by name when it doesn't exist.""" - with patch("data_designer.config.utils.plugin_helpers.plugin_manager", mock_plugin_manager): - mock_plugin_manager.get_plugin_names.return_value = ["plugin-one", "plugin-two"] - - with patch("data_designer.config.utils.plugin_helpers.PluginType", MockPluginType): - result = get_plugin_column_config_if_available("plugin-three") - - assert result is None - mock_plugin_manager.get_plugin.assert_not_called() - - -def test_get_plugin_column_config_if_available_plugin_manager_disabled() -> None: - """Test getting a specific plugin when plugin_manager is None.""" - with patch("data_designer.config.utils.plugin_helpers.plugin_manager", None): - result = get_plugin_column_config_if_available("plugin-one") - - assert result is None - - -def test_get_plugin_column_types_with_plugins(mock_plugin_manager: MagicMock, mock_plugins: list[Mock]) -> None: - """Test getting plugin column types when plugins are available.""" - TestEnum = Enum( - "TestEnum", {plugin.name.replace("-", "_").upper(): plugin.name for plugin in mock_plugins}, type=str - ) - - with patch("data_designer.config.utils.plugin_helpers.plugin_manager", mock_plugin_manager): - mock_plugin_manager.get_plugins.return_value = mock_plugins - - with patch("data_designer.config.utils.plugin_helpers.PluginType", MockPluginType): - result = get_plugin_column_types(TestEnum) - - assert len(result) == 3 - assert all(isinstance(item, TestEnum) for item in result) - mock_plugin_manager.get_plugins.assert_called_once_with(MockPluginType.COLUMN_GENERATOR) - - -def test_get_plugin_column_types_with_resource_filtering(mock_plugin_manager: MagicMock) -> None: - """Test filtering plugins by required resources.""" - plugins = [ - create_mock_plugin("plugin-one", MockPluginType.COLUMN_GENERATOR, ["resource1", "resource2"]), - create_mock_plugin("plugin-two", MockPluginType.COLUMN_GENERATOR, ["resource1"]), - create_mock_plugin("plugin-three", MockPluginType.COLUMN_GENERATOR, ["resource2", "resource3"]), - ] - TestEnum = Enum( - "TestEnum", {"PLUGIN_ONE": "plugin-one", "PLUGIN_TWO": "plugin-two", "PLUGIN_THREE": "plugin-three"}, type=str - ) - - with patch("data_designer.config.utils.plugin_helpers.plugin_manager", mock_plugin_manager): - mock_plugin_manager.get_plugins.return_value = plugins - - with patch("data_designer.config.utils.plugin_helpers.PluginType", MockPluginType): - result = get_plugin_column_types(TestEnum, required_resources=["resource1"]) - - assert len(result) == 2 - assert TestEnum.PLUGIN_ONE in result - assert TestEnum.PLUGIN_TWO in result - - -def test_get_plugin_column_types_plugin_without_required_resources(mock_plugin_manager: MagicMock) -> None: - """Test filtering when plugin has None for required_resources.""" - plugin = create_mock_plugin("plugin-one", MockPluginType.COLUMN_GENERATOR, None) - TestEnum = Enum("TestEnum", {"PLUGIN_ONE": "plugin-one"}, type=str) - - with patch("data_designer.config.utils.plugin_helpers.plugin_manager", mock_plugin_manager): - mock_plugin_manager.get_plugins.return_value = [plugin] - - with patch("data_designer.config.utils.plugin_helpers.PluginType", MockPluginType): - result = get_plugin_column_types(TestEnum, required_resources=["resource1"]) - - assert result == [] - - -def test_get_plugin_column_types_no_plugins(mock_plugin_manager: MagicMock) -> None: - """Test getting plugin column types when no plugins are available.""" - TestEnum = Enum("TestEnum", {}, type=str) - - with patch("data_designer.config.utils.plugin_helpers.plugin_manager", mock_plugin_manager): - mock_plugin_manager.get_plugins.return_value = [] - - with patch("data_designer.config.utils.plugin_helpers.PluginType", MockPluginType): - result = get_plugin_column_types(TestEnum) - - assert result == [] - - -def test_get_plugin_column_types_plugin_manager_disabled() -> None: - """Test getting plugin column types when plugin_manager is None.""" - TestEnum = Enum("TestEnum", {}, type=str) - - with patch("data_designer.config.utils.plugin_helpers.plugin_manager", None): - result = get_plugin_column_types(TestEnum) - - assert result == [] - - -def test_inject_into_column_config_type_union_with_plugins(mock_plugin_manager: MagicMock) -> None: - """Test injecting plugins into column config type union.""" - - class BaseType: - pass - - with patch("data_designer.config.utils.plugin_helpers.plugin_manager", mock_plugin_manager): - mock_plugin_manager.add_plugin_types.return_value = str | int - - with patch("data_designer.config.utils.plugin_helpers.PluginType", MockPluginType): - result = inject_into_column_config_type_union(BaseType) - - assert result == str | int - mock_plugin_manager.add_plugin_types.assert_called_once_with(BaseType, MockPluginType.COLUMN_GENERATOR) - - -def test_inject_into_column_config_type_union_plugin_manager_disabled() -> None: - """Test injecting plugins when plugin_manager is None.""" - - class BaseType: - pass - - with patch("data_designer.config.utils.plugin_helpers.plugin_manager", None): - result = inject_into_column_config_type_union(BaseType) - - assert result == BaseType diff --git a/tests/plugins/test_manager.py b/tests/plugins/test_plugin_registry.py similarity index 71% rename from tests/plugins/test_manager.py rename to tests/plugins/test_plugin_registry.py index e5bd8eb9..4947601d 100644 --- a/tests/plugins/test_manager.py +++ b/tests/plugins/test_plugin_registry.py @@ -13,8 +13,8 @@ from data_designer.config.column_configs import SingleColumnConfig from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata from data_designer.plugins.errors import PluginNotFoundError -from data_designer.plugins.manager import PluginManager from data_designer.plugins.plugin import Plugin, PluginType +from data_designer.plugins.registry import PluginRegistry # ============================================================================= # Test Stubs @@ -73,13 +73,13 @@ def plugin_b() -> Plugin: @pytest.fixture(autouse=True) -def clean_plugin_manager() -> None: - """Reset PluginManager singleton state before and after each test.""" - PluginManager.reset() +def clean_plugin_registry() -> None: + """Reset PluginRegistry singleton state before and after each test.""" + PluginRegistry.reset() yield - PluginManager.reset() + PluginRegistry.reset() @pytest.fixture @@ -88,8 +88,8 @@ def mock_plugin_discovery(): @contextmanager def _mock_discovery(entry_points_list): - with patch("data_designer.plugins.manager.PLUGINS_DISABLED", False): - with patch("data_designer.plugins.manager.entry_points", return_value=entry_points_list): + with patch("data_designer.plugins.registry.PLUGINS_DISABLED", False): + with patch("data_designer.plugins.registry.entry_points", return_value=entry_points_list): yield return _mock_discovery @@ -110,27 +110,27 @@ def mock_entry_points(plugin_a: Plugin, plugin_b: Plugin) -> list[MagicMock]: # ============================================================================= -# PluginManager Singleton Tests +# PluginRegistry Singleton Tests # ============================================================================= -def test_plugin_manager_is_singleton(mock_plugin_discovery) -> None: - """Test PluginManager returns same instance.""" +def test_plugin_registry_is_singleton(mock_plugin_discovery) -> None: + """Test PluginRegistry returns same instance.""" with mock_plugin_discovery([]): - manager1 = PluginManager() - manager2 = PluginManager() + manager1 = PluginRegistry() + manager2 = PluginRegistry() assert manager1 is manager2 -def test_plugin_manager_singleton_thread_safety(mock_plugin_discovery) -> None: - """Test PluginManager singleton creation is thread-safe.""" - instances: list[PluginManager] = [] +def test_plugin_registry_singleton_thread_safety(mock_plugin_discovery) -> None: + """Test PluginRegistry singleton creation is thread-safe.""" + instances: list[PluginRegistry] = [] with mock_plugin_discovery([]): def create_manager() -> None: - instances.append(PluginManager()) + instances.append(PluginRegistry()) threads = [threading.Thread(target=create_manager) for _ in range(10)] for thread in threads: @@ -142,24 +142,24 @@ def create_manager() -> None: # ============================================================================= -# PluginManager Discovery Tests +# PluginRegistry Discovery Tests # ============================================================================= -def test_plugin_manager_discovers_plugins( +def test_plugin_registry_discovers_plugins( mock_plugin_discovery, mock_entry_points: list[MagicMock], plugin_a: Plugin, plugin_b: Plugin ) -> None: - """Test PluginManager discovers and loads plugins from entry points.""" + """Test PluginRegistry discovers and loads plugins from entry points.""" with mock_plugin_discovery(mock_entry_points): - manager = PluginManager() + manager = PluginRegistry() assert manager.num_plugins(PluginType.COLUMN_GENERATOR) == 2 assert manager.get_plugin("test-plugin-a") == plugin_a assert manager.get_plugin("test-plugin-b") == plugin_b -def test_plugin_manager_skips_invalid_plugins(mock_plugin_discovery, plugin_a: Plugin) -> None: - """Test PluginManager skips non-Plugin objects during discovery.""" +def test_plugin_registry_skips_invalid_plugins(mock_plugin_discovery, plugin_a: Plugin) -> None: + """Test PluginRegistry skips non-Plugin objects during discovery.""" mock_ep_valid = MagicMock(spec=EntryPoint) mock_ep_valid.name = "test-plugin-a" mock_ep_valid.load.return_value = plugin_a @@ -169,14 +169,14 @@ def test_plugin_manager_skips_invalid_plugins(mock_plugin_discovery, plugin_a: P mock_ep_invalid.load.return_value = "not a plugin" with mock_plugin_discovery([mock_ep_valid, mock_ep_invalid]): - manager = PluginManager() + manager = PluginRegistry() assert manager.num_plugins(PluginType.COLUMN_GENERATOR) == 1 assert manager.get_plugin("test-plugin-a") == plugin_a -def test_plugin_manager_handles_loading_errors(mock_plugin_discovery, plugin_a: Plugin) -> None: - """Test PluginManager gracefully handles plugin loading errors.""" +def test_plugin_registry_handles_loading_errors(mock_plugin_discovery, plugin_a: Plugin) -> None: + """Test PluginRegistry gracefully handles plugin loading errors.""" mock_ep_valid = MagicMock(spec=EntryPoint) mock_ep_valid.name = "test-plugin-a" mock_ep_valid.load.return_value = plugin_a @@ -186,57 +186,57 @@ def test_plugin_manager_handles_loading_errors(mock_plugin_discovery, plugin_a: mock_ep_error.load.side_effect = Exception("Loading failed") with mock_plugin_discovery([mock_ep_valid, mock_ep_error]): - manager = PluginManager() + manager = PluginRegistry() assert manager.num_plugins(PluginType.COLUMN_GENERATOR) == 1 assert manager.get_plugin("test-plugin-a") == plugin_a -def test_plugin_manager_discovery_runs_once() -> None: - """Test discovery runs once even with multiple PluginManager instances.""" +def test_plugin_registry_discovery_runs_once() -> None: + """Test discovery runs once even with multiple PluginRegistry instances.""" mock_entry_points = MagicMock(return_value=[]) - with patch("data_designer.plugins.manager.PLUGINS_DISABLED", False): - with patch("data_designer.plugins.manager.entry_points", mock_entry_points): - PluginManager() - PluginManager() - PluginManager() + with patch("data_designer.plugins.registry.PLUGINS_DISABLED", False): + with patch("data_designer.plugins.registry.entry_points", mock_entry_points): + PluginRegistry() + PluginRegistry() + PluginRegistry() assert mock_entry_points.call_count == 1 -def test_plugin_manager_respects_disabled_flag() -> None: - """Test PluginManager respects DISABLE_DATA_DESIGNER_PLUGINS flag.""" +def test_plugin_registry_respects_disabled_flag() -> None: + """Test PluginRegistry respects DISABLE_DATA_DESIGNER_PLUGINS flag.""" mock_entry_points = MagicMock(return_value=[]) - with patch("data_designer.plugins.manager.PLUGINS_DISABLED", True): - with patch("data_designer.plugins.manager.entry_points", mock_entry_points): - manager = PluginManager() + with patch("data_designer.plugins.registry.PLUGINS_DISABLED", True): + with patch("data_designer.plugins.registry.entry_points", mock_entry_points): + manager = PluginRegistry() assert mock_entry_points.call_count == 0 assert manager.num_plugins(PluginType.COLUMN_GENERATOR) == 0 # ============================================================================= -# PluginManager Query Methods Tests +# PluginRegistry Query Methods Tests # ============================================================================= -def test_plugin_manager_get_plugin_raises_error(mock_plugin_discovery) -> None: +def test_plugin_registry_get_plugin_raises_error(mock_plugin_discovery) -> None: """Test get_plugin() raises error for nonexistent plugin.""" with mock_plugin_discovery([]): - manager = PluginManager() + manager = PluginRegistry() with pytest.raises(PluginNotFoundError, match="Plugin 'nonexistent' not found"): manager.get_plugin("nonexistent") -def test_plugin_manager_get_plugins_by_type( +def test_plugin_registry_get_plugins_by_type( mock_plugin_discovery, mock_entry_points: list[MagicMock], plugin_a: Plugin, plugin_b: Plugin ) -> None: """Test get_plugins() filters by plugin type.""" with mock_plugin_discovery(mock_entry_points): - manager = PluginManager() + manager = PluginRegistry() plugins = manager.get_plugins(PluginType.COLUMN_GENERATOR) assert len(plugins) == 2 @@ -244,33 +244,33 @@ def test_plugin_manager_get_plugins_by_type( assert plugin_b in plugins -def test_plugin_manager_get_plugins_empty(mock_plugin_discovery) -> None: +def test_plugin_registry_get_plugins_empty(mock_plugin_discovery) -> None: """Test get_plugins() returns empty list when no plugins match.""" with mock_plugin_discovery([]): - manager = PluginManager() + manager = PluginRegistry() plugins = manager.get_plugins(PluginType.COLUMN_GENERATOR) assert plugins == [] -def test_plugin_manager_get_plugin_names(mock_plugin_discovery, mock_entry_points: list[MagicMock]) -> None: +def test_plugin_registry_get_plugin_names(mock_plugin_discovery, mock_entry_points: list[MagicMock]) -> None: """Test get_plugin_names() returns plugin names by type.""" with mock_plugin_discovery(mock_entry_points): - manager = PluginManager() + manager = PluginRegistry() names = manager.get_plugin_names(PluginType.COLUMN_GENERATOR) assert set(names) == {"test-plugin-a", "test-plugin-b"} # ============================================================================= -# PluginManager Type Union Tests +# PluginRegistry Type Union Tests # ============================================================================= -def test_plugin_manager_update_type_union(mock_plugin_discovery, mock_entry_points: list[MagicMock]) -> None: +def test_plugin_registry_update_type_union(mock_plugin_discovery, mock_entry_points: list[MagicMock]) -> None: """Test update_type_union() adds plugin config types to union.""" with mock_plugin_discovery(mock_entry_points): - manager = PluginManager() + manager = PluginRegistry() type_union: type = ConfigBase updated_union = manager.add_plugin_types(type_union, PluginType.COLUMN_GENERATOR) diff --git a/tests/test_plugin_manager.py b/tests/test_plugin_manager.py new file mode 100644 index 00000000..4e5e1b0e --- /dev/null +++ b/tests/test_plugin_manager.py @@ -0,0 +1,231 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Generator +from contextlib import contextmanager +from enum import Enum +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from data_designer.plugin_manager import PluginManager + + +class MockPluginType(str, Enum): + """Mock PluginType enum for testing.""" + + COLUMN_GENERATOR = "column-generator" + + @property + def discriminator_field(self) -> str: + return "column_type" + + +def create_mock_plugin(name: str, plugin_type: MockPluginType, resources: list[str] | None = None) -> Mock: + """Create a mock plugin with specified name and resources. + + Args: + name: Plugin name. + plugin_type: Plugin type enum. + resources: List of required resources, or None if no resource requirements. + + Returns: + Mock plugin object. + """ + plugin = Mock() + plugin.name = name + plugin.plugin_type = plugin_type + plugin.config_cls = Mock(name=name) + + mock_task = Mock() + mock_task.metadata = Mock(return_value=Mock(required_resources=resources)) + plugin.task_cls = mock_task + + return plugin + + +@contextmanager +def mock_plugin_system(registry: MagicMock) -> Generator[None, None, None]: + """Context manager to mock the plugin system with a given registry. + + This works regardless of whether the actual environment has plugins available or not + by patching at the module level where PluginManager is instantiated. + """ + with patch("data_designer.plugin_manager.can_run_data_designer_locally", return_value=True): + with patch("data_designer.plugin_manager.PluginRegistry", return_value=registry, create=True): + with patch("data_designer.plugin_manager.PluginType", MockPluginType, create=True): + yield + + +@pytest.fixture +def mock_plugin_registry() -> MagicMock: + """Create a mock plugin registry.""" + return MagicMock() + + +@pytest.fixture +def mock_plugins() -> list[Mock]: + """Create mock plugins for testing.""" + return [ + create_mock_plugin("plugin-one", MockPluginType.COLUMN_GENERATOR, ["resource1", "resource2"]), + create_mock_plugin("plugin-two", MockPluginType.COLUMN_GENERATOR, ["resource1"]), + create_mock_plugin("plugin-three", MockPluginType.COLUMN_GENERATOR, ["resource2", "resource3"]), + ] + + +def test_get_plugin_column_configs_with_plugins(mock_plugin_registry: MagicMock, mock_plugins: list[Mock]) -> None: + """Test getting plugin column configs when plugins are available.""" + mock_plugin_registry.get_plugin_names.return_value = ["plugin-one", "plugin-two"] + mock_plugin_registry.get_plugin.side_effect = lambda name: next(p for p in mock_plugins if p.name == name) + + with mock_plugin_system(mock_plugin_registry): + manager = PluginManager() + result = manager.get_plugin_column_configs() + + assert len(result) == 2 + assert [p.name for p in result] == ["plugin-one", "plugin-two"] + mock_plugin_registry.get_plugin_names.assert_called_once_with(MockPluginType.COLUMN_GENERATOR) + + +@pytest.mark.parametrize("plugins_available", [True, False]) +def test_get_plugin_column_configs_empty(mock_plugin_registry: MagicMock, plugins_available: bool) -> None: + """Test getting plugin column configs when no plugins are registered or system is disabled.""" + if plugins_available: + mock_plugin_registry.get_plugin_names.return_value = [] + with mock_plugin_system(mock_plugin_registry): + manager = PluginManager() + result = manager.get_plugin_column_configs() + else: + with patch("data_designer.plugin_manager.can_run_data_designer_locally", return_value=False): + manager = PluginManager() + result = manager.get_plugin_column_configs() + + assert result == [] + + +def test_get_plugin_column_config_if_available_found(mock_plugin_registry: MagicMock, mock_plugins: list[Mock]) -> None: + """Test getting a specific plugin by name when it exists.""" + mock_plugin_registry.get_plugin_names.return_value = ["plugin-one", "plugin-two"] + mock_plugin_registry.get_plugin.return_value = mock_plugins[0] + + with mock_plugin_system(mock_plugin_registry): + manager = PluginManager() + result = manager.get_plugin_column_config_if_available("plugin-one") + + assert result is not None + assert result.name == "plugin-one" + mock_plugin_registry.get_plugin.assert_called_once_with("plugin-one") + + +def test_get_plugin_column_config_if_available_not_found(mock_plugin_registry: MagicMock) -> None: + """Test getting a specific plugin by name when it doesn't exist.""" + mock_plugin_registry.get_plugin_names.return_value = ["plugin-one", "plugin-two"] + + with mock_plugin_system(mock_plugin_registry): + manager = PluginManager() + result = manager.get_plugin_column_config_if_available("plugin-three") + + assert result is None + mock_plugin_registry.get_plugin.assert_not_called() + + +def test_get_plugin_column_config_if_available_when_disabled() -> None: + """Test getting a specific plugin when plugin system is disabled.""" + with patch("data_designer.plugin_manager.can_run_data_designer_locally", return_value=False): + manager = PluginManager() + result = manager.get_plugin_column_config_if_available("plugin-one") + + assert result is None + + +def test_get_plugin_column_types_with_plugins(mock_plugin_registry: MagicMock, mock_plugins: list[Mock]) -> None: + """Test getting plugin column types when plugins are available.""" + TestEnum = Enum( + "TestEnum", {plugin.name.replace("-", "_").upper(): plugin.name for plugin in mock_plugins}, type=str + ) + mock_plugin_registry.get_plugins.return_value = mock_plugins + + with mock_plugin_system(mock_plugin_registry): + manager = PluginManager() + result = manager.get_plugin_column_types(TestEnum) + + assert len(result) == 3 + assert all(isinstance(item, TestEnum) for item in result) + mock_plugin_registry.get_plugins.assert_called_once_with(MockPluginType.COLUMN_GENERATOR) + + +def test_get_plugin_column_types_with_resource_filtering( + mock_plugin_registry: MagicMock, mock_plugins: list[Mock] +) -> None: + """Test filtering plugins by required resources.""" + TestEnum = Enum( + "TestEnum", {"PLUGIN_ONE": "plugin-one", "PLUGIN_TWO": "plugin-two", "PLUGIN_THREE": "plugin-three"}, type=str + ) + mock_plugin_registry.get_plugins.return_value = mock_plugins + + with mock_plugin_system(mock_plugin_registry): + manager = PluginManager() + result = manager.get_plugin_column_types(TestEnum, required_resources=["resource1"]) + + assert len(result) == 2 + assert set(result) == {TestEnum.PLUGIN_ONE, TestEnum.PLUGIN_TWO} + + +def test_get_plugin_column_types_filters_none_resources(mock_plugin_registry: MagicMock) -> None: + """Test filtering when plugin has None for required_resources.""" + plugin = create_mock_plugin("plugin-one", MockPluginType.COLUMN_GENERATOR, None) + TestEnum = Enum("TestEnum", {"PLUGIN_ONE": "plugin-one"}, type=str) + mock_plugin_registry.get_plugins.return_value = [plugin] + + with mock_plugin_system(mock_plugin_registry): + manager = PluginManager() + result = manager.get_plugin_column_types(TestEnum, required_resources=["resource1"]) + + assert result == [] + + +@pytest.mark.parametrize("plugins_available", [True, False]) +def test_get_plugin_column_types_empty(mock_plugin_registry: MagicMock, plugins_available: bool) -> None: + """Test getting plugin column types when no plugins are registered or system is disabled.""" + TestEnum = Enum("TestEnum", {}, type=str) + + if plugins_available: + mock_plugin_registry.get_plugins.return_value = [] + with mock_plugin_system(mock_plugin_registry): + manager = PluginManager() + result = manager.get_plugin_column_types(TestEnum) + else: + with patch("data_designer.plugin_manager.can_run_data_designer_locally", return_value=False): + manager = PluginManager() + result = manager.get_plugin_column_types(TestEnum) + + assert result == [] + + +def test_inject_into_column_config_type_union_with_plugins(mock_plugin_registry: MagicMock) -> None: + """Test injecting plugins into column config type union.""" + + class BaseType: + pass + + mock_plugin_registry.add_plugin_types.return_value = str | int + + with mock_plugin_system(mock_plugin_registry): + manager = PluginManager() + result = manager.inject_into_column_config_type_union(BaseType) + + assert result == str | int + mock_plugin_registry.add_plugin_types.assert_called_once_with(BaseType, MockPluginType.COLUMN_GENERATOR) + + +def test_inject_into_column_config_type_union_when_disabled() -> None: + """Test injecting plugins when plugin system is disabled.""" + + class BaseType: + pass + + with patch("data_designer.plugin_manager.can_run_data_designer_locally", return_value=False): + manager = PluginManager() + result = manager.inject_into_column_config_type_union(BaseType) + + assert result == BaseType From 4bccea71388e934d2ec02c86ee498f77015fd688 Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Mon, 10 Nov 2025 21:43:49 -0500 Subject: [PATCH 24/27] builtin -> default; move adding plugins to registry --- .../analysis/column_profilers/registry.py | 2 +- .../engine/column_generators/registry.py | 13 ++++++++++- .../engine/processing/processors/registry.py | 2 +- .../engine/registry/data_designer_registry.py | 22 +++++-------------- tests/engine/analysis/test_errors.py | 4 ++-- .../engine/column_generators/test_registry.py | 4 ++-- .../processing/processors/test_registry.py | 4 ++-- .../registry/test_data_designer_registry.py | 4 ++-- 8 files changed, 28 insertions(+), 27 deletions(-) diff --git a/src/data_designer/engine/analysis/column_profilers/registry.py b/src/data_designer/engine/analysis/column_profilers/registry.py index 7d6f06ca..ce022c68 100644 --- a/src/data_designer/engine/analysis/column_profilers/registry.py +++ b/src/data_designer/engine/analysis/column_profilers/registry.py @@ -14,7 +14,7 @@ class ColumnProfilerRegistry(TaskRegistry[ColumnProfilerType, ColumnProfiler, ConfigBase]): ... -def create_builtin_column_profiler_registry() -> ColumnProfilerRegistry: +def create_default_column_profiler_registry() -> ColumnProfilerRegistry: registry = ColumnProfilerRegistry() registry.register(ColumnProfilerType.JUDGE_SCORE, JudgeScoreProfiler, JudgeScoreProfilerConfig, False) return registry diff --git a/src/data_designer/engine/column_generators/registry.py b/src/data_designer/engine/column_generators/registry.py index 502025b7..61b43753 100644 --- a/src/data_designer/engine/column_generators/registry.py +++ b/src/data_designer/engine/column_generators/registry.py @@ -27,12 +27,14 @@ SeedDatasetMultiColumnConfig, ) from data_designer.engine.registry.base import TaskRegistry +from data_designer.plugins.plugin import PluginType +from data_designer.plugins.registry import PluginRegistry class ColumnGeneratorRegistry(TaskRegistry[DataDesignerColumnType, ColumnGenerator, ConfigBase]): ... -def create_builtin_column_generator_registry() -> ColumnGeneratorRegistry: +def create_default_column_generator_registry(with_plugins: bool = True) -> ColumnGeneratorRegistry: registry = ColumnGeneratorRegistry() registry.register(DataDesignerColumnType.LLM_TEXT, LLMTextCellGenerator, LLMTextColumnConfig) registry.register(DataDesignerColumnType.LLM_CODE, LLMCodeCellGenerator, LLMCodeColumnConfig) @@ -42,4 +44,13 @@ def create_builtin_column_generator_registry() -> ColumnGeneratorRegistry: registry.register(DataDesignerColumnType.SEED_DATASET, SeedDatasetColumnGenerator, SeedDatasetMultiColumnConfig) registry.register(DataDesignerColumnType.VALIDATION, ValidationColumnGenerator, ValidationColumnConfig) registry.register(DataDesignerColumnType.LLM_STRUCTURED, LLMStructuredCellGenerator, LLMStructuredColumnConfig) + + if with_plugins: + for plugin in PluginRegistry().get_plugins(PluginType.COLUMN_GENERATOR): + registry.register( + DataDesignerColumnType(plugin.name), + plugin.task_cls, + plugin.config_cls, + ) + return registry diff --git a/src/data_designer/engine/processing/processors/registry.py b/src/data_designer/engine/processing/processors/registry.py index 41201551..dadcbc33 100644 --- a/src/data_designer/engine/processing/processors/registry.py +++ b/src/data_designer/engine/processing/processors/registry.py @@ -14,7 +14,7 @@ class ProcessorRegistry(TaskRegistry[str, Processor, ConfigBase]): ... -def create_builtin_processor_registry() -> ProcessorRegistry: +def create_default_processor_registry() -> ProcessorRegistry: registry = ProcessorRegistry() registry.register(ProcessorType.DROP_COLUMNS, DropColumnsProcessor, DropColumnsProcessorConfig, False) return registry diff --git a/src/data_designer/engine/registry/data_designer_registry.py b/src/data_designer/engine/registry/data_designer_registry.py index 3ecb3e2c..407029c3 100644 --- a/src/data_designer/engine/registry/data_designer_registry.py +++ b/src/data_designer/engine/registry/data_designer_registry.py @@ -1,18 +1,15 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from data_designer.config.column_types import DataDesignerColumnType from data_designer.engine.analysis.column_profilers.registry import ( ColumnProfilerRegistry, - create_builtin_column_profiler_registry, + create_default_column_profiler_registry, ) from data_designer.engine.column_generators.registry import ( ColumnGeneratorRegistry, - create_builtin_column_generator_registry, + create_default_column_generator_registry, ) -from data_designer.engine.processing.processors.registry import ProcessorRegistry, create_builtin_processor_registry -from data_designer.plugins.plugin import PluginType -from data_designer.plugins.registry import PluginRegistry +from data_designer.engine.processing.processors.registry import ProcessorRegistry, create_default_processor_registry class DataDesignerRegistry: @@ -23,16 +20,9 @@ def __init__( column_profiler_registry: ColumnProfilerRegistry | None = None, processor_registry: ProcessorRegistry | None = None, ): - self._column_generator_registry = column_generator_registry or create_builtin_column_generator_registry() - self._column_profiler_registry = column_profiler_registry or create_builtin_column_profiler_registry() - self._processor_registry = processor_registry or create_builtin_processor_registry() - - for plugin in PluginRegistry().get_plugins(PluginType.COLUMN_GENERATOR): - self._column_generator_registry.register( - DataDesignerColumnType(plugin.name), - plugin.task_cls, - plugin.config_cls, - ) + self._column_generator_registry = column_generator_registry or create_default_column_generator_registry() + self._column_profiler_registry = column_profiler_registry or create_default_column_profiler_registry() + self._processor_registry = processor_registry or create_default_processor_registry() @property def column_generators(self) -> ColumnGeneratorRegistry: diff --git a/tests/engine/analysis/test_errors.py b/tests/engine/analysis/test_errors.py index ca446596..6c6fb901 100644 --- a/tests/engine/analysis/test_errors.py +++ b/tests/engine/analysis/test_errors.py @@ -10,7 +10,7 @@ ) from data_designer.engine.analysis.column_profilers.registry import ( ColumnProfilerRegistry, - create_builtin_column_profiler_registry, + create_default_column_profiler_registry, ) from data_designer.engine.registry.errors import NotFoundInRegistryError @@ -52,7 +52,7 @@ def test_get_nonexistent_profiler(): def test_create_default_registry(): - registry = create_builtin_column_profiler_registry() + registry = create_default_column_profiler_registry() assert isinstance(registry, ColumnProfilerRegistry) assert ColumnProfilerType.JUDGE_SCORE in ColumnProfilerRegistry._registry diff --git a/tests/engine/column_generators/test_registry.py b/tests/engine/column_generators/test_registry.py index 03e67486..f70b0d90 100644 --- a/tests/engine/column_generators/test_registry.py +++ b/tests/engine/column_generators/test_registry.py @@ -14,12 +14,12 @@ from data_designer.engine.column_generators.generators.validation import ValidationColumnGenerator from data_designer.engine.column_generators.registry import ( ColumnGeneratorRegistry, - create_builtin_column_generator_registry, + create_default_column_generator_registry, ) def test_column_generator_registry_create_default_registry_with_generators(): - registry = create_builtin_column_generator_registry() + registry = create_default_column_generator_registry() assert isinstance(registry, ColumnGeneratorRegistry) diff --git a/tests/engine/processing/processors/test_registry.py b/tests/engine/processing/processors/test_registry.py index 8fedd392..41ccf5a8 100644 --- a/tests/engine/processing/processors/test_registry.py +++ b/tests/engine/processing/processors/test_registry.py @@ -5,12 +5,12 @@ from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor from data_designer.engine.processing.processors.registry import ( ProcessorRegistry, - create_builtin_processor_registry, + create_default_processor_registry, ) def test_create_default_processor_registry(): - registry = create_builtin_processor_registry() + registry = create_default_processor_registry() assert isinstance(registry, ProcessorRegistry) assert ProcessorType.DROP_COLUMNS in ProcessorRegistry._registry diff --git a/tests/engine/registry/test_data_designer_registry.py b/tests/engine/registry/test_data_designer_registry.py index 6f131bb6..5f98970c 100644 --- a/tests/engine/registry/test_data_designer_registry.py +++ b/tests/engine/registry/test_data_designer_registry.py @@ -21,10 +21,10 @@ def stub_column_profiler_registry(): @pytest.fixture def stub_default_registries(): with patch( - "data_designer.engine.registry.data_designer_registry.create_builtin_column_generator_registry" + "data_designer.engine.registry.data_designer_registry.create_default_column_generator_registry" ) as mock_gen: with patch( - "data_designer.engine.registry.data_designer_registry.create_builtin_column_profiler_registry" + "data_designer.engine.registry.data_designer_registry.create_default_column_profiler_registry" ) as mock_prof: mock_gen_registry = Mock() mock_prof_registry = Mock() From c752bd4f71928198da50b4f68682d53f2dfdb96a Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Mon, 10 Nov 2025 22:06:52 -0500 Subject: [PATCH 25/27] update method names to better match what they do --- .../config/analysis/column_statistics.py | 2 +- src/data_designer/config/column_types.py | 4 +-- src/data_designer/plugin_manager.py | 23 +++++------- src/data_designer/plugins/registry.py | 3 ++ tests/test_plugin_manager.py | 35 ++++++++++--------- 5 files changed, 32 insertions(+), 35 deletions(-) diff --git a/src/data_designer/config/analysis/column_statistics.py b/src/data_designer/config/analysis/column_statistics.py index bf00234c..c39dedfb 100644 --- a/src/data_designer/config/analysis/column_statistics.py +++ b/src/data_designer/config/analysis/column_statistics.py @@ -263,7 +263,7 @@ def from_series(cls, series: Series) -> Self: DataDesignerColumnType.VALIDATION: ValidationColumnStatistics, } -for plugin in PluginManager().get_plugin_column_configs(): +for plugin in PluginManager().get_column_generator_plugins(): # Dynamically create a statistics class for this plugin using Pydantic's create_model plugin_stats_cls_name = f"{plugin.config_type_as_class_name}ColumnStatistics" diff --git a/src/data_designer/config/column_types.py b/src/data_designer/config/column_types.py index ab1b3d52..50ba498d 100644 --- a/src/data_designer/config/column_types.py +++ b/src/data_designer/config/column_types.py @@ -52,7 +52,7 @@ DataDesignerColumnType.VALIDATION: "🔍", } COLUMN_TYPE_EMOJI_MAP.update( - {DataDesignerColumnType(p.name): p.emoji for p in plugin_manager.get_plugin_column_configs()} + {DataDesignerColumnType(p.name): p.emoji for p in plugin_manager.get_column_generator_plugins()} ) @@ -117,7 +117,7 @@ def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType return SamplerColumnConfig(name=name, **_resolve_sampler_kwargs(name, kwargs)) if column_type == DataDesignerColumnType.SEED_DATASET: return SeedDatasetColumnConfig(name=name, **kwargs) - if plugin := plugin_manager.get_plugin_column_config_if_available(column_type.value): + if plugin := plugin_manager.get_column_generator_plugin_if_exists(column_type.value): return plugin.config_cls(name=name, **kwargs) raise InvalidColumnTypeError(f"🛑 {column_type} is not a valid column type.") # pragma: no cover diff --git a/src/data_designer/plugin_manager.py b/src/data_designer/plugin_manager.py index 79bd5eee..d6101a95 100644 --- a/src/data_designer/plugin_manager.py +++ b/src/data_designer/plugin_manager.py @@ -26,21 +26,16 @@ def __init__(self): self._plugins_available = False self._plugin_registry = None - def get_plugin_column_configs(self) -> list[Plugin]: - """Get all plugin column configs. + def get_column_generator_plugins(self) -> list[Plugin]: + """Get all column generator plugins. Returns: - A list of all plugin column configs. + A list of all column generator plugins. """ - if self._plugins_available: - return [ - self._plugin_registry.get_plugin(plugin_name) - for plugin_name in self._plugin_registry.get_plugin_names(PluginType.COLUMN_GENERATOR) - ] - return [] + return self._plugin_registry.get_plugins(PluginType.COLUMN_GENERATOR) if self._plugins_available else [] - def get_plugin_column_config_if_available(self, plugin_name: str) -> Plugin | None: - """Get a plugin column config by name if available. + def get_column_generator_plugin_if_exists(self, plugin_name: str) -> Plugin | None: + """Get a column generator plugin by name if it exists. Args: plugin_name: The name of the plugin to retrieve. @@ -48,10 +43,8 @@ def get_plugin_column_config_if_available(self, plugin_name: str) -> Plugin | No Returns: The plugin if found, otherwise None. """ - if self._plugins_available: - for name in self._plugin_registry.get_plugin_names(PluginType.COLUMN_GENERATOR): - if plugin_name == name: - return self._plugin_registry.get_plugin(plugin_name) + if self._plugins_available and self._plugin_registry.plugin_exists(plugin_name): + return self._plugin_registry.get_plugin(plugin_name) return None def get_plugin_column_types(self, enum_type: Type[Enum], required_resources: list[str] | None = None) -> list[Enum]: diff --git a/src/data_designer/plugins/registry.py b/src/data_designer/plugins/registry.py index a0fd86bc..94cc18aa 100644 --- a/src/data_designer/plugins/registry.py +++ b/src/data_designer/plugins/registry.py @@ -67,6 +67,9 @@ def get_plugin_names(self, plugin_type: PluginType) -> list[str]: def num_plugins(self, plugin_type: PluginType) -> int: return len(self.get_plugins(plugin_type)) + def plugin_exists(self, plugin_name: str) -> bool: + return plugin_name in self._plugins + def set_plugins(self, plugins: dict[str, Plugin]) -> None: self._plugins = plugins diff --git a/tests/test_plugin_manager.py b/tests/test_plugin_manager.py index 4e5e1b0e..9f5be903 100644 --- a/tests/test_plugin_manager.py +++ b/tests/test_plugin_manager.py @@ -73,67 +73,68 @@ def mock_plugins() -> list[Mock]: ] -def test_get_plugin_column_configs_with_plugins(mock_plugin_registry: MagicMock, mock_plugins: list[Mock]) -> None: +def test_get_column_generator_plugins_with_plugins(mock_plugin_registry: MagicMock, mock_plugins: list[Mock]) -> None: """Test getting plugin column configs when plugins are available.""" - mock_plugin_registry.get_plugin_names.return_value = ["plugin-one", "plugin-two"] - mock_plugin_registry.get_plugin.side_effect = lambda name: next(p for p in mock_plugins if p.name == name) + mock_plugin_registry.get_plugins.return_value = [mock_plugins[0], mock_plugins[1]] with mock_plugin_system(mock_plugin_registry): manager = PluginManager() - result = manager.get_plugin_column_configs() + result = manager.get_column_generator_plugins() assert len(result) == 2 assert [p.name for p in result] == ["plugin-one", "plugin-two"] - mock_plugin_registry.get_plugin_names.assert_called_once_with(MockPluginType.COLUMN_GENERATOR) + mock_plugin_registry.get_plugins.assert_called_once_with(MockPluginType.COLUMN_GENERATOR) @pytest.mark.parametrize("plugins_available", [True, False]) -def test_get_plugin_column_configs_empty(mock_plugin_registry: MagicMock, plugins_available: bool) -> None: +def test_get_column_generator_plugins_empty(mock_plugin_registry: MagicMock, plugins_available: bool) -> None: """Test getting plugin column configs when no plugins are registered or system is disabled.""" if plugins_available: - mock_plugin_registry.get_plugin_names.return_value = [] + mock_plugin_registry.get_plugins.return_value = [] with mock_plugin_system(mock_plugin_registry): manager = PluginManager() - result = manager.get_plugin_column_configs() + result = manager.get_column_generator_plugins() else: with patch("data_designer.plugin_manager.can_run_data_designer_locally", return_value=False): manager = PluginManager() - result = manager.get_plugin_column_configs() + result = manager.get_column_generator_plugins() assert result == [] -def test_get_plugin_column_config_if_available_found(mock_plugin_registry: MagicMock, mock_plugins: list[Mock]) -> None: +def test_get_column_generator_plugin_if_exists_found(mock_plugin_registry: MagicMock, mock_plugins: list[Mock]) -> None: """Test getting a specific plugin by name when it exists.""" - mock_plugin_registry.get_plugin_names.return_value = ["plugin-one", "plugin-two"] + mock_plugin_registry.plugin_exists.return_value = True mock_plugin_registry.get_plugin.return_value = mock_plugins[0] with mock_plugin_system(mock_plugin_registry): manager = PluginManager() - result = manager.get_plugin_column_config_if_available("plugin-one") + result = manager.get_column_generator_plugin_if_exists("plugin-one") assert result is not None assert result.name == "plugin-one" + mock_plugin_registry.plugin_exists.assert_called_once_with("plugin-one") mock_plugin_registry.get_plugin.assert_called_once_with("plugin-one") -def test_get_plugin_column_config_if_available_not_found(mock_plugin_registry: MagicMock) -> None: +def test_get_column_generator_plugin_if_exists_not_found(mock_plugin_registry: MagicMock) -> None: """Test getting a specific plugin by name when it doesn't exist.""" - mock_plugin_registry.get_plugin_names.return_value = ["plugin-one", "plugin-two"] + mock_plugin_registry.plugin_exists.return_value = False with mock_plugin_system(mock_plugin_registry): manager = PluginManager() - result = manager.get_plugin_column_config_if_available("plugin-three") + result = manager.get_column_generator_plugin_if_exists("plugin-three") assert result is None + mock_plugin_registry.plugin_exists.assert_called_once_with("plugin-three") mock_plugin_registry.get_plugin.assert_not_called() -def test_get_plugin_column_config_if_available_when_disabled() -> None: +def test_get_column_generator_plugin_if_exists_when_disabled() -> None: """Test getting a specific plugin when plugin system is disabled.""" with patch("data_designer.plugin_manager.can_run_data_designer_locally", return_value=False): manager = PluginManager() - result = manager.get_plugin_column_config_if_available("plugin-one") + result = manager.get_column_generator_plugin_if_exists("plugin-one") assert result is None From cea9209555076fda443663ceaf31345210a92df7 Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Tue, 11 Nov 2025 11:23:31 -0500 Subject: [PATCH 26/27] use register verb for consistency with other registries --- src/data_designer/plugins/plugin.py | 4 ++++ src/data_designer/plugins/registry.py | 14 +++++++------- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/data_designer/plugins/plugin.py b/src/data_designer/plugins/plugin.py index aa831846..886a2252 100644 --- a/src/data_designer/plugins/plugin.py +++ b/src/data_designer/plugins/plugin.py @@ -21,6 +21,10 @@ def discriminator_field(self) -> str: else: raise ValueError(f"Invalid plugin type: {self.value}") + @property + def display_name(self) -> str: + return self.value.replace("-", " ") + class Plugin(BaseModel): task_cls: Type[ConfigurableTask] diff --git a/src/data_designer/plugins/registry.py b/src/data_designer/plugins/registry.py index 94cc18aa..a897a050 100644 --- a/src/data_designer/plugins/registry.py +++ b/src/data_designer/plugins/registry.py @@ -37,11 +37,6 @@ def reset(cls) -> None: cls._plugins_discovered = False cls._plugins = {} - def add_plugin(self, plugin: Plugin) -> None: - if plugin.name in self._plugins: - raise PluginRegistrationError(f"Plugin {plugin.name!r} already added.") - self._plugins[plugin.name] = plugin - def add_plugin_types(self, type_union: Type[TypeAlias], plugin_type: PluginType) -> Type[TypeAlias]: for plugin in self.get_plugins(plugin_type): type_union |= plugin.config_cls @@ -81,9 +76,9 @@ def discover(self) -> Self: plugin = ep.load() if isinstance(plugin, Plugin): with self._lock: - self.add_plugin(plugin) + self.register(plugin) logger.info( - f"🔌 Plugin discovered ➜ {plugin.plugin_type.value.replace('-', ' ')} " + f"🔌 Plugin discovered ➜ {plugin.plugin_type.display_name} " f"{plugin.enum_key_name} is now available ⚡️" ) except Exception as e: @@ -91,6 +86,11 @@ def discover(self) -> Self: return self + def register(self, plugin: Plugin) -> None: + if plugin.name in self._plugins: + raise PluginRegistrationError(f"Plugin {plugin.name!r} already registered.") + self._plugins[plugin.name] = plugin + def __new__(cls, *args, **kwargs): """Plugin manager is a singleton.""" if not cls._instance: From 2b4468637ef5e518b59f681e822dfd471a61bb0c Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Tue, 11 Nov 2025 11:58:50 -0500 Subject: [PATCH 27/27] thread safety updates; make discover private --- src/data_designer/plugin_manager.py | 16 +++++----- src/data_designer/plugins/registry.py | 42 +++++++++------------------ tests/plugins/test_plugin_registry.py | 14 +++++++-- tests/test_plugin_manager.py | 4 +-- 4 files changed, 37 insertions(+), 39 deletions(-) diff --git a/src/data_designer/plugin_manager.py b/src/data_designer/plugin_manager.py index d6101a95..923138ea 100644 --- a/src/data_designer/plugin_manager.py +++ b/src/data_designer/plugin_manager.py @@ -20,10 +20,10 @@ class PluginManager: def __init__(self): if can_run_data_designer_locally(): - self._plugins_available = True + self._plugins_supported = True self._plugin_registry = PluginRegistry() else: - self._plugins_available = False + self._plugins_supported = False self._plugin_registry = None def get_column_generator_plugins(self) -> list[Plugin]: @@ -32,7 +32,7 @@ def get_column_generator_plugins(self) -> list[Plugin]: Returns: A list of all column generator plugins. """ - return self._plugin_registry.get_plugins(PluginType.COLUMN_GENERATOR) if self._plugins_available else [] + return self._plugin_registry.get_plugins(PluginType.COLUMN_GENERATOR) if self._plugins_supported else [] def get_column_generator_plugin_if_exists(self, plugin_name: str) -> Plugin | None: """Get a column generator plugin by name if it exists. @@ -43,7 +43,7 @@ def get_column_generator_plugin_if_exists(self, plugin_name: str) -> Plugin | No Returns: The plugin if found, otherwise None. """ - if self._plugins_available and self._plugin_registry.plugin_exists(plugin_name): + if self._plugins_supported and self._plugin_registry.plugin_exists(plugin_name): return self._plugin_registry.get_plugin(plugin_name) return None @@ -58,7 +58,7 @@ def get_plugin_column_types(self, enum_type: Type[Enum], required_resources: lis A list of plugin column types. """ type_list = [] - if self._plugins_available: + if self._plugins_supported: for plugin in self._plugin_registry.get_plugins(PluginType.COLUMN_GENERATOR): if required_resources: task_required_resources = plugin.task_cls.metadata().required_resources or [] @@ -76,6 +76,8 @@ def inject_into_column_config_type_union(self, column_config_type: Type[TypeAlia Returns: The column config type with plugins injected. """ - if self._plugins_available: - column_config_type = self._plugin_registry.add_plugin_types(column_config_type, PluginType.COLUMN_GENERATOR) + if self._plugins_supported: + column_config_type = self._plugin_registry.add_plugin_types_to_union( + column_config_type, PluginType.COLUMN_GENERATOR + ) return column_config_type diff --git a/src/data_designer/plugins/registry.py b/src/data_designer/plugins/registry.py index a897a050..6ef465e0 100644 --- a/src/data_designer/plugins/registry.py +++ b/src/data_designer/plugins/registry.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from copy import deepcopy from importlib.metadata import entry_points import logging import os @@ -10,7 +9,7 @@ from typing_extensions import Self -from data_designer.plugins.errors import PluginNotFoundError, PluginRegistrationError +from data_designer.plugins.errors import PluginNotFoundError from data_designer.plugins.plugin import Plugin, PluginType logger = logging.getLogger(__name__) @@ -27,27 +26,23 @@ class PluginRegistry: _plugins: dict[str, Plugin] = {} def __init__(self): - if not self._plugins_discovered: - self.discover() - self._plugins_discovered = True + with self._lock: + if not self._plugins_discovered: + self._discover() @classmethod def reset(cls) -> None: - cls._instance = None - cls._plugins_discovered = False - cls._plugins = {} + with cls._lock: + cls._instance = None + cls._plugins_discovered = False + cls._plugins = {} - def add_plugin_types(self, type_union: Type[TypeAlias], plugin_type: PluginType) -> Type[TypeAlias]: + def add_plugin_types_to_union(self, type_union: Type[TypeAlias], plugin_type: PluginType) -> Type[TypeAlias]: for plugin in self.get_plugins(plugin_type): - type_union |= plugin.config_cls + if plugin.config_cls not in type_union.__args__: + type_union |= plugin.config_cls return type_union - def clear_plugins(self) -> None: - self._plugins.clear() - - def copy_plugins(self) -> dict[str, Plugin]: - return deepcopy(self._plugins) - def get_plugin(self, plugin_name: str) -> Plugin: if plugin_name not in self._plugins: raise PluginNotFoundError(f"Plugin {plugin_name!r} not found.") @@ -65,32 +60,23 @@ def num_plugins(self, plugin_type: PluginType) -> int: def plugin_exists(self, plugin_name: str) -> bool: return plugin_name in self._plugins - def set_plugins(self, plugins: dict[str, Plugin]) -> None: - self._plugins = plugins - - def discover(self) -> Self: + def _discover(self) -> Self: if PLUGINS_DISABLED: return self for ep in entry_points(group="data_designer.plugins"): try: plugin = ep.load() if isinstance(plugin, Plugin): - with self._lock: - self.register(plugin) logger.info( f"🔌 Plugin discovered ➜ {plugin.plugin_type.display_name} " f"{plugin.enum_key_name} is now available ⚡️" ) + self._plugins[plugin.name] = plugin except Exception as e: logger.warning(f"🛑 Failed to load plugin from entry point {ep.name!r}: {e}") - + self._plugins_discovered = True return self - def register(self, plugin: Plugin) -> None: - if plugin.name in self._plugins: - raise PluginRegistrationError(f"Plugin {plugin.name!r} already registered.") - self._plugins[plugin.name] = plugin - def __new__(cls, *args, **kwargs): """Plugin manager is a singleton.""" if not cls._instance: diff --git a/tests/plugins/test_plugin_registry.py b/tests/plugins/test_plugin_registry.py index 4947601d..b3956f44 100644 --- a/tests/plugins/test_plugin_registry.py +++ b/tests/plugins/test_plugin_registry.py @@ -269,11 +269,21 @@ def test_plugin_registry_get_plugin_names(mock_plugin_discovery, mock_entry_poin def test_plugin_registry_update_type_union(mock_plugin_discovery, mock_entry_points: list[MagicMock]) -> None: """Test update_type_union() adds plugin config types to union.""" + from typing import Union + + from typing_extensions import TypeAlias + + class DummyConfig(ConfigBase): + pass + with mock_plugin_discovery(mock_entry_points): manager = PluginRegistry() - type_union: type = ConfigBase - updated_union = manager.add_plugin_types(type_union, PluginType.COLUMN_GENERATOR) + # Create a Union with at least 2 types so it has __args__ + type_union: TypeAlias = Union[ConfigBase, DummyConfig] + updated_union = manager.add_plugin_types_to_union(type_union, PluginType.COLUMN_GENERATOR) assert StubPluginConfigA in updated_union.__args__ assert StubPluginConfigB in updated_union.__args__ + assert ConfigBase in updated_union.__args__ + assert DummyConfig in updated_union.__args__ diff --git a/tests/test_plugin_manager.py b/tests/test_plugin_manager.py index 9f5be903..c00d78f9 100644 --- a/tests/test_plugin_manager.py +++ b/tests/test_plugin_manager.py @@ -209,14 +209,14 @@ def test_inject_into_column_config_type_union_with_plugins(mock_plugin_registry: class BaseType: pass - mock_plugin_registry.add_plugin_types.return_value = str | int + mock_plugin_registry.add_plugin_types_to_union.return_value = str | int with mock_plugin_system(mock_plugin_registry): manager = PluginManager() result = manager.inject_into_column_config_type_union(BaseType) assert result == str | int - mock_plugin_registry.add_plugin_types.assert_called_once_with(BaseType, MockPluginType.COLUMN_GENERATOR) + mock_plugin_registry.add_plugin_types_to_union.assert_called_once_with(BaseType, MockPluginType.COLUMN_GENERATOR) def test_inject_into_column_config_type_union_when_disabled() -> None: