Skip to content

Commit 4212986

Browse files
committed
scaffold
1 parent 7268290 commit 4212986

File tree

7 files changed

+139
-2
lines changed

7 files changed

+139
-2
lines changed

examples/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
**/artifacts

examples/custom_column.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import pandas as pd
2+
3+
from data_designer.essentials import (
4+
CustomColumnConfig,
5+
DataDesigner,
6+
DataDesignerConfigBuilder,
7+
InferenceParameters,
8+
LoggingConfig,
9+
ModelConfig,
10+
configure_logging,
11+
)
12+
13+
configure_logging(LoggingConfig.debug())
14+
15+
# Initialize NDD and add columns
16+
MODEL_ALIAS = "nano"
17+
SYSTEM_PROMPT = "/no_think"
18+
19+
model_configs = [
20+
ModelConfig(
21+
alias="nano",
22+
model="nvidia/nvidia-nemotron-nano-9b-v2",
23+
inference_parameters=InferenceParameters(
24+
temperature=0.5,
25+
top_p=1.0,
26+
max_tokens=1024,
27+
max_parallel_requests=4,
28+
),
29+
provider="nvidia",
30+
)
31+
]
32+
33+
builder = DataDesignerConfigBuilder(model_configs=model_configs)
34+
35+
builder.add_column(
36+
name="topic",
37+
column_type="sampler",
38+
sampler_type="category",
39+
params={
40+
"values": [
41+
"healthcare",
42+
"finance",
43+
"technology",
44+
]
45+
}
46+
)
47+
48+
builder.add_column(
49+
name="text",
50+
column_type="llm-text",
51+
model_alias=MODEL_ALIAS,
52+
prompt="Write me a paragraph about {{ topic }}.",
53+
system_prompt=SYSTEM_PROMPT,
54+
)
55+
56+
def generator_function(df: pd.DataFrame) -> pd.DataFrame:
57+
df["length_frac"] = df["text"].apply(lambda x: len(x) / 1000)
58+
return df
59+
60+
builder.add_column(
61+
CustomColumnConfig(
62+
name="length_frac",
63+
generator_function=generator_function,
64+
)
65+
)
66+
67+
# Generate dataset
68+
dd = DataDesigner(artifact_path="./artifacts")
69+
dd_preview = dd.preview(builder, num_records=10)
70+
dd_preview.display_sample_record()
71+
72+
dd.create(builder, num_records=20)

src/data_designer/config/columns.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33

44
from abc import ABC, abstractmethod
55
from enum import Enum
6-
from typing import Literal, Optional, Type, Union
6+
from typing import Callable, Literal, Optional, Type, Union
77

8-
from pydantic import BaseModel, Field, model_validator
8+
import pandas as pd
9+
from pydantic import BaseModel, Field, field_serializer, model_validator
910
from typing_extensions import Self, TypeAlias
1011

1112
from .base import ConfigBase
@@ -28,6 +29,7 @@ class DataDesignerColumnType(str, Enum):
2829
EXPRESSION = "expression"
2930
VALIDATION = "validation"
3031
SEED_DATASET = "seed-dataset"
32+
CUSTOM = "custom"
3133

3234
@staticmethod
3335
def get_display_order() -> list[Self]:
@@ -40,6 +42,7 @@ def get_display_order() -> list[Self]:
4042
DataDesignerColumnType.LLM_JUDGE,
4143
DataDesignerColumnType.VALIDATION,
4244
DataDesignerColumnType.EXPRESSION,
45+
DataDesignerColumnType.CUSTOM,
4346
]
4447

4548
@property
@@ -55,6 +58,7 @@ def is_dag_column_type(self) -> bool:
5558
self.LLM_STRUCTURED,
5659
self.LLM_TEXT,
5760
self.VALIDATION,
61+
self.CUSTOM,
5862
]
5963

6064

@@ -195,6 +199,18 @@ def column_type(self) -> DataDesignerColumnType:
195199
return DataDesignerColumnType.SEED_DATASET
196200

197201

202+
class CustomColumnConfig(SingleColumnConfig):
203+
generator_function: Callable[[pd.DataFrame], pd.DataFrame]
204+
205+
@property
206+
def column_type(self) -> DataDesignerColumnType:
207+
return DataDesignerColumnType.CUSTOM
208+
209+
@field_serializer("generator_function")
210+
def serialize_generator_function(self, v: Callable[[pd.DataFrame], pd.DataFrame]) -> str:
211+
return v.__name__
212+
213+
198214
COLUMN_TYPE_EMOJI_MAP = {
199215
"general": "⚛️", # possible analysis column type
200216
DataDesignerColumnType.EXPRESSION: "🧩",
@@ -205,6 +221,7 @@ def column_type(self) -> DataDesignerColumnType:
205221
DataDesignerColumnType.SEED_DATASET: "🌱",
206222
DataDesignerColumnType.SAMPLER: "🎲",
207223
DataDesignerColumnType.VALIDATION: "🔍",
224+
DataDesignerColumnType.CUSTOM: "🛠️",
208225
}
209226

