Skip to content

Commit f4c501e

Browse files
committed
plugin system updates
1 parent d7e93c5 commit f4c501e

File tree

8 files changed

+312
-114
lines changed

8 files changed

+312
-114
lines changed

_tmp_notes.md

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Plugin system updates
2+
3+
## Requirements
4+
5+
1. Plugins MUST support defining both a configuration object (a Pydantic model) and some `engine`-related implementation object (`ConfigurableTask`, `ColumnGenerator`, etc.).
6+
1. The UX for making plugins discoverable MUST be simple. We should only require users define a single `Plugin` object that gets referenced in a single entry point.
7+
1. The plugin system MUST NOT introduce a dependency chain that makes any `config` module depend on any `engine` module.
8+
a. Breaks "slim install" support, because `engine` code may include third-party deps that a `config`-only slim install will not include.
9+
b. Introduces a high risk of circular imports, because `engine` code depends on `config` modules.
10+
1. A client using a slim-install of the library SHOULD be able to use plugins.
11+
12+
13+
## Current state
14+
15+
The current plugin system violates REQ 3 (and by extension REQ 4):
16+
17+
```
18+
config.column_types -> data_designer.plugin_manager -> data_designer.plugins.plugin -> data_designer.engine.configurable_task
19+
```
20+
(`->` means "imports" aka "depends on")
21+
22+
23+
## Blessed engine modules?
24+
25+
One idea that was floated is to refactor existing `engine` code so that base classes exist in some "blessed" module(s) that we would ensure do not create circular imports with `config` modules,
26+
but this seems...
27+
- hard to enforce/guarantee
28+
- potentially restricting (what if a plugin author wants to use objects from other parts of `engine`)
29+
- conceptually not ideal (it's just simpler to say "`engine` can import/depend on `config`, but not vice-versa" full stop instead of carving out exceptions)
30+
- potentially complicated with respect to however we restructure packages to support slim installs
31+
32+
33+
## Idea proposed in this branch
34+
35+
Make the `Plugin` object "lazy" by defining the config and and task types as fully-qualified strings rather than objects.
36+
37+
By using strings in the `Plugin` object fields, **if** the plugin is structured with multiple files (e.g. `config.py` and `task.py`)*,
38+
then the core library's `config` code that uses plugins (to extend discriminated union types) can load the plugin and resolve **only**
39+
the config class type; it would not need to resolve/load/import the plugin's task-related module where `engine` base classes are imported and subclassed.
40+
41+
> *This multi-file setup wouldn't be **required** out of the box; see "Plugin development lifecycle" below.
42+
43+
Example:
44+
```python
45+
# src/my_plugin/config.py
46+
from data_designer.config.column_types import SingleColumnConfig
47+
48+
class MyPluginConfig(SingleColumnConfig):
49+
foo: str
50+
51+
52+
53+
# src/my_plugin/generator.py
54+
from data_designer.engine.column_generators.generators.base import ColumnGenerator
55+
from my_plugin.config import MyPluginConfig
56+
57+
class MyPluginGenerator(ColumnGenerator[MyPluginConfig]):
58+
pass
59+
60+
61+
62+
# src/my_plugin/plugin.py
63+
from data_designer.plugins.plugin import Plugin, PluginType
64+
65+
plugin = Plugin(
66+
config_cls="my_plugin.config.MyPluginConfig",
67+
task_cls="my_plugin.generator.MyPluginGenerator",
68+
plugin_type=PluginType.COLUMN_GENERATOR,
69+
)
70+
```
71+
72+
73+
### Strings instead of concrete types?
74+
75+
Yeah, a little sad, but seems a reasonable compromise given the benefits this unlocks.
76+
77+
To mitigate against dumb stuff like typos, I suggest we ship a test helper function that we'd encourage plugin authors use in their unit tests:
78+
```python
79+
# my_plugin/tests/test_plugin.py
80+
from data_designer.plugins.test import is_valid_plugin
81+
from my_plugin.plugin import plugin
82+
83+
84+
def test_plugin_validity():
85+
assert is_valid_plugin(plugin)
86+
```
87+
(Similar to `pd.testing.assert_frame_equal`.)
88+
89+
To start, that test helper would ensure two things:
90+
1. The string class names resolve to concrete types that do exist
91+
2. The resolved concrete types are subclasses of the expected base classes
92+
93+
In the future, we could extend the helper to validate other things that are more complex than just Pydantic field type validations.
94+
95+
Remember: we can't implement this validation as a Pydantic validator because it would break the laziness.
96+
We **can** at least validate that the module exists (and this branch does so), but only the test helper
97+
can go further and actually fully resolve the two fields.
98+
99+
100+
### Plugin development lifecycle
101+
102+
A plugin author _could_ continue defining everything in one Python file and things would still work in the library.
103+
The limitation would be that a plugin defined that way would not support slim installs, and so clients like NMP would not be able to use it.
104+
**This might be perfectly fine for many plugins**, especially in the early going.
105+
A reasonable "plugin development lifecycle" might be:
106+
1. Develop everything in one file and get it working with the library
107+
2. Refactor the plugin to support slim installs (if ever desired)
108+
109+
Plugin authors would only need to do step 2 if/when we want to make the plugin available in NMP.
110+
That step 2 refactor would involve breaking the plugin implementation up into multiple files _and_ (if necessary) making sure any heavyweight,
111+
task-only third party dependencies are included under an `engine` extra.

src/data_designer/plugins/errors.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from data_designer.errors import DataDesignerError
55

66

7+
class PluginLoadError(DataDesignerError): ...
8+
9+
710
class PluginRegistrationError(DataDesignerError): ...
811

912

src/data_designer/plugins/plugin.py

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
from __future__ import annotations
5+
6+
import importlib
7+
import importlib.util
48
from enum import Enum
5-
from typing import Literal, get_origin
9+
from functools import cached_property
10+
from typing import TYPE_CHECKING, Annotated, Literal, get_origin
611

7-
from pydantic import BaseModel, model_validator
12+
from pydantic import AfterValidator, BaseModel, model_validator
813
from typing_extensions import Self
914

10-
from data_designer.config.base import ConfigBase
11-
from data_designer.engine.configurable_task import ConfigurableTask
15+
from data_designer.plugins.errors import PluginLoadError
16+
17+
if TYPE_CHECKING:
18+
from data_designer.config.base import ConfigBase
19+
from data_designer.engine.configurable_task import ConfigurableTask
1220

1321

1422
class PluginType(str, Enum):
@@ -26,9 +34,30 @@ def display_name(self) -> str:
2634
return self.value.replace("-", " ")
2735

2836

37+
def _get_module_and_object_names(fully_qualified_object: str) -> tuple[str, str]:
38+
try:
39+
module_name, object_name = fully_qualified_object.rsplit(".", 1)
40+
except ValueError:
41+
# If fully_qualified_object does not have any periods, the rsplit call will return
42+
# a list of length 1 and the variable assignment above will raise ValueError
43+
raise PluginLoadError("Expected a fully-qualified object name, e.g. 'my_plugin.config.MyConfig'")
44+
45+
return module_name, object_name
46+
47+
48+
def _is_valid_module(value: str) -> str:
49+
module_name, _ = _get_module_and_object_names(value)
50+
try:
51+
importlib.util.find_spec(module_name)
52+
except:
53+
raise PluginLoadError(f"Could not find module {module_name!r}.")
54+
55+
return value
56+
57+
2958
class Plugin(BaseModel):
30-
task_cls: type[ConfigurableTask]
31-
config_cls: type[ConfigBase]
59+
task_class_name: Annotated[str, AfterValidator(_is_valid_module)]
60+
config_class_name: Annotated[str, AfterValidator(_is_valid_module)]
3261
plugin_type: PluginType
3362
emoji: str = "🔌"
3463

@@ -50,20 +79,37 @@ def discriminator_field(self) -> str:
5079

5180
@model_validator(mode="after")
5281
def validate_discriminator_field(self) -> Self:
53-
cfg = self.config_cls.__name__
82+
_, cfg = _get_module_and_object_names(self.config_class_name)
5483
field = self.plugin_type.discriminator_field
5584
if field not in self.config_cls.model_fields:
56-
raise ValueError(f"Discriminator field '{field}' not found in config class {cfg}")
85+
raise ValueError(f"Discriminator field {field!r} not found in config class {cfg!r}")
5786
field_info = self.config_cls.model_fields[field]
5887
if get_origin(field_info.annotation) is not Literal:
59-
raise ValueError(f"Field '{field}' of {cfg} must be a Literal type, not {field_info.annotation}.")
88+
raise ValueError(f"Field {field!r} of {cfg!r} must be a Literal type, not {field_info.annotation!r}.")
6089
if not isinstance(field_info.default, str):
61-
raise ValueError(f"The default of '{field}' must be a string, not {type(field_info.default)}.")
90+
raise ValueError(f"The default of {field!r} must be a string, not {type(field_info.default)!r}.")
6291
enum_key = field_info.default.replace("-", "_").upper()
6392
if not enum_key.isidentifier():
6493
raise ValueError(
65-
f"The default value '{field_info.default}' for discriminator field '{field}' "
66-
f"cannot be converted to a valid enum key. The converted key '{enum_key}' "
94+
f"The default value {field_info.default!r} for discriminator field {field!r} "
95+
f"cannot be converted to a valid enum key. The converted key {enum_key!r} "
6796
f"must be a valid Python identifier."
6897
)
6998
return self
99+
100+
@cached_property
101+
def config_cls(self) -> type[ConfigBase]:
102+
return self._load(self.config_class_name)
103+
104+
@cached_property
105+
def task_cls(self) -> type[ConfigurableTask]:
106+
return self._load(self.task_class_name)
107+
108+
@staticmethod
109+
def _load(fully_qualified_object: str) -> type:
110+
module_name, object_name = _get_module_and_object_names(fully_qualified_object)
111+
module = importlib.import_module(module_name)
112+
try:
113+
return getattr(module, object_name)
114+
except AttributeError:
115+
raise PluginLoadError(f"Could not find class {object_name!r} in module {module_name!r}")
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from data_designer.plugins.testing.utils import is_valid_plugin
5+
6+
__all__ = [
7+
is_valid_plugin.__name__,
8+
]
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from typing import Literal
5+
6+
from data_designer.config.base import ConfigBase
7+
from data_designer.config.column_configs import SingleColumnConfig
8+
from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata
9+
10+
MODULE_NAME = __name__
11+
12+
13+
class ValidTestConfig(SingleColumnConfig):
14+
"""Valid config for testing plugin creation."""
15+
16+
column_type: Literal["test-generator"] = "test-generator"
17+
name: str
18+
19+
20+
class ValidTestTask(ConfigurableTask[ValidTestConfig]):
21+
"""Valid task for testing plugin creation."""
22+
23+
@staticmethod
24+
def metadata() -> ConfigurableTaskMetadata:
25+
return ConfigurableTaskMetadata(
26+
name="test_generator",
27+
description="Test generator",
28+
required_resources=None,
29+
)
30+
31+
32+
class ConfigWithoutDiscriminator(ConfigBase):
33+
some_field: str
34+
35+
36+
class ConfigWithStringField(ConfigBase):
37+
column_type: str = "test-generator"
38+
39+
40+
class ConfigWithNonStringDefault(ConfigBase):
41+
column_type: Literal["test-generator"] = 123 # type: ignore
42+
43+
44+
class ConfigWithInvalidKey(ConfigBase):
45+
column_type: Literal["invalid-key-!@#"] = "invalid-key-!@#"
46+
47+
48+
class StubPluginConfigA(SingleColumnConfig):
49+
column_type: Literal["test-plugin-a"] = "test-plugin-a"
50+
51+
52+
class StubPluginConfigB(SingleColumnConfig):
53+
column_type: Literal["test-plugin-b"] = "test-plugin-b"
54+
55+
56+
class StubPluginTaskA(ConfigurableTask[StubPluginConfigA]):
57+
@staticmethod
58+
def metadata() -> ConfigurableTaskMetadata:
59+
return ConfigurableTaskMetadata(
60+
name="test_plugin_a",
61+
description="Test plugin A",
62+
required_resources=None,
63+
)
64+
65+
66+
class StubPluginTaskB(ConfigurableTask[StubPluginConfigB]):
67+
@staticmethod
68+
def metadata() -> ConfigurableTaskMetadata:
69+
return ConfigurableTaskMetadata(
70+
name="test_plugin_b",
71+
description="Test plugin B",
72+
required_resources=None,
73+
)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from data_designer.config.base import ConfigBase
5+
from data_designer.engine.configurable_task import ConfigurableTask
6+
from data_designer.plugins.plugin import Plugin
7+
8+
9+
def is_valid_plugin(plugin: Plugin) -> bool:
10+
if not isinstance(plugin.config_cls, ConfigBase):
11+
return False
12+
if not isinstance(plugin.task_cls, ConfigurableTask):
13+
return False
14+
15+
return True

0 commit comments

Comments
 (0)