Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
8d99604
separate column configs and types
johnnygreco Nov 7, 2025
8604dc1
create plugin object
johnnygreco Nov 7, 2025
3223d12
create plugin manager
johnnygreco Nov 8, 2025
ef5ece6
fix config integration
johnnygreco Nov 8, 2025
e648aa9
make base task registry raise on collision false by default
johnnygreco Nov 8, 2025
15ed0f8
update registry test after raise on collision default update
johnnygreco Nov 8, 2025
feb9817
make analysis work using general stats calculation
johnnygreco Nov 9, 2025
a00e858
default -> builtin
johnnygreco Nov 10, 2025
74d3308
use entry point approach instead
johnnygreco Nov 10, 2025
1ec27fd
rewire using plugin helpers
johnnygreco Nov 10, 2025
e648b0f
add env var to disable plugins
johnnygreco Nov 10, 2025
f7e708a
fix tests
johnnygreco Nov 10, 2025
f3e392e
update plugin manager tests
johnnygreco Nov 10, 2025
6a9b011
add tests for plugin helpers
johnnygreco Nov 10, 2025
be273b0
update license headers
johnnygreco Nov 10, 2025
808fd0c
add emoji
johnnygreco Nov 10, 2025
c987509
not using the pm in the builder code
johnnygreco Nov 10, 2025
dcc6ee8
Update src/data_designer/plugins/manager.py
johnnygreco Nov 10, 2025
43c8f5d
Update src/data_designer/plugins/manager.py
johnnygreco Nov 10, 2025
34084ad
Update src/data_designer/plugins/manager.py
johnnygreco Nov 10, 2025
ba06651
merge plugin registry into the manager
johnnygreco Nov 10, 2025
31a1d9b
small pr feedback
johnnygreco Nov 11, 2025
cd4183b
client side plugin manager
johnnygreco Nov 11, 2025
4bccea7
builtin -> default; move adding plugins to registry
johnnygreco Nov 11, 2025
c752bd4
update method names to better match what they do
johnnygreco Nov 11, 2025
cea9209
use register verb for consistency with other registries
johnnygreco Nov 11, 2025
2b44686
thread safety updates; make discover private
johnnygreco Nov 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ dev = [
"pytest>=8.3.3",
"pytest-asyncio>=0.24.0",
"pytest-cov>=7.0.0",
"pytest-env>=1.2.0",
"pytest-httpx>=0.35.0",
]
docs = [
Expand Down Expand Up @@ -89,6 +90,9 @@ version-file = "src/data_designer/_version.py"
[tool.pytest.ini_options]
testpaths = ["tests"]
asyncio_default_fixture_loop_scope = "session"
env = [
"DISABLE_DATA_DESIGNER_PLUGINS=true",
]

[tool.uv]
package = true
Expand Down
57 changes: 41 additions & 16 deletions src/data_designer/config/analysis/column_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@

from abc import ABC, abstractmethod
from enum import Enum
from typing import Annotated, Any, Literal, Optional, Union
from typing import Any, Literal, Optional, Union

from pandas import Series
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from pydantic import BaseModel, ConfigDict, create_model, field_validator, model_validator
from typing_extensions import Self, TypeAlias

from ..columns import DataDesignerColumnType
from ...plugin_manager import PluginManager
from ..column_types import DataDesignerColumnType
from ..sampler_params import SamplerType
from ..utils.constants import EPSILON
from ..utils.numerical_helpers import is_float, is_int, prepare_number_for_reporting
Expand Down Expand Up @@ -238,17 +239,41 @@ def from_series(cls, series: Series) -> Self:
)