210227

@@ -217,6 +234,7 @@ def column_type(self) -> DataDesignerColumnType:
217234
SamplerColumnConfig,
218235
SeedDatasetColumnConfig,
219236
ValidationColumnConfig,
237+
CustomColumnConfig,
220238
]
221239

222240

@@ -248,6 +266,8 @@ def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType
248266
return SamplerColumnConfig(name=name, **_resolve_sampler_kwargs(name, kwargs))
249267
elif column_type == DataDesignerColumnType.SEED_DATASET:
250268
return SeedDatasetColumnConfig(name=name, **kwargs)
269+
elif column_type == DataDesignerColumnType.CUSTOM:
270+
return CustomColumnConfig(name=name, **kwargs)
251271
raise InvalidColumnTypeError(f"🛑 {column_type} is not a valid column type.") # pragma: no cover
252272

253273

src/data_designer/config/utils/visualization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def display_sample_record(
160160
+ config_builder.get_columns_of_type(DataDesignerColumnType.EXPRESSION)
161161
+ config_builder.get_columns_of_type(DataDesignerColumnType.LLM_TEXT)
162162
+ config_builder.get_columns_of_type(DataDesignerColumnType.LLM_STRUCTURED)
163+
+ config_builder.get_columns_of_type(DataDesignerColumnType.CUSTOM)
163164
)
164165
if len(non_code_columns) > 0:
165166
table = Table(title="Generated Columns", **table_kws)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import logging
5+
6+
import pandas as pd
7+
8+
from data_designer.config.columns import CustomColumnConfig
9+
from data_designer.engine.column_generators.generators.base import (
10+
ColumnGenerator,
11+
GenerationStrategy,
12+
GeneratorMetadata,
13+
)
14+
from data_designer.engine.errors import DataDesignerRuntimeError
15+
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class CustomColumnGenerator(ColumnGenerator[CustomColumnConfig]):
21+
@staticmethod
22+
def metadata() -> GeneratorMetadata:
23+
return GeneratorMetadata(
24+
name="custom",
25+
description="Generate a custom column.",
26+
generation_strategy=GenerationStrategy.FULL_COLUMN,
27+
required_resources=None,
28+
)
29+
30+
def generate(self, data: pd.DataFrame) -> pd.DataFrame:
31+
logger.info(f"🛠️ Generating custom column {self.config.name!r} with {len(data)} records")
32+
logger.info(f" |-- generator function: {self.config.generator_function.__name__}")
33+
34+
try:
35+
result = self.config.generator_function(data)
36+
except Exception as e:
37+
raise DataDesignerRuntimeError(f"Error generating custom column {self.config.name!r}: {e}")
38+
39+
return result

src/data_designer/engine/column_generators/registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from data_designer.config.base import ConfigBase
55
from data_designer.config.columns import (
6+
CustomColumnConfig,
67
DataDesignerColumnType,
78
ExpressionColumnConfig,
89
LLMCodeColumnConfig,
@@ -12,6 +13,7 @@
1213
ValidationColumnConfig,
1314
)
1415
from data_designer.engine.column_generators.generators.base import ColumnGenerator
16+
from data_designer.engine.column_generators.generators.custom import CustomColumnGenerator
1517
from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator
1618
from data_designer.engine.column_generators.generators.llm_generators import (
1719
LLMCodeCellGenerator,
@@ -39,6 +41,7 @@ def create_default_column_generator_registry() -> ColumnGeneratorRegistry:
3941
registry.register(DataDesignerColumnType.LLM_JUDGE, LLMJudgeCellGenerator, LLMJudgeColumnConfig, False)
4042
registry.register(DataDesignerColumnType.EXPRESSION, ExpressionColumnGenerator, ExpressionColumnConfig, False)
4143
registry.register(DataDesignerColumnType.SAMPLER, SamplerColumnGenerator, SamplerMultiColumnConfig, False)
44+
registry.register(DataDesignerColumnType.CUSTOM, CustomColumnGenerator, CustomColumnConfig, False)
4245
registry.register(
4346
DataDesignerColumnType.SEED_DATASET,
4447
SeedDatasetColumnGenerator,

src/data_designer/essentials/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from ..config.analysis.column_profilers import JudgeScoreProfilerConfig
55
from ..config.columns import (
6+
CustomColumnConfig,
67
DataDesignerColumnType,
78
ExpressionColumnConfig,
89
LLMCodeColumnConfig,

0 commit comments

Comments
 (0)