diff --git a/docs/plugins/example.md b/docs/plugins/example.md index 9c6929d1..b288456b 100644 --- a/docs/plugins/example.md +++ b/docs/plugins/example.md @@ -131,8 +131,8 @@ from data_designer.plugins import Plugin, PluginType # Plugin instance - this is what gets loaded via entry point plugin = Plugin( - task_cls=IndexMultiplierColumnGenerator, - config_cls=IndexMultiplierColumnConfig, + task_qualified_name="data_designer_index_multiplier.plugin.IndexMultiplierColumnGenerator", + config_qualified_name="data_designer_index_multiplier.plugin.IndexMultiplierColumnConfig", plugin_type=PluginType.COLUMN_GENERATOR, emoji="🔌", ) @@ -204,8 +204,8 @@ class IndexMultiplierColumnGenerator(ColumnGenerator[IndexMultiplierColumnConfig # Plugin instance - this is what gets loaded via entry point plugin = Plugin( - task_cls=IndexMultiplierColumnGenerator, - config_cls=IndexMultiplierColumnConfig, + task_qualified_name="data_designer_index_multiplier.plugin.IndexMultiplierColumnGenerator", + config_qualified_name="data_designer_index_multiplier.plugin.IndexMultiplierColumnConfig", plugin_type=PluginType.COLUMN_GENERATOR, emoji="🔌", ) diff --git a/src/data_designer/plugins/errors.py b/src/data_designer/plugins/errors.py index de6e4435..7be1bcf4 100644 --- a/src/data_designer/plugins/errors.py +++ b/src/data_designer/plugins/errors.py @@ -4,6 +4,9 @@ from data_designer.errors import DataDesignerError +class PluginLoadError(DataDesignerError): ... + + class PluginRegistrationError(DataDesignerError): ... diff --git a/src/data_designer/plugins/plugin.py b/src/data_designer/plugins/plugin.py index 6553e45e..007f0ae4 100644 --- a/src/data_designer/plugins/plugin.py +++ b/src/data_designer/plugins/plugin.py @@ -1,14 +1,23 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import ast +import importlib +import importlib.util from enum import Enum -from typing import Literal, get_origin +from functools import cached_property +from typing import TYPE_CHECKING, Literal, get_origin -from pydantic import BaseModel, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator from typing_extensions import Self -from data_designer.config.base import ConfigBase -from data_designer.engine.configurable_task import ConfigurableTask +from data_designer.plugins.errors import PluginLoadError + +if TYPE_CHECKING: + from data_designer.config.base import ConfigBase + from data_designer.engine.configurable_task import ConfigurableTask class PluginType(str, Enum): @@ -26,11 +35,42 @@ def display_name(self) -> str: return self.value.replace("-", " ") +def _get_module_and_object_names(fully_qualified_object: str) -> tuple[str, str]: + try: + module_name, object_name = fully_qualified_object.rsplit(".", 1) + except ValueError: + # If fully_qualified_object does not have any periods, the rsplit call will return + # a list of length 1 and the variable assignment above will raise ValueError + raise PluginLoadError("Expected a fully-qualified object name, e.g. 'my_plugin.config.MyConfig'") + + return module_name, object_name + + +def _check_class_exists_in_file(filepath: str, class_name: str) -> None: + try: + with open(filepath, "r") as file: + source = file.read() + except FileNotFoundError: + raise PluginLoadError(f"Could not read source code at {filepath!r}") + + tree = ast.parse(source) + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == class_name: + return None + + raise PluginLoadError(f"Could not find class named {class_name!r} in {filepath!r}") + + class Plugin(BaseModel): - task_cls: type[ConfigurableTask] - config_cls: type[ConfigBase] - plugin_type: PluginType - emoji: str = "🔌" + task_qualified_name: str = Field( + ..., + description="The fully-qualified name of the task class object, e.g. 'my_plugin.generator.MyColumnGenerator'", + ) + config_qualified_name: str = Field( + ..., description="The fully-qualified name o the config class object, e.g. 'my_plugin.config.MyConfig'" + ) + plugin_type: PluginType = Field(..., description="The type of plugin") + emoji: str = Field(default="🔌", description="The emoji to use in logs related to the plugin") @property def config_type_as_class_name(self) -> str: @@ -48,22 +88,55 @@ def name(self) -> str: def discriminator_field(self) -> str: return self.plugin_type.discriminator_field + @field_validator("task_qualified_name", "config_qualified_name", mode="after") + @classmethod + def validate_class_name(cls, value: str) -> str: + module_name, object_name = _get_module_and_object_names(value) + try: + spec = importlib.util.find_spec(module_name) + except: + raise PluginLoadError(f"Could not find module {module_name!r}") + + if spec is None or spec.origin is None: + raise PluginLoadError(f"Error finding source for module {module_name!r}") + + _check_class_exists_in_file(spec.origin, object_name) + + return value + @model_validator(mode="after") def validate_discriminator_field(self) -> Self: - cfg = self.config_cls.__name__ + _, cfg = _get_module_and_object_names(self.config_qualified_name) field = self.plugin_type.discriminator_field if field not in self.config_cls.model_fields: - raise ValueError(f"Discriminator field '{field}' not found in config class {cfg}") + raise ValueError(f"Discriminator field {field!r} not found in config class {cfg!r}") field_info = self.config_cls.model_fields[field] if get_origin(field_info.annotation) is not Literal: - raise ValueError(f"Field '{field}' of {cfg} must be a Literal type, not {field_info.annotation}.") + raise ValueError(f"Field {field!r} of {cfg!r} must be a Literal type, not {field_info.annotation!r}.") if not isinstance(field_info.default, str): - raise ValueError(f"The default of '{field}' must be a string, not {type(field_info.default)}.") + raise ValueError(f"The default of {field!r} must be a string, not {type(field_info.default)!r}.") enum_key = field_info.default.replace("-", "_").upper() if not enum_key.isidentifier(): raise ValueError( - f"The default value '{field_info.default}' for discriminator field '{field}' " - f"cannot be converted to a valid enum key. The converted key '{enum_key}' " + f"The default value {field_info.default!r} for discriminator field {field!r} " + f"cannot be converted to a valid enum key. The converted key {enum_key!r} " f"must be a valid Python identifier." ) return self + + @cached_property + def config_cls(self) -> type[ConfigBase]: + return self._load(self.config_qualified_name) + + @cached_property + def task_cls(self) -> type[ConfigurableTask]: + return self._load(self.task_qualified_name) + + @staticmethod + def _load(fully_qualified_object: str) -> type: + module_name, object_name = _get_module_and_object_names(fully_qualified_object) + module = importlib.import_module(module_name) + try: + return getattr(module, object_name) + except AttributeError: + raise PluginLoadError(f"Could not find class {object_name!r} in module {module_name!r}") diff --git a/src/data_designer/plugins/testing/__init__.py b/src/data_designer/plugins/testing/__init__.py new file mode 100644 index 00000000..61ee469b --- /dev/null +++ b/src/data_designer/plugins/testing/__init__.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from data_designer.plugins.testing.utils import assert_valid_plugin + +__all__ = [ + assert_valid_plugin.__name__, +] diff --git a/src/data_designer/plugins/testing/stubs.py b/src/data_designer/plugins/testing/stubs.py new file mode 100644 index 00000000..342a641a --- /dev/null +++ b/src/data_designer/plugins/testing/stubs.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Literal + +from data_designer.config.base import ConfigBase +from data_designer.config.column_configs import SingleColumnConfig +from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata + +MODULE_NAME = __name__ + + +class ValidTestConfig(SingleColumnConfig): + """Valid config for testing plugin creation.""" + + column_type: Literal["test-generator"] = "test-generator" + name: str + + +class ValidTestTask(ConfigurableTask[ValidTestConfig]): + """Valid task for testing plugin creation.""" + + @staticmethod + def metadata() -> ConfigurableTaskMetadata: + return ConfigurableTaskMetadata( + name="test_generator", + description="Test generator", + required_resources=None, + ) + + +class ConfigWithoutDiscriminator(ConfigBase): + some_field: str + + +class ConfigWithStringField(ConfigBase): + column_type: str = "test-generator" + + +class ConfigWithNonStringDefault(ConfigBase): + column_type: Literal["test-generator"] = 123 # type: ignore + + +class ConfigWithInvalidKey(ConfigBase): + column_type: Literal["invalid-key-!@#"] = "invalid-key-!@#" + + +class StubPluginConfigA(SingleColumnConfig): + column_type: Literal["test-plugin-a"] = "test-plugin-a" + + +class StubPluginConfigB(SingleColumnConfig): + column_type: Literal["test-plugin-b"] = "test-plugin-b" + + +class StubPluginTaskA(ConfigurableTask[StubPluginConfigA]): + @staticmethod + def metadata() -> ConfigurableTaskMetadata: + return ConfigurableTaskMetadata( + name="test_plugin_a", + description="Test plugin A", + required_resources=None, + ) + + +class StubPluginTaskB(ConfigurableTask[StubPluginConfigB]): + @staticmethod + def metadata() -> ConfigurableTaskMetadata: + return ConfigurableTaskMetadata( + name="test_plugin_b", + description="Test plugin B", + required_resources=None, + ) diff --git a/src/data_designer/plugins/testing/utils.py b/src/data_designer/plugins/testing/utils.py new file mode 100644 index 00000000..ff96a038 --- /dev/null +++ b/src/data_designer/plugins/testing/utils.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from data_designer.config.base import ConfigBase +from data_designer.engine.configurable_task import ConfigurableTask +from data_designer.plugins.plugin import Plugin + + +def assert_valid_plugin(plugin: Plugin) -> None: + assert issubclass(plugin.config_cls, ConfigBase), "Plugin config class is not a subclass of ConfigBase" + assert issubclass(plugin.task_cls, ConfigurableTask), "Plugin task class is not a subclass of ConfigurableTask" diff --git a/tests/plugins/test_plugin.py b/tests/plugins/test_plugin.py index ac5b17cc..afddcca2 100644 --- a/tests/plugins/test_plugin.py +++ b/tests/plugins/test_plugin.py @@ -1,43 +1,24 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Literal - import pytest -from pydantic import ValidationError from data_designer.config.base import ConfigBase -from data_designer.config.column_configs import SamplerColumnConfig, SingleColumnConfig +from data_designer.config.column_configs import SamplerColumnConfig from data_designer.engine.column_generators.generators.samplers import SamplerColumnGenerator -from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata +from data_designer.engine.configurable_task import ConfigurableTask +from data_designer.plugins.errors import PluginLoadError from data_designer.plugins.plugin import Plugin, PluginType - - -class ValidTestConfig(SingleColumnConfig): - """Valid config for testing plugin creation.""" - - column_type: Literal["test-generator"] = "test-generator" - name: str - - -class ValidTestTask(ConfigurableTask[ValidTestConfig]): - """Valid task for testing plugin creation.""" - - @staticmethod - def metadata() -> ConfigurableTaskMetadata: - return ConfigurableTaskMetadata( - name="test_generator", - description="Test generator", - required_resources=None, - ) +from data_designer.plugins.testing.stubs import MODULE_NAME, ValidTestConfig, ValidTestTask +from data_designer.plugins.testing.utils import assert_valid_plugin @pytest.fixture def valid_plugin() -> Plugin: """Fixture providing a valid plugin instance for testing.""" return Plugin( - task_cls=ValidTestTask, - config_cls=ValidTestConfig, + task_qualified_name=f"{MODULE_NAME}.ValidTestTask", + config_qualified_name=f"{MODULE_NAME}.ValidTestConfig", plugin_type=PluginType.COLUMN_GENERATOR, ) @@ -81,18 +62,6 @@ def test_plugin_discriminator_field_from_type(valid_plugin: Plugin) -> None: assert valid_plugin.discriminator_field == "column_type" -def test_plugin_requires_all_fields() -> None: - """Test that Plugin creation fails without required fields.""" - with pytest.raises(ValidationError): - Plugin() # type: ignore - - with pytest.raises(ValidationError): - Plugin(task_cls=ValidTestTask) # type: ignore - - with pytest.raises(ValidationError): - Plugin(config_cls=ValidTestConfig) # type: ignore - - # ============================================================================= # Plugin Validation Tests # ============================================================================= @@ -101,58 +70,99 @@ def test_plugin_requires_all_fields() -> None: def test_validation_fails_when_config_missing_discriminator_field() -> None: """Test validation fails when config lacks the required discriminator field.""" - class ConfigWithoutDiscriminator(ConfigBase): - some_field: str - with pytest.raises(ValueError, match="Discriminator field 'column_type' not found in config class"): Plugin( - task_cls=ValidTestTask, - config_cls=ConfigWithoutDiscriminator, + task_qualified_name=f"{MODULE_NAME}.ValidTestTask", + config_qualified_name=f"{MODULE_NAME}.ConfigWithoutDiscriminator", plugin_type=PluginType.COLUMN_GENERATOR, ) def test_validation_fails_when_discriminator_field_not_literal_type() -> None: """Test validation fails when discriminator field is not a Literal type.""" - - class ConfigWithStringField(ConfigBase): - column_type: str = "test-generator" - with pytest.raises(ValueError, match="Field 'column_type' of .* must be a Literal type"): Plugin( - task_cls=ValidTestTask, - config_cls=ConfigWithStringField, + task_qualified_name=f"{MODULE_NAME}.ValidTestTask", + config_qualified_name=f"{MODULE_NAME}.ConfigWithStringField", plugin_type=PluginType.COLUMN_GENERATOR, ) def test_validation_fails_when_discriminator_default_not_string() -> None: """Test validation fails when discriminator field default is not a string.""" - - class ConfigWithNonStringDefault(ConfigBase): - column_type: Literal["test-generator"] = 123 # type: ignore - with pytest.raises(ValueError, match="The default of 'column_type' must be a string"): Plugin( - task_cls=ValidTestTask, - config_cls=ConfigWithNonStringDefault, + task_qualified_name=f"{MODULE_NAME}.ValidTestTask", + config_qualified_name=f"{MODULE_NAME}.ConfigWithNonStringDefault", plugin_type=PluginType.COLUMN_GENERATOR, ) def test_validation_fails_with_invalid_enum_key_conversion() -> None: """Test validation fails when default value cannot be converted to valid Python identifier.""" + with pytest.raises(ValueError, match="cannot be converted to a valid enum key"): + Plugin( + task_qualified_name=f"{MODULE_NAME}.ValidTestTask", + config_qualified_name=f"{MODULE_NAME}.ConfigWithInvalidKey", + plugin_type=PluginType.COLUMN_GENERATOR, + ) - class ConfigWithInvalidKey(ConfigBase): - column_type: Literal["invalid-key-!@#"] = "invalid-key-!@#" - with pytest.raises(ValueError, match="cannot be converted to a valid enum key"): +def test_validation_fails_with_invalid_modules() -> None: + """Test validation fails when task or config class modules are invalid.""" + with pytest.raises(PluginLoadError, match="Could not find module"): Plugin( - task_cls=ValidTestTask, - config_cls=ConfigWithInvalidKey, + task_qualified_name=f"{MODULE_NAME}.ValidTestTask", + config_qualified_name="invalid.module.ValidTestConfig", plugin_type=PluginType.COLUMN_GENERATOR, ) + with pytest.raises(PluginLoadError, match="Could not find module"): + Plugin( + task_qualified_name="invalid.module.ValidTestTask", + config_qualified_name=f"{MODULE_NAME}.ValidTestConfig", + plugin_type=PluginType.COLUMN_GENERATOR, + ) + + with pytest.raises(PluginLoadError, match="Expected a fully-qualified object name"): + Plugin( + task_qualified_name="ValidTestTask", + config_qualified_name="ValidTestConfig", + plugin_type=PluginType.COLUMN_GENERATOR, + ) + + with pytest.raises(PluginLoadError, match="Could not find class"): + Plugin( + task_qualified_name=f"{MODULE_NAME}.ValidTestTask", + config_qualified_name=f"{MODULE_NAME}.NotADefinedClass", + plugin_type=PluginType.COLUMN_GENERATOR, + ) + + with pytest.raises(PluginLoadError, match="Could not find class"): + Plugin( + task_qualified_name=f"{MODULE_NAME}.NotADefinedClass", + config_qualified_name=f"{MODULE_NAME}.ValidTestConfig", + plugin_type=PluginType.COLUMN_GENERATOR, + ) + + +def test_helper_utility_identifies_invalid_classes() -> None: + """Test the helper utility provides deeper validation of config classes.""" + valid_plugin = Plugin( + task_qualified_name=f"{MODULE_NAME}.ValidTestTask", + config_qualified_name=f"{MODULE_NAME}.ValidTestConfig", + plugin_type=PluginType.COLUMN_GENERATOR, + ) + assert_valid_plugin(valid_plugin) + + plugin_with_improper_task_class_type = Plugin( + task_qualified_name=f"{MODULE_NAME}.ValidTestConfig", + config_qualified_name=f"{MODULE_NAME}.ValidTestConfig", + plugin_type=PluginType.COLUMN_GENERATOR, + ) + with pytest.raises(AssertionError): + assert_valid_plugin(plugin_with_improper_task_class_type) + # ============================================================================= # Integration Tests @@ -162,8 +172,8 @@ class ConfigWithInvalidKey(ConfigBase): def test_plugin_works_with_real_sampler_column_generator() -> None: """Test that Plugin works with actual SamplerColumnGenerator from the codebase.""" plugin = Plugin( - task_cls=SamplerColumnGenerator, - config_cls=SamplerColumnConfig, + task_qualified_name="data_designer.engine.column_generators.generators.samplers.SamplerColumnGenerator", + config_qualified_name="data_designer.config.column_configs.SamplerColumnConfig", plugin_type=PluginType.COLUMN_GENERATOR, ) diff --git a/tests/plugins/test_plugin_registry.py b/tests/plugins/test_plugin_registry.py index 5a7feb5c..215fe713 100644 --- a/tests/plugins/test_plugin_registry.py +++ b/tests/plugins/test_plugin_registry.py @@ -4,50 +4,15 @@ import threading from contextlib import contextmanager from importlib.metadata import EntryPoint -from typing import Literal from unittest.mock import MagicMock, patch import pytest from data_designer.config.base import ConfigBase -from data_designer.config.column_configs import SingleColumnConfig -from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata from data_designer.plugins.errors import PluginNotFoundError from data_designer.plugins.plugin import Plugin, PluginType from data_designer.plugins.registry import PluginRegistry - -# ============================================================================= -# Test Stubs -# ============================================================================= - - -class StubPluginConfigA(SingleColumnConfig): - column_type: Literal["test-plugin-a"] = "test-plugin-a" - - -class StubPluginConfigB(SingleColumnConfig): - column_type: Literal["test-plugin-b"] = "test-plugin-b" - - -class StubPluginTaskA(ConfigurableTask[StubPluginConfigA]): - @staticmethod - def metadata() -> ConfigurableTaskMetadata: - return ConfigurableTaskMetadata( - name="test_plugin_a", - description="Test plugin A", - required_resources=None, - ) - - -class StubPluginTaskB(ConfigurableTask[StubPluginConfigB]): - @staticmethod - def metadata() -> ConfigurableTaskMetadata: - return ConfigurableTaskMetadata( - name="test_plugin_b", - description="Test plugin B", - required_resources=None, - ) - +from data_designer.plugins.testing.stubs import MODULE_NAME, StubPluginConfigA, StubPluginConfigB # ============================================================================= # Test Fixtures @@ -57,8 +22,8 @@ def metadata() -> ConfigurableTaskMetadata: @pytest.fixture def plugin_a() -> Plugin: return Plugin( - task_cls=StubPluginTaskA, - config_cls=StubPluginConfigA, + task_qualified_name=f"{MODULE_NAME}.StubPluginTaskA", + config_qualified_name=f"{MODULE_NAME}.StubPluginConfigA", plugin_type=PluginType.COLUMN_GENERATOR, ) @@ -66,8 +31,8 @@ def plugin_a() -> Plugin: @pytest.fixture def plugin_b() -> Plugin: return Plugin( - task_cls=StubPluginTaskB, - config_cls=StubPluginConfigB, + task_qualified_name=f"{MODULE_NAME}.StubPluginTaskB", + config_qualified_name=f"{MODULE_NAME}.StubPluginConfigB", plugin_type=PluginType.COLUMN_GENERATOR, )