Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/dco-assistant.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
steps:
- name: "DCO Assistant"
if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the Contributor Agreement including DCO and I hereby sign the Contributor Agreement and DCO') || github.event_name == 'pull_request_target'
uses: contributor-assistant/github-action@v2.6.1
uses: contributor-assistant/github-action@ca4a40a7d1004f18d9960b404b97e5f30a505a08
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PERSONAL_ACCESS_TOKEN: ${{ secrets.DCO_ASSISTANT_TOKEN }}
Expand Down
4 changes: 2 additions & 2 deletions src/data_designer/config/analysis/dataset_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from pydantic import BaseModel, Field, field_validator

from ..columns import DataDesignerColumnType
from ..columns import DataDesignerColumnType, get_column_display_order
from ..utils.constants import EPSILON
from ..utils.numerical_helpers import prepare_number_for_reporting
from .column_profilers import ColumnProfilerResultsT
Expand All @@ -32,7 +32,7 @@ def percent_complete(self) -> float:

@cached_property
def column_types(self) -> list[str]:
display_order = DataDesignerColumnType.get_display_order()
display_order = get_column_display_order()
return sorted(
list(set([c.column_type for c in self.column_statistics])),
key=lambda x: display_order.index(x) if x in display_order else len(display_order),
Expand Down
4 changes: 2 additions & 2 deletions src/data_designer/config/analysis/utils/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from rich.text import Text

from ...analysis.column_statistics import CategoricalHistogramData
from ...columns import COLUMN_TYPE_EMOJI_MAP, DataDesignerColumnType
from ...columns import COLUMN_TYPE_EMOJI_MAP, DataDesignerColumnType, get_column_display_order
from ...utils.visualization import (
ColorPalette,
convert_to_row_element,
Expand Down Expand Up @@ -44,7 +44,7 @@ class ReportSection(str, Enum):
DEFAULT_INCLUDE_SECTIONS = [
ReportSection.OVERVIEW,
ReportSection.COLUMN_PROFILERS,
] + DataDesignerColumnType.get_display_order()
] + get_column_display_order()


def generate_analysis_report(
Expand Down
157 changes: 69 additions & 88 deletions src/data_designer/config/columns.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
from enum import Enum
from abc import ABC
from typing import Literal, Optional, Type, Union

from pydantic import BaseModel, Field, model_validator
Expand All @@ -15,56 +14,14 @@
from .utils.code_lang import CodeLang
from .utils.constants import REASONING_TRACE_COLUMN_POSTFIX
from .utils.misc import assert_valid_jinja2_template, get_prompt_template_keywords
from .utils.type_helpers import SAMPLER_PARAMS, resolve_string_enum
from .utils.type_helpers import SAMPLER_PARAMS, create_str_enum_from_discriminated_type_union, resolve_string_enum
from .validator_params import ValidatorParamsT, ValidatorType


class DataDesignerColumnType(str, Enum):
SAMPLER = "sampler"
LLM_TEXT = "llm-text"
LLM_CODE = "llm-code"
LLM_STRUCTURED = "llm-structured"
LLM_JUDGE = "llm-judge"
EXPRESSION = "expression"
VALIDATION = "validation"
SEED_DATASET = "seed-dataset"

@staticmethod
def get_display_order() -> list[Self]:
return [
DataDesignerColumnType.SEED_DATASET,
DataDesignerColumnType.SAMPLER,
DataDesignerColumnType.LLM_TEXT,
DataDesignerColumnType.LLM_CODE,
DataDesignerColumnType.LLM_STRUCTURED,
DataDesignerColumnType.LLM_JUDGE,
DataDesignerColumnType.VALIDATION,
DataDesignerColumnType.EXPRESSION,
]

@property
def has_prompt_templates(self) -> bool:
return self in [self.LLM_TEXT, self.LLM_CODE, self.LLM_STRUCTURED, self.LLM_JUDGE]

@property
def is_dag_column_type(self) -> bool:
return self in [
self.EXPRESSION,
self.LLM_CODE,
self.LLM_JUDGE,
self.LLM_STRUCTURED,
self.LLM_TEXT,
self.VALIDATION,
]


class SingleColumnConfig(ConfigBase, ABC):
name: str
drop: bool = False

@property
@abstractmethod
def column_type(self) -> DataDesignerColumnType: ...
column_type: str

@property
def required_columns(self) -> list[str]:
Expand All @@ -80,21 +37,15 @@ class SamplerColumnConfig(SingleColumnConfig):
params: SamplerParamsT
conditional_params: dict[str, SamplerParamsT] = {}
convert_to: Optional[str] = None

@property
def column_type(self) -> DataDesignerColumnType:
return DataDesignerColumnType.SAMPLER
column_type: Literal["sampler"] = "sampler"


class LLMTextColumnConfig(SingleColumnConfig):
prompt: str
model_alias: str
system_prompt: Optional[str] = None
multi_modal_context: Optional[list[ImageContext]] = None

@property
def column_type(self) -> DataDesignerColumnType:
return DataDesignerColumnType.LLM_TEXT
column_type: Literal["llm-text"] = "llm-text"

@property
def required_columns(self) -> list[str]:
Expand All @@ -117,18 +68,12 @@ def assert_prompt_valid_jinja(self) -> Self:

class LLMCodeColumnConfig(LLMTextColumnConfig):
code_lang: CodeLang

@property
def column_type(self) -> DataDesignerColumnType:
return DataDesignerColumnType.LLM_CODE
column_type: Literal["llm-code"] = "llm-code"


class LLMStructuredColumnConfig(LLMTextColumnConfig):
output_format: Union[dict, Type[BaseModel]]

@property
def column_type(self) -> DataDesignerColumnType:
return DataDesignerColumnType.LLM_STRUCTURED
column_type: Literal["llm-structured"] = "llm-structured"

@model_validator(mode="after")
def validate_output_format(self) -> Self:
Expand All @@ -145,20 +90,14 @@ class Score(ConfigBase):

class LLMJudgeColumnConfig(LLMTextColumnConfig):
scores: list[Score] = Field(..., min_length=1)

@property
def column_type(self) -> DataDesignerColumnType:
return DataDesignerColumnType.LLM_JUDGE
column_type: Literal["llm-judge"] = "llm-judge"


class ExpressionColumnConfig(SingleColumnConfig):
name: str
expr: str
dtype: Literal["int", "float", "str", "bool"] = "str"

@property
def column_type(self) -> DataDesignerColumnType:
return DataDesignerColumnType.EXPRESSION
column_type: Literal["expression"] = "expression"

@property
def required_columns(self) -> list[str]:
Expand All @@ -168,7 +107,9 @@ def required_columns(self) -> list[str]:
def assert_expression_valid_jinja(self) -> Self:
if not self.expr.strip():
raise InvalidConfigError(
f"🛑 Expression column '{self.name}' has an empty or whitespace-only expression. Please provide a valid Jinja2 expression (e.g., '{{ column_name }}' or '{{ col1 }} + {{ col2 }}') or remove this column if not needed."
f"🛑 Expression column '{self.name}' has an empty or whitespace-only expression. "
f"Please provide a valid Jinja2 expression (e.g., '{{ column_name }}' or '{{ col1 }} + {{ col2 }}') "
"or remove this column if not needed."
)
assert_valid_jinja2_template(self.expr)
return self
Expand All @@ -179,20 +120,34 @@ class ValidationColumnConfig(SingleColumnConfig):
validator_type: ValidatorType
validator_params: ValidatorParamsT
batch_size: int = Field(default=10, ge=1, description="Number of records to process in each batch")

@property
def column_type(self) -> DataDesignerColumnType:
return DataDesignerColumnType.VALIDATION
column_type: Literal["validation"] = "validation"

@property
def required_columns(self) -> list[str]:
return self.target_columns


class SeedDatasetColumnConfig(SingleColumnConfig):
@property
def column_type(self) -> DataDesignerColumnType:
return DataDesignerColumnType.SEED_DATASET
column_type: Literal["seed-dataset"] = "seed-dataset"


ColumnConfigT: TypeAlias = Union[
ExpressionColumnConfig,
LLMCodeColumnConfig,
LLMJudgeColumnConfig,
LLMStructuredColumnConfig,
LLMTextColumnConfig,
SamplerColumnConfig,
SeedDatasetColumnConfig,
ValidationColumnConfig,
]


DataDesignerColumnType = create_str_enum_from_discriminated_type_union(
enum_name="DataDesignerColumnType",
type_union=ColumnConfigT,
discriminator_field_name="column_type",
)
Comment on lines +146 to +150
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's where we dynamically create this StrEnum.



COLUMN_TYPE_EMOJI_MAP = {
Expand All @@ -208,16 +163,28 @@ def column_type(self) -> DataDesignerColumnType:
}


ColumnConfigT: TypeAlias = Union[
ExpressionColumnConfig,
LLMCodeColumnConfig,
LLMJudgeColumnConfig,
LLMStructuredColumnConfig,
LLMTextColumnConfig,
SamplerColumnConfig,
SeedDatasetColumnConfig,
ValidationColumnConfig,
]
def column_type_used_in_execution_dag(column_type: Union[str, DataDesignerColumnType]) -> bool:
"""Return True if the column type is used in the workflow execution DAG."""
column_type = resolve_string_enum(column_type, DataDesignerColumnType)
return column_type in {
DataDesignerColumnType.EXPRESSION,
DataDesignerColumnType.LLM_CODE,
DataDesignerColumnType.LLM_JUDGE,
DataDesignerColumnType.LLM_STRUCTURED,
DataDesignerColumnType.LLM_TEXT,
DataDesignerColumnType.VALIDATION,
Comment on lines +170 to +175
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does our type checker have an issue with this? Since DataDesignerColumnType is dynamically resolved?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're not running ty yet, but my IDE seems happy with all of it. If you can pull the branch to check your IDE settings, that might be helpful.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's change these to sets

}


def column_type_is_llm_generated(column_type: Union[str, DataDesignerColumnType]) -> bool:
"""Return True if the column type is an LLM-generated column."""
column_type = resolve_string_enum(column_type, DataDesignerColumnType)
return column_type in {
DataDesignerColumnType.LLM_TEXT,
DataDesignerColumnType.LLM_CODE,
DataDesignerColumnType.LLM_STRUCTURED,
DataDesignerColumnType.LLM_JUDGE,
}


def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType, **kwargs) -> ColumnConfigT:
Expand Down Expand Up @@ -251,6 +218,20 @@ def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType
raise InvalidColumnTypeError(f"🛑 {column_type} is not a valid column type.") # pragma: no cover


def get_column_display_order() -> list[DataDesignerColumnType]:
"""Return the preferred display order of the column types."""
return [
DataDesignerColumnType.SEED_DATASET,
DataDesignerColumnType.SAMPLER,
DataDesignerColumnType.LLM_TEXT,
DataDesignerColumnType.LLM_CODE,
DataDesignerColumnType.LLM_STRUCTURED,
DataDesignerColumnType.LLM_JUDGE,
DataDesignerColumnType.VALIDATION,
DataDesignerColumnType.EXPRESSION,
]


def _resolve_sampler_kwargs(name: str, kwargs: dict) -> dict:
if "sampler_type" not in kwargs:
raise InvalidConfigError(f"🛑 `sampler_type` is required for sampler column '{name}'.")
Expand Down
10 changes: 8 additions & 2 deletions src/data_designer/config/config_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@

from .analysis.column_profilers import ColumnProfilerConfigT
from .base import ExportableConfigBase
from .columns import ColumnConfigT, DataDesignerColumnType, SeedDatasetColumnConfig, get_column_config_from_kwargs
from .columns import (
ColumnConfigT,
DataDesignerColumnType,
SeedDatasetColumnConfig,
column_type_is_llm_generated,
get_column_config_from_kwargs,
)
from .data_designer_config import DataDesignerConfig
from .dataset_builders import BuildStage
from .datastore import DatastoreSettings, fetch_seed_dataset_column_names
Expand Down Expand Up @@ -449,7 +455,7 @@ def get_llm_gen_columns(self) -> list[ColumnConfigT]:
Returns:
A list of column configurations that use LLM generation.
"""
return [c for c in self._column_configs.values() if c.column_type.has_prompt_templates]
return [c for c in self._column_configs.values() if column_type_is_llm_generated(c.column_type)]

def get_columns_of_type(self, column_type: DataDesignerColumnType) -> list[ColumnConfigT]:
"""Get all column configurations of the specified type.
Expand Down
4 changes: 2 additions & 2 deletions src/data_designer/config/data_designer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from __future__ import annotations

from typing import Optional
from typing import Annotated, Optional

from pydantic import Field

Expand Down Expand Up @@ -32,7 +32,7 @@ class DataDesignerConfig(ExportableConfigBase):
profilers: Optional list of column profilers for analyzing generated data characteristics.
"""

columns: list[ColumnConfigT] = Field(min_length=1)
columns: list[Annotated[ColumnConfigT, Field(discriminator="column_type")]] = Field(min_length=1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now we can do this

model_configs: Optional[list[ModelConfig]] = None
seed_config: Optional[SeedConfig] = None
constraints: Optional[list[ColumnConstraintT]] = None
Expand Down
6 changes: 6 additions & 0 deletions src/data_designer/config/utils/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,10 @@ class UserJinjaTemplateSyntaxError(DataDesignerError): ...
class InvalidEnumValueError(DataDesignerError): ...


class InvalidTypeUnionError(DataDesignerError): ...


class InvalidDiscriminatorFieldError(DataDesignerError): ...


class DatasetSampleDisplayError(DataDesignerError): ...
Loading
Loading