ColumnStatisticsT: TypeAlias = Annotated[
Union[
GeneralColumnStatistics,
LLMTextColumnStatistics,
LLMCodeColumnStatistics,
LLMStructuredColumnStatistics,
LLMJudgedColumnStatistics,
SamplerColumnStatistics,
SeedDatasetColumnStatistics,
ValidationColumnStatistics,
ExpressionColumnStatistics,
],
Field(discriminator="column_type"),
ColumnStatisticsT: TypeAlias = Union[
GeneralColumnStatistics,
LLMTextColumnStatistics,
LLMCodeColumnStatistics,
LLMStructuredColumnStatistics,
LLMJudgedColumnStatistics,
SamplerColumnStatistics,
SeedDatasetColumnStatistics,
ValidationColumnStatistics,
ExpressionColumnStatistics,
]


DEFAULT_COLUMN_STATISTICS_MAP = {
DataDesignerColumnType.EXPRESSION: ExpressionColumnStatistics,
DataDesignerColumnType.LLM_CODE: LLMCodeColumnStatistics,
DataDesignerColumnType.LLM_JUDGE: LLMJudgedColumnStatistics,
DataDesignerColumnType.LLM_STRUCTURED: LLMStructuredColumnStatistics,
DataDesignerColumnType.LLM_TEXT: LLMTextColumnStatistics,
DataDesignerColumnType.SAMPLER: SamplerColumnStatistics,
DataDesignerColumnType.SEED_DATASET: SeedDatasetColumnStatistics,
DataDesignerColumnType.VALIDATION: ValidationColumnStatistics,
}

for plugin in PluginManager().get_column_generator_plugins():
# Dynamically create a statistics class for this plugin using Pydantic's create_model
plugin_stats_cls_name = f"{plugin.config_type_as_class_name}ColumnStatistics"

# Create the class with proper Pydantic field
plugin_stats_cls = create_model(
plugin_stats_cls_name,
__base__=GeneralColumnStatistics,
column_type=(Literal[plugin.name], plugin.name),
)

# Add the plugin statistics class to the union
ColumnStatisticsT |= plugin_stats_cls
DEFAULT_COLUMN_STATISTICS_MAP[DataDesignerColumnType(plugin.name)] = plugin_stats_cls
6 changes: 3 additions & 3 deletions src/data_designer/config/analysis/dataset_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

from functools import cached_property
from pathlib import Path
from typing import Optional, Union
from typing import Annotated, Optional, Union

from pydantic import BaseModel, Field, field_validator

from ..columns import DataDesignerColumnType, get_column_display_order
from ..column_types import DataDesignerColumnType, get_column_display_order
from ..utils.constants import EPSILON
from ..utils.numerical_helpers import prepare_number_for_reporting
from .column_profilers import ColumnProfilerResultsT
Expand All @@ -18,7 +18,7 @@
class DatasetProfilerResults(BaseModel):
num_records: int
target_num_records: int
column_statistics: list[ColumnStatisticsT] = Field(..., min_length=1)
column_statistics: list[Annotated[ColumnStatisticsT, Field(discriminator="column_type")]] = Field(..., min_length=1)
side_effect_column_names: Optional[list[str]] = None
column_profiles: Optional[list[ColumnProfilerResultsT]] = None

Expand Down
3 changes: 1 addition & 2 deletions src/data_designer/config/analysis/utils/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from rich.text import Text

from ...analysis.column_statistics import CategoricalHistogramData
from ...columns import COLUMN_TYPE_EMOJI_MAP, DataDesignerColumnType, get_column_display_order
from ...column_types import COLUMN_TYPE_EMOJI_MAP, DataDesignerColumnType, get_column_display_order
from ...utils.visualization import (
ColorPalette,
convert_to_row_element,
Expand All @@ -27,7 +27,6 @@
if TYPE_CHECKING:
from ...analysis.dataset_profiler import DatasetProfilerResults


HEADER_STYLE = "dim"
RULE_STYLE = f"bold {ColorPalette.NVIDIA_GREEN.value}"
ACCENT_STYLE = f"bold {ColorPalette.BLUE.value}"
Expand Down
130 changes: 130 additions & 0 deletions src/data_designer/config/column_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from abc import ABC
from typing import Literal, Optional, Type, Union

from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self

from .base import ConfigBase
from .errors import InvalidConfigError
from .models import ImageContext
from .sampler_params import SamplerParamsT, SamplerType
from .utils.code_lang import CodeLang
from .utils.constants import REASONING_TRACE_COLUMN_POSTFIX
from .utils.misc import assert_valid_jinja2_template, get_prompt_template_keywords
from .validator_params import ValidatorParamsT, ValidatorType


class SingleColumnConfig(ConfigBase, ABC):
name: str
drop: bool = False
column_type: str

@property
def required_columns(self) -> list[str]:
return []

@property
def side_effect_columns(self) -> list[str]:
return []


class SamplerColumnConfig(SingleColumnConfig):
sampler_type: SamplerType
params: SamplerParamsT
conditional_params: dict[str, SamplerParamsT] = {}
convert_to: Optional[str] = None
column_type: Literal["sampler"] = "sampler"


class LLMTextColumnConfig(SingleColumnConfig):
prompt: str
model_alias: str
system_prompt: Optional[str] = None
multi_modal_context: Optional[list[ImageContext]] = None
column_type: Literal["llm-text"] = "llm-text"

@property
def required_columns(self) -> list[str]:
required_cols = list(get_prompt_template_keywords(self.prompt))
if self.system_prompt:
required_cols.extend(list(get_prompt_template_keywords(self.system_prompt)))
return list(set(required_cols))

@property
def side_effect_columns(self) -> list[str]:
return [f"{self.name}{REASONING_TRACE_COLUMN_POSTFIX}"]

@model_validator(mode="after")
def assert_prompt_valid_jinja(self) -> Self:
assert_valid_jinja2_template(self.prompt)
if self.system_prompt:
assert_valid_jinja2_template(self.system_prompt)
return self


class LLMCodeColumnConfig(LLMTextColumnConfig):
code_lang: CodeLang
column_type: Literal["llm-code"] = "llm-code"


class LLMStructuredColumnConfig(LLMTextColumnConfig):
output_format: Union[dict, Type[BaseModel]]
column_type: Literal["llm-structured"] = "llm-structured"

@model_validator(mode="after")
def validate_output_format(self) -> Self:
if not isinstance(self.output_format, dict) and issubclass(self.output_format, BaseModel):
self.output_format = self.output_format.model_json_schema()
return self


class Score(ConfigBase):
name: str = Field(..., description="A clear name for this score.")
description: str = Field(..., description="An informative and detailed assessment guide for using this score.")
options: dict[Union[int, str], str] = Field(..., description="Score options in the format of {score: description}.")


class LLMJudgeColumnConfig(LLMTextColumnConfig):
scores: list[Score] = Field(..., min_length=1)
column_type: Literal["llm-judge"] = "llm-judge"


class ExpressionColumnConfig(SingleColumnConfig):
name: str
expr: str
dtype: Literal["int", "float", "str", "bool"] = "str"
column_type: Literal["expression"] = "expression"

@property
def required_columns(self) -> list[str]:
return list(get_prompt_template_keywords(self.expr))

@model_validator(mode="after")
def assert_expression_valid_jinja(self) -> Self:
if not self.expr.strip():
raise InvalidConfigError(
f"🛑 Expression column '{self.name}' has an empty or whitespace-only expression. "
f"Please provide a valid Jinja2 expression (e.g., '{{ column_name }}' or '{{ col1 }} + {{ col2 }}') "
"or remove this column if not needed."
)
assert_valid_jinja2_template(self.expr)
return self


class ValidationColumnConfig(SingleColumnConfig):
target_columns: list[str]
validator_type: ValidatorType
validator_params: ValidatorParamsT
batch_size: int = Field(default=10, ge=1, description="Number of records to process in each batch")
column_type: Literal["validation"] = "validation"

@property
def required_columns(self) -> list[str]:
return self.target_columns


class SeedDatasetColumnConfig(SingleColumnConfig):
column_type: Literal["seed-dataset"] = "seed-dataset"
Loading