Skip to content

Commit fdbc012

Browse files
feat: 🔌 Initial plugin system implementation (#23)
* separate column configs and types * create plugin object * create plugin manager * fix config integration * make base task registry raise on collision false by default * update registry test after raise on collision default update * make analysis work using general stats calculation * default -> builtin * use entry point approach instead * rewire using plugin helpers * add env var to disable plugins * fix tests * update plugin manager tests * add tests for plugin helpers * update license headers * add emoji * not using the pm in the builder code * Update src/data_designer/plugins/manager.py Co-authored-by: Nabin Mulepati <[email protected]> * Update src/data_designer/plugins/manager.py Co-authored-by: Nabin Mulepati <[email protected]> * Update src/data_designer/plugins/manager.py Co-authored-by: Nabin Mulepati <[email protected]> * merge plugin registry into the manager * small pr feedback * client side plugin manager * builtin -> default; move adding plugins to registry * update method names to better match what they do * use register verb for consistency with other registries * thread safety updates; make discover private --------- Co-authored-by: Nabin Mulepati <[email protected]>
1 parent aa22900 commit fdbc012

File tree

67 files changed

+1418
-425
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+1418
-425
lines changed

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ dev = [
5959
"pytest>=8.3.3",
6060
"pytest-asyncio>=0.24.0",
6161
"pytest-cov>=7.0.0",
62+
"pytest-env>=1.2.0",
6263
"pytest-httpx>=0.35.0",
6364
]
6465
docs = [
@@ -89,6 +90,9 @@ version-file = "src/data_designer/_version.py"
8990
[tool.pytest.ini_options]
9091
testpaths = ["tests"]
9192
asyncio_default_fixture_loop_scope = "session"
93+
env = [
94+
"DISABLE_DATA_DESIGNER_PLUGINS=true",
95+
]
9296

9397
[tool.uv]
9498
package = true

src/data_designer/config/analysis/column_statistics.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55

66
from abc import ABC, abstractmethod
77
from enum import Enum
8-
from typing import Annotated, Any, Literal, Optional, Union
8+
from typing import Any, Literal, Optional, Union
99

1010
from pandas import Series
11-
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
11+
from pydantic import BaseModel, ConfigDict, create_model, field_validator, model_validator
1212
from typing_extensions import Self, TypeAlias
1313

14-
from ..columns import DataDesignerColumnType
14+
from ...plugin_manager import PluginManager
15+
from ..column_types import DataDesignerColumnType
1516
from ..sampler_params import SamplerType
1617
from ..utils.constants import EPSILON
1718
from ..utils.numerical_helpers import is_float, is_int, prepare_number_for_reporting
@@ -238,17 +239,41 @@ def from_series(cls, series: Series) -> Self:
238239
)
239240

240241

241-
ColumnStatisticsT: TypeAlias = Annotated[
242-
Union[
243-
GeneralColumnStatistics,
244-
LLMTextColumnStatistics,
245-
LLMCodeColumnStatistics,
246-
LLMStructuredColumnStatistics,
247-
LLMJudgedColumnStatistics,
248-
SamplerColumnStatistics,
249-
SeedDatasetColumnStatistics,
250-
ValidationColumnStatistics,
251-
ExpressionColumnStatistics,
252-
],
253-
Field(discriminator="column_type"),
242+
ColumnStatisticsT: TypeAlias = Union[
243+
GeneralColumnStatistics,
244+
LLMTextColumnStatistics,
245+
LLMCodeColumnStatistics,
246+
LLMStructuredColumnStatistics,
247+
LLMJudgedColumnStatistics,
248+
SamplerColumnStatistics,
249+
SeedDatasetColumnStatistics,
250+
ValidationColumnStatistics,
251+
ExpressionColumnStatistics,
254252
]
253+
254+
255+
DEFAULT_COLUMN_STATISTICS_MAP = {
256+
DataDesignerColumnType.EXPRESSION: ExpressionColumnStatistics,
257+
DataDesignerColumnType.LLM_CODE: LLMCodeColumnStatistics,
258+
DataDesignerColumnType.LLM_JUDGE: LLMJudgedColumnStatistics,
259+
DataDesignerColumnType.LLM_STRUCTURED: LLMStructuredColumnStatistics,
260+
DataDesignerColumnType.LLM_TEXT: LLMTextColumnStatistics,
261+
DataDesignerColumnType.SAMPLER: SamplerColumnStatistics,
262+
DataDesignerColumnType.SEED_DATASET: SeedDatasetColumnStatistics,
263+
DataDesignerColumnType.VALIDATION: ValidationColumnStatistics,
264+
}
265+
266+
for plugin in PluginManager().get_column_generator_plugins():
267+
# Dynamically create a statistics class for this plugin using Pydantic's create_model
268+
plugin_stats_cls_name = f"{plugin.config_type_as_class_name}ColumnStatistics"
269+
270+
# Create the class with proper Pydantic field
271+
plugin_stats_cls = create_model(
272+
plugin_stats_cls_name,
273+
__base__=GeneralColumnStatistics,
274+
column_type=(Literal[plugin.name], plugin.name),
275+
)
276+
277+
# Add the plugin statistics class to the union
278+
ColumnStatisticsT |= plugin_stats_cls
279+
DEFAULT_COLUMN_STATISTICS_MAP[DataDesignerColumnType(plugin.name)] = plugin_stats_cls

src/data_designer/config/analysis/dataset_profiler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33

44
from functools import cached_property
55
from pathlib import Path
6-
from typing import Optional, Union
6+
from typing import Annotated, Optional, Union
77

88
from pydantic import BaseModel, Field, field_validator
99

10-
from ..columns import DataDesignerColumnType, get_column_display_order
10+
from ..column_types import DataDesignerColumnType, get_column_display_order
1111
from ..utils.constants import EPSILON
1212
from ..utils.numerical_helpers import prepare_number_for_reporting
1313
from .column_profilers import ColumnProfilerResultsT
@@ -18,7 +18,7 @@
1818
class DatasetProfilerResults(BaseModel):
1919
num_records: int
2020
target_num_records: int
21-
column_statistics: list[ColumnStatisticsT] = Field(..., min_length=1)
21+
column_statistics: list[Annotated[ColumnStatisticsT, Field(discriminator="column_type")]] = Field(..., min_length=1)
2222
side_effect_column_names: Optional[list[str]] = None
2323
column_profiles: Optional[list[ColumnProfilerResultsT]] = None
2424

src/data_designer/config/analysis/utils/reporting.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from rich.text import Text
1616

1717
from ...analysis.column_statistics import CategoricalHistogramData
18-
from ...columns import COLUMN_TYPE_EMOJI_MAP, DataDesignerColumnType, get_column_display_order
18+
from ...column_types import COLUMN_TYPE_EMOJI_MAP, DataDesignerColumnType, get_column_display_order
1919
from ...utils.visualization import (
2020
ColorPalette,
2121
convert_to_row_element,
@@ -27,7 +27,6 @@
2727
if TYPE_CHECKING:
2828
from ...analysis.dataset_profiler import DatasetProfilerResults
2929

30-
3130
HEADER_STYLE = "dim"
3231
RULE_STYLE = f"bold {ColorPalette.NVIDIA_GREEN.value}"
3332
ACCENT_STYLE = f"bold {ColorPalette.BLUE.value}"
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from abc import ABC
5+
from typing import Literal, Optional, Type, Union
6+
7+
from pydantic import BaseModel, Field, model_validator
8+
from typing_extensions import Self
9+
10+
from .base import ConfigBase
11+
from .errors import InvalidConfigError
12+
from .models import ImageContext
13+
from .sampler_params import SamplerParamsT, SamplerType
14+
from .utils.code_lang import CodeLang
15+
from .utils.constants import REASONING_TRACE_COLUMN_POSTFIX
16+
from .utils.misc import assert_valid_jinja2_template, get_prompt_template_keywords
17+
from .validator_params import ValidatorParamsT, ValidatorType
18+
19+
20+
class SingleColumnConfig(ConfigBase, ABC):
21+
name: str
22+
drop: bool = False
23+
column_type: str
24+
25+
@property
26+
def required_columns(self) -> list[str]:
27+
return []
28+
29+
@property
30+
def side_effect_columns(self) -> list[str]:
31+
return []
32+
33+
34+
class SamplerColumnConfig(SingleColumnConfig):
35+
sampler_type: SamplerType
36+
params: SamplerParamsT
37+
conditional_params: dict[str, SamplerParamsT] = {}
38+
convert_to: Optional[str] = None
39+
column_type: Literal["sampler"] = "sampler"
40+
41+
42+
class LLMTextColumnConfig(SingleColumnConfig):
43+
prompt: str
44+
model_alias: str
45+
system_prompt: Optional[str] = None
46+
multi_modal_context: Optional[list[ImageContext]] = None
47+
column_type: Literal["llm-text"] = "llm-text"
48+
49+
@property
50+
def required_columns(self) -> list[str]:
51+
required_cols = list(get_prompt_template_keywords(self.prompt))
52+
if self.system_prompt:
53+
required_cols.extend(list(get_prompt_template_keywords(self.system_prompt)))
54+
return list(set(required_cols))
55+
56+
@property
57+
def side_effect_columns(self) -> list[str]:
58+
return [f"{self.name}{REASONING_TRACE_COLUMN_POSTFIX}"]
59+
60+
@model_validator(mode="after")
61+
def assert_prompt_valid_jinja(self) -> Self:
62+
assert_valid_jinja2_template(self.prompt)
63+
if self.system_prompt:
64+
assert_valid_jinja2_template(self.system_prompt)
65+
return self
66+
67+
68+
class LLMCodeColumnConfig(LLMTextColumnConfig):
69+
code_lang: CodeLang
70+
column_type: Literal["llm-code"] = "llm-code"
71+
72+
73+
class LLMStructuredColumnConfig(LLMTextColumnConfig):
74+
output_format: Union[dict, Type[BaseModel]]
75+
column_type: Literal["llm-structured"] = "llm-structured"
76+
77+
@model_validator(mode="after")
78+
def validate_output_format(self) -> Self:
79+
if not isinstance(self.output_format, dict) and issubclass(self.output_format, BaseModel):
80+
self.output_format = self.output_format.model_json_schema()
81+
return self
82+
83+
84+
class Score(ConfigBase):
85+
name: str = Field(..., description="A clear name for this score.")
86+
description: str = Field(..., description="An informative and detailed assessment guide for using this score.")
87+
options: dict[Union[int, str], str] = Field(..., description="Score options in the format of {score: description}.")
88+
89+
90+
class LLMJudgeColumnConfig(LLMTextColumnConfig):
91+
scores: list[Score] = Field(..., min_length=1)
92+
column_type: Literal["llm-judge"] = "llm-judge"
93+
94+
95+
class ExpressionColumnConfig(SingleColumnConfig):
96+
name: str
97+
expr: str
98+
dtype: Literal["int", "float", "str", "bool"] = "str"
99+
column_type: Literal["expression"] = "expression"
100+
101+
@property
102+
def required_columns(self) -> list[str]:
103+
return list(get_prompt_template_keywords(self.expr))
104+
105+
@model_validator(mode="after")
106+
def assert_expression_valid_jinja(self) -> Self:
107+
if not self.expr.strip():
108+
raise InvalidConfigError(
109+
f"🛑 Expression column '{self.name}' has an empty or whitespace-only expression. "
110+
f"Please provide a valid Jinja2 expression (e.g., '{{ column_name }}' or '{{ col1 }} + {{ col2 }}') "
111+
"or remove this column if not needed."
112+
)
113+
assert_valid_jinja2_template(self.expr)
114+
return self
115+
116+
117+
class ValidationColumnConfig(SingleColumnConfig):
118+
target_columns: list[str]
119+
validator_type: ValidatorType
120+
validator_params: ValidatorParamsT
121+
batch_size: int = Field(default=10, ge=1, description="Number of records to process in each batch")
122+
column_type: Literal["validation"] = "validation"
123+
124+
@property
125+
def required_columns(self) -> list[str]:
126+
return self.target_columns
127+
128+
129+
class SeedDatasetColumnConfig(SingleColumnConfig):
130+
column_type: Literal["seed-dataset"] = "seed-dataset"

0 commit comments

Comments
 (0)