Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/plugins/example.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ from data_designer.plugins import Plugin, PluginType

# Plugin instance - this is what gets loaded via entry point
plugin = Plugin(
task_cls=IndexMultiplierColumnGenerator,
config_cls=IndexMultiplierColumnConfig,
task_qualified_name="data_designer_index_multiplier.plugin.IndexMultiplierColumnGenerator",
config_qualified_name="data_designer_index_multiplier.plugin.IndexMultiplierColumnConfig",
plugin_type=PluginType.COLUMN_GENERATOR,
emoji="🔌",
)
Expand Down Expand Up @@ -204,8 +204,8 @@ class IndexMultiplierColumnGenerator(ColumnGenerator[IndexMultiplierColumnConfig

# Plugin instance - this is what gets loaded via entry point
plugin = Plugin(
task_cls=IndexMultiplierColumnGenerator,
config_cls=IndexMultiplierColumnConfig,
task_qualified_name="data_designer_index_multiplier.plugin.IndexMultiplierColumnGenerator",
config_qualified_name="data_designer_index_multiplier.plugin.IndexMultiplierColumnConfig",
plugin_type=PluginType.COLUMN_GENERATOR,
emoji="🔌",
)
Expand Down
3 changes: 3 additions & 0 deletions src/data_designer/plugins/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from data_designer.errors import DataDesignerError


class PluginLoadError(DataDesignerError): ...


class PluginRegistrationError(DataDesignerError): ...


Expand Down
101 changes: 87 additions & 14 deletions src/data_designer/plugins/plugin.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import ast
import importlib
import importlib.util
from enum import Enum
from typing import Literal, get_origin
from functools import cached_property
from typing import TYPE_CHECKING, Literal, get_origin

from pydantic import BaseModel, model_validator
from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self

from data_designer.config.base import ConfigBase
from data_designer.engine.configurable_task import ConfigurableTask
from data_designer.plugins.errors import PluginLoadError

if TYPE_CHECKING:
from data_designer.config.base import ConfigBase
from data_designer.engine.configurable_task import ConfigurableTask


class PluginType(str, Enum):
Expand All @@ -26,11 +35,42 @@ def display_name(self) -> str:
return self.value.replace("-", " ")


def _get_module_and_object_names(fully_qualified_object: str) -> tuple[str, str]:
try:
module_name, object_name = fully_qualified_object.rsplit(".", 1)
except ValueError:
# If fully_qualified_object does not have any periods, the rsplit call will return
# a list of length 1 and the variable assignment above will raise ValueError
raise PluginLoadError("Expected a fully-qualified object name, e.g. 'my_plugin.config.MyConfig'")

return module_name, object_name


def _check_class_exists_in_file(filepath: str, class_name: str) -> None:
try:
with open(filepath, "r") as file:
source = file.read()
except FileNotFoundError:
raise PluginLoadError(f"Could not read source code at {filepath!r}")

tree = ast.parse(source)
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and node.name == class_name:
return None

raise PluginLoadError(f"Could not find class named {class_name!r} in {filepath!r}")


class Plugin(BaseModel):
task_cls: type[ConfigurableTask]
config_cls: type[ConfigBase]
plugin_type: PluginType
emoji: str = "🔌"
task_qualified_name: str = Field(
...,
description="The fully-qualified name of the task class object, e.g. 'my_plugin.generator.MyColumnGenerator'",
)
config_qualified_name: str = Field(
..., description="The fully-qualified name o the config class object, e.g. 'my_plugin.config.MyConfig'"
)
plugin_type: PluginType = Field(..., description="The type of plugin")
emoji: str = Field(default="🔌", description="The emoji to use in logs related to the plugin")

@property
def config_type_as_class_name(self) -> str:
Expand All @@ -48,22 +88,55 @@ def name(self) -> str:
def discriminator_field(self) -> str:
return self.plugin_type.discriminator_field

@field_validator("task_qualified_name", "config_qualified_name", mode="after")
@classmethod
def validate_class_name(cls, value: str) -> str:
module_name, object_name = _get_module_and_object_names(value)
try:
spec = importlib.util.find_spec(module_name)
except:
raise PluginLoadError(f"Could not find module {module_name!r}")

if spec is None or spec.origin is None:
raise PluginLoadError(f"Error finding source for module {module_name!r}")

_check_class_exists_in_file(spec.origin, object_name)

return value

@model_validator(mode="after")
def validate_discriminator_field(self) -> Self:
cfg = self.config_cls.__name__
_, cfg = _get_module_and_object_names(self.config_qualified_name)
field = self.plugin_type.discriminator_field
if field not in self.config_cls.model_fields:
raise ValueError(f"Discriminator field '{field}' not found in config class {cfg}")
raise ValueError(f"Discriminator field {field!r} not found in config class {cfg!r}")
field_info = self.config_cls.model_fields[field]
if get_origin(field_info.annotation) is not Literal:
raise ValueError(f"Field '{field}' of {cfg} must be a Literal type, not {field_info.annotation}.")
raise ValueError(f"Field {field!r} of {cfg!r} must be a Literal type, not {field_info.annotation!r}.")
if not isinstance(field_info.default, str):
raise ValueError(f"The default of '{field}' must be a string, not {type(field_info.default)}.")
raise ValueError(f"The default of {field!r} must be a string, not {type(field_info.default)!r}.")
enum_key = field_info.default.replace("-", "_").upper()
if not enum_key.isidentifier():
raise ValueError(
f"The default value '{field_info.default}' for discriminator field '{field}' "
f"cannot be converted to a valid enum key. The converted key '{enum_key}' "
f"The default value {field_info.default!r} for discriminator field {field!r} "
f"cannot be converted to a valid enum key. The converted key {enum_key!r} "
f"must be a valid Python identifier."
)
return self

@cached_property
def config_cls(self) -> type[ConfigBase]:
return self._load(self.config_qualified_name)

@cached_property
def task_cls(self) -> type[ConfigurableTask]:
return self._load(self.task_qualified_name)

@staticmethod
def _load(fully_qualified_object: str) -> type:
module_name, object_name = _get_module_and_object_names(fully_qualified_object)
module = importlib.import_module(module_name)
try:
return getattr(module, object_name)
except AttributeError:
raise PluginLoadError(f"Could not find class {object_name!r} in module {module_name!r}")
8 changes: 8 additions & 0 deletions src/data_designer/plugins/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from data_designer.plugins.testing.utils import assert_valid_plugin

__all__ = [
assert_valid_plugin.__name__,
]
73 changes: 73 additions & 0 deletions src/data_designer/plugins/testing/stubs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from typing import Literal

from data_designer.config.base import ConfigBase
from data_designer.config.column_configs import SingleColumnConfig
from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata

MODULE_NAME = __name__


class ValidTestConfig(SingleColumnConfig):
"""Valid config for testing plugin creation."""

column_type: Literal["test-generator"] = "test-generator"
name: str


class ValidTestTask(ConfigurableTask[ValidTestConfig]):
"""Valid task for testing plugin creation."""

@staticmethod
def metadata() -> ConfigurableTaskMetadata:
return ConfigurableTaskMetadata(
name="test_generator",
description="Test generator",
required_resources=None,
)


class ConfigWithoutDiscriminator(ConfigBase):
some_field: str


class ConfigWithStringField(ConfigBase):
column_type: str = "test-generator"


class ConfigWithNonStringDefault(ConfigBase):
column_type: Literal["test-generator"] = 123 # type: ignore


class ConfigWithInvalidKey(ConfigBase):
column_type: Literal["invalid-key-!@#"] = "invalid-key-!@#"


class StubPluginConfigA(SingleColumnConfig):
column_type: Literal["test-plugin-a"] = "test-plugin-a"


class StubPluginConfigB(SingleColumnConfig):
column_type: Literal["test-plugin-b"] = "test-plugin-b"


class StubPluginTaskA(ConfigurableTask[StubPluginConfigA]):
@staticmethod
def metadata() -> ConfigurableTaskMetadata:
return ConfigurableTaskMetadata(
name="test_plugin_a",
description="Test plugin A",
required_resources=None,
)


class StubPluginTaskB(ConfigurableTask[StubPluginConfigB]):
@staticmethod
def metadata() -> ConfigurableTaskMetadata:
return ConfigurableTaskMetadata(
name="test_plugin_b",
description="Test plugin B",
required_resources=None,
)
11 changes: 11 additions & 0 deletions src/data_designer/plugins/testing/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from data_designer.config.base import ConfigBase
from data_designer.engine.configurable_task import ConfigurableTask
from data_designer.plugins.plugin import Plugin


def assert_valid_plugin(plugin: Plugin) -> None:
assert issubclass(plugin.config_cls, ConfigBase), "Plugin config class is not a subclass of ConfigBase"
assert issubclass(plugin.task_cls, ConfigurableTask), "Plugin task class is not a subclass of ConfigurableTask"
Loading