33
44from abc import ABC , abstractmethod
55from enum import Enum
6- from typing import Literal , Optional , Type , Union
6+ from typing import Callable , Literal , Optional , Type , Union
77
8- from pydantic import BaseModel , Field , model_validator
8+ import pandas as pd
9+ from pydantic import BaseModel , Field , field_serializer , model_validator
910from typing_extensions import Self , TypeAlias
1011
1112from .base import ConfigBase
@@ -28,6 +29,7 @@ class DataDesignerColumnType(str, Enum):
2829 EXPRESSION = "expression"
2930 VALIDATION = "validation"
3031 SEED_DATASET = "seed-dataset"
32+ CUSTOM = "custom"
3133
3234 @staticmethod
3335 def get_display_order () -> list [Self ]:
@@ -40,6 +42,7 @@ def get_display_order() -> list[Self]:
4042 DataDesignerColumnType .LLM_JUDGE ,
4143 DataDesignerColumnType .VALIDATION ,
4244 DataDesignerColumnType .EXPRESSION ,
45+ DataDesignerColumnType .CUSTOM ,
4346 ]
4447
4548 @property
@@ -55,6 +58,7 @@ def is_dag_column_type(self) -> bool:
5558 self .LLM_STRUCTURED ,
5659 self .LLM_TEXT ,
5760 self .VALIDATION ,
61+ self .CUSTOM ,
5862 ]
5963
6064
@@ -195,6 +199,18 @@ def column_type(self) -> DataDesignerColumnType:
195199 return DataDesignerColumnType .SEED_DATASET
196200
197201
202+ class CustomColumnConfig (SingleColumnConfig ):
203+ generator_function : Callable [[pd .DataFrame ], pd .DataFrame ]
204+
205+ @property
206+ def column_type (self ) -> DataDesignerColumnType :
207+ return DataDesignerColumnType .CUSTOM
208+
209+ @field_serializer ("generator_function" )
210+ def serialize_generator_function (self , v : Callable [[pd .DataFrame ], pd .DataFrame ]) -> str :
211+ return v .__name__
212+
213+
198214COLUMN_TYPE_EMOJI_MAP = {
199215 "general" : "⚛️" , # possible analysis column type
200216 DataDesignerColumnType .EXPRESSION : "🧩" ,
@@ -205,6 +221,7 @@ def column_type(self) -> DataDesignerColumnType:
205221 DataDesignerColumnType .SEED_DATASET : "🌱" ,
206222 DataDesignerColumnType .SAMPLER : "🎲" ,
207223 DataDesignerColumnType .VALIDATION : "🔍" ,
224+ DataDesignerColumnType .CUSTOM : "🛠️" ,
208225}
209226
210227
@@ -217,6 +234,7 @@ def column_type(self) -> DataDesignerColumnType:
217234 SamplerColumnConfig ,
218235 SeedDatasetColumnConfig ,
219236 ValidationColumnConfig ,
237+ CustomColumnConfig ,
220238]
221239
222240
@@ -248,6 +266,8 @@ def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType
248266 return SamplerColumnConfig (name = name , ** _resolve_sampler_kwargs (name , kwargs ))
249267 elif column_type == DataDesignerColumnType .SEED_DATASET :
250268 return SeedDatasetColumnConfig (name = name , ** kwargs )
269+ elif column_type == DataDesignerColumnType .CUSTOM :
270+ return CustomColumnConfig (name = name , ** kwargs )
251271 raise InvalidColumnTypeError (f"🛑 { column_type } is not a valid column type." ) # pragma: no cover
252272
253273
0 commit comments