Skip to content

Commit 202eba6

Browse files
authored
make sampler type a discriminated union; add injection validator (#71)
1 parent 2e9e4ff commit 202eba6

File tree

4 files changed

+177
-11
lines changed

4 files changed

+177
-11
lines changed

src/data_designer/config/column_configs.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
from abc import ABC
5-
from typing import Literal, Optional, Type, Union
5+
from typing import Annotated, Literal, Optional, Type, Union
66

7-
from pydantic import BaseModel, Field, model_validator
7+
from pydantic import BaseModel, Discriminator, Field, model_validator
88
from typing_extensions import Self
99

1010
from .base import ConfigBase
@@ -89,11 +89,36 @@ class SamplerColumnConfig(SingleColumnConfig):
8989
"""
9090

9191
sampler_type: SamplerType
92-
params: SamplerParamsT
93-
conditional_params: dict[str, SamplerParamsT] = {}
92+
params: Annotated[SamplerParamsT, Discriminator("sampler_type")]
93+
conditional_params: dict[str, Annotated[SamplerParamsT, Discriminator("sampler_type")]] = {}
9494
convert_to: Optional[str] = None
9595
column_type: Literal["sampler"] = "sampler"
9696

97+
@model_validator(mode="before")
98+
@classmethod
99+
def inject_sampler_type_into_params(cls, data: dict) -> dict:
100+
"""Inject sampler_type into params dict to enable discriminated union resolution.
101+
102+
This allows users to pass params as a simple dict without the sampler_type field,
103+
which will be automatically added based on the outer sampler_type field.
104+
"""
105+
if isinstance(data, dict):
106+
sampler_type = data.get("sampler_type")
107+
params = data.get("params")
108+
109+
# If params is a dict and doesn't have sampler_type, inject it
110+
if sampler_type and isinstance(params, dict) and "sampler_type" not in params:
111+
data["params"] = {"sampler_type": sampler_type, **params}
112+
113+
# Handle conditional_params similarly
114+
conditional_params = data.get("conditional_params")
115+
if conditional_params and isinstance(conditional_params, dict):
116+
for condition, cond_params in conditional_params.items():
117+
if isinstance(cond_params, dict) and "sampler_type" not in cond_params:
118+
data["conditional_params"][condition] = {"sampler_type": sampler_type, **cond_params}
119+
120+
return data
121+
97122

98123
class LLMTextColumnConfig(SingleColumnConfig):
99124
"""Configuration for text generation columns using Large Language Models.

src/data_designer/config/sampler_params.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class CategorySamplerParams(ConfigBase):
6666
"Larger values will be sampled with higher probability."
6767
),
6868
)
69+
sampler_type: Literal[SamplerType.CATEGORY] = SamplerType.CATEGORY
6970

7071
@model_validator(mode="after")
7172
def _normalize_weights_if_needed(self) -> Self:
@@ -106,6 +107,7 @@ class DatetimeSamplerParams(ConfigBase):
106107
default="D",
107108
description="Sampling units, e.g. the smallest possible time interval between samples.",
108109
)
110+
sampler_type: Literal[SamplerType.DATETIME] = SamplerType.DATETIME
109111

110112
@field_validator("start", "end")
111113
@classmethod
@@ -136,6 +138,7 @@ class SubcategorySamplerParams(ConfigBase):
136138
...,
137139
description="Mapping from each value of parent category to a list of subcategory values.",
138140
)
141+
sampler_type: Literal[SamplerType.SUBCATEGORY] = SamplerType.SUBCATEGORY
139142

140143

141144
class TimeDeltaSamplerParams(ConfigBase):
@@ -187,6 +190,7 @@ class TimeDeltaSamplerParams(ConfigBase):
187190
default="D",
188191
description="Sampling units, e.g. the smallest possible time interval between samples.",
189192
)
193+
sampler_type: Literal[SamplerType.TIMEDELTA] = SamplerType.TIMEDELTA
190194

191195
@model_validator(mode="after")
192196
def _validate_min_less_than_max(self) -> Self:
@@ -219,6 +223,7 @@ class UUIDSamplerParams(ConfigBase):
219223
default=False,
220224
description="If true, all letters in the UUID will be capitalized.",
221225
)
226+
sampler_type: Literal[SamplerType.UUID] = SamplerType.UUID
222227

223228
@property
224229
def last_index(self) -> int:
@@ -257,6 +262,7 @@ class ScipySamplerParams(ConfigBase):
257262
decimal_places: Optional[int] = Field(
258263
default=None, description="Number of decimal places to round the sampled values to."
259264
)
265+
sampler_type: Literal[SamplerType.SCIPY] = SamplerType.SCIPY
260266

261267

262268
class BinomialSamplerParams(ConfigBase):
@@ -273,6 +279,7 @@ class BinomialSamplerParams(ConfigBase):
273279

274280
n: int = Field(..., description="Number of trials.")
275281
p: float = Field(..., description="Probability of success on each trial.", ge=0.0, le=1.0)
282+
sampler_type: Literal[SamplerType.BINOMIAL] = SamplerType.BINOMIAL
276283

277284

278285
class BernoulliSamplerParams(ConfigBase):
@@ -288,6 +295,7 @@ class BernoulliSamplerParams(ConfigBase):
288295
"""
289296

290297
p: float = Field(..., description="Probability of success.", ge=0.0, le=1.0)
298+
sampler_type: Literal[SamplerType.BERNOULLI] = SamplerType.BERNOULLI
291299

292300

293301
class BernoulliMixtureSamplerParams(ConfigBase):
@@ -327,6 +335,7 @@ class BernoulliMixtureSamplerParams(ConfigBase):
327335
...,
328336
description="Parameters of the scipy.stats distribution given in `dist_name`.",
329337
)
338+
sampler_type: Literal[SamplerType.BERNOULLI_MIXTURE] = SamplerType.BERNOULLI_MIXTURE
330339

331340

332341
class GaussianSamplerParams(ConfigBase):
@@ -350,6 +359,7 @@ class GaussianSamplerParams(ConfigBase):
350359
decimal_places: Optional[int] = Field(
351360
default=None, description="Number of decimal places to round the sampled values to."
352361
)
362+
sampler_type: Literal[SamplerType.GAUSSIAN] = SamplerType.GAUSSIAN
353363

354364

355365
class PoissonSamplerParams(ConfigBase):
@@ -369,6 +379,7 @@ class PoissonSamplerParams(ConfigBase):
369379
"""
370380

371381
mean: float = Field(..., description="Mean number of events in a fixed interval.")
382+
sampler_type: Literal[SamplerType.POISSON] = SamplerType.POISSON
372383

373384

374385
class UniformSamplerParams(ConfigBase):
@@ -390,6 +401,7 @@ class UniformSamplerParams(ConfigBase):
390401
decimal_places: Optional[int] = Field(
391402
default=None, description="Number of decimal places to round the sampled values to."
392403
)
404+
sampler_type: Literal[SamplerType.UNIFORM] = SamplerType.UNIFORM
393405

394406

395407
#########################################
@@ -470,11 +482,12 @@ class PersonSamplerParams(ConfigBase):
470482
default=False,
471483
description="If True, then append synthetic persona columns to each generated person.",
472484
)
485+
sampler_type: Literal[SamplerType.PERSON] = SamplerType.PERSON
473486

474487
@property
475488
def generator_kwargs(self) -> list[str]:
476489
"""Keyword arguments to pass to the person generator."""
477-
return [f for f in list(PersonSamplerParams.model_fields) if f != "locale"]
490+
return [f for f in list(PersonSamplerParams.model_fields) if f not in ("locale", "sampler_type")]
478491

479492
@property
480493
def people_gen_key(self) -> str:
@@ -533,11 +546,12 @@ class PersonFromFakerSamplerParams(ConfigBase):
533546
min_length=2,
534547
max_length=2,
535548
)
549+
sampler_type: Literal[SamplerType.PERSON_FROM_FAKER] = SamplerType.PERSON_FROM_FAKER
536550

537551
@property
538552
def generator_kwargs(self) -> list[str]:
539553
"""Keyword arguments to pass to the person generator."""
540-
return [f for f in list(PersonFromFakerSamplerParams.model_fields) if f != "locale"]
554+
return [f for f in list(PersonFromFakerSamplerParams.model_fields) if f not in ("locale", "sampler_type")]
541555

542556
@property
543557
def people_gen_key(self) -> str:

tests/config/test_columns.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,15 @@
2323
get_column_display_order,
2424
)
2525
from data_designer.config.errors import InvalidConfigError
26-
from data_designer.config.sampler_params import SamplerType, UUIDSamplerParams
26+
from data_designer.config.sampler_params import (
27+
CategorySamplerParams,
28+
GaussianSamplerParams,
29+
PersonFromFakerSamplerParams,
30+
PersonSamplerParams,
31+
SamplerType,
32+
UniformSamplerParams,
33+
UUIDSamplerParams,
34+
)
2735
from data_designer.config.utils.code_lang import CodeLang
2836
from data_designer.config.utils.errors import UserJinjaTemplateSyntaxError
2937
from data_designer.config.validator_params import CodeValidatorParams
@@ -324,3 +332,114 @@ def test_get_column_config_from_kwargs():
324332
),
325333
SeedDatasetColumnConfig,
326334
)
335+
336+
337+
def test_sampler_column_config_discriminated_union_with_dict_params():
338+
"""Test that sampler_type field is automatically injected into params dict."""
339+
config = SamplerColumnConfig(
340+
name="test_uniform",
341+
sampler_type=SamplerType.UNIFORM,
342+
params={"low": 0.0, "high": 1.0, "decimal_places": 2},
343+
)
344+
assert config.name == "test_uniform"
345+
assert config.sampler_type == SamplerType.UNIFORM
346+
assert isinstance(config.params, UniformSamplerParams)
347+
assert config.params.sampler_type == SamplerType.UNIFORM
348+
assert config.params.low == 0.0
349+
assert config.params.high == 1.0
350+
assert config.params.decimal_places == 2
351+
352+
353+
def test_sampler_column_config_discriminated_union_with_explicit_sampler_type():
354+
"""Test that explicit sampler_type in params dict is preserved."""
355+
config = SamplerColumnConfig(
356+
name="test_category",
357+
sampler_type=SamplerType.CATEGORY,
358+
params={"sampler_type": "category", "values": ["A", "B", "C"], "weights": [0.5, 0.3, 0.2]},
359+
)
360+
assert config.name == "test_category"
361+
assert config.sampler_type == SamplerType.CATEGORY
362+
assert isinstance(config.params, CategorySamplerParams)
363+
assert config.params.sampler_type == SamplerType.CATEGORY
364+
assert config.params.values == ["A", "B", "C"]
365+
366+
367+
def test_sampler_column_config_discriminated_union_serialization():
368+
"""Test that discriminated union works correctly with serialization/deserialization."""
369+
config = SamplerColumnConfig(
370+
name="test_person",
371+
sampler_type=SamplerType.PERSON,
372+
params={"locale": "en_US", "sex": "Female", "age_range": [25, 45]},
373+
)
374+
375+
# Serialize
376+
serialized = config.model_dump()
377+
assert "sampler_type" in serialized["params"]
378+
assert serialized["params"]["sampler_type"] == "person"
379+
380+
# Deserialize
381+
deserialized = SamplerColumnConfig(**serialized)
382+
assert isinstance(deserialized.params, PersonSamplerParams)
383+
assert deserialized.params.locale == "en_US"
384+
assert deserialized.params.sex == "Female"
385+
assert deserialized.params.age_range == [25, 45]
386+
387+
388+
def test_sampler_column_config_discriminated_union_person_vs_person_from_faker():
389+
"""Test that discriminated union correctly distinguishes between person and person_from_faker."""
390+
# Test person sampler (managed datasets)
391+
person_config = SamplerColumnConfig(
392+
name="test_person",
393+
sampler_type=SamplerType.PERSON,
394+
params={"locale": "en_US", "sex": "Male", "age_range": [30, 50]},
395+
)
396+
assert isinstance(person_config.params, PersonSamplerParams)
397+
assert person_config.params.sampler_type == SamplerType.PERSON
398+
assert person_config.params.locale == "en_US"
399+
400+
# Test person_from_faker sampler (Faker-based)
401+
person_faker_config = SamplerColumnConfig(
402+
name="test_person_faker",
403+
sampler_type=SamplerType.PERSON_FROM_FAKER,
404+
params={"locale": "en_GB", "sex": "Female", "age_range": [20, 40]},
405+
)
406+
assert isinstance(person_faker_config.params, PersonFromFakerSamplerParams)
407+
assert person_faker_config.params.sampler_type == SamplerType.PERSON_FROM_FAKER
408+
assert person_faker_config.params.locale == "en_GB"
409+
410+
# Verify they are different types
411+
assert type(person_config.params) != type(person_faker_config.params)
412+
assert isinstance(person_config.params, PersonSamplerParams)
413+
assert isinstance(person_faker_config.params, PersonFromFakerSamplerParams)
414+
415+
416+
def test_sampler_column_config_discriminated_union_with_conditional_params():
417+
"""Test that sampler_type is injected into conditional_params as well."""
418+
config = SamplerColumnConfig(
419+
name="test_gaussian",
420+
sampler_type=SamplerType.GAUSSIAN,
421+
params={"mean": 0.0, "stddev": 1.0},
422+
conditional_params={"age > 21": {"mean": 5.0, "stddev": 2.0}},
423+
)
424+
425+
assert isinstance(config.params, GaussianSamplerParams)
426+
assert config.params.mean == 0.0
427+
assert config.params.stddev == 1.0
428+
429+
# Check conditional params
430+
assert "age > 21" in config.conditional_params
431+
cond_param = config.conditional_params["age > 21"]
432+
assert isinstance(cond_param, GaussianSamplerParams)
433+
assert cond_param.sampler_type == SamplerType.GAUSSIAN
434+
assert cond_param.mean == 5.0
435+
assert cond_param.stddev == 2.0
436+
437+
438+
def test_sampler_column_config_discriminated_union_wrong_params_type():
439+
"""Test that discriminated union rejects params that don't match the sampler_type."""
440+
with pytest.raises(ValidationError):
441+
SamplerColumnConfig(
442+
name="test_wrong_params",
443+
sampler_type=SamplerType.UNIFORM,
444+
params={"values": ["A", "B"]}, # Category params for uniform sampler
445+
)

tests/engine/analysis/column_profilers/test_base.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
def test_column_config_with_dataframe_valid_column_config_with_dataframe():
1818
df = pd.DataFrame({"test_column": [1, 2, 3]})
19-
column_config = SamplerColumnConfig(name="test_column", sampler_type=SamplerType.CATEGORY, params={})
19+
column_config = SamplerColumnConfig(
20+
name="test_column", sampler_type=SamplerType.CATEGORY, params={"values": [1, 2, 3]}
21+
)
2022

2123
config_with_df = ColumnConfigWithDataFrame(column_config=column_config, df=df)
2224

@@ -27,15 +29,19 @@ def test_column_config_with_dataframe_valid_column_config_with_dataframe():
2729

2830
def test_column_config_with_dataframe_column_not_found_validation_error():
2931
df = pd.DataFrame({"other_column": [1, 2, 3]})
30-
column_config = SamplerColumnConfig(name="test_column", sampler_type=SamplerType.CATEGORY, params={})
32+
column_config = SamplerColumnConfig(
33+
name="test_column", sampler_type=SamplerType.CATEGORY, params={"values": [1, 2, 3]}
34+
)
3135

3236
with pytest.raises(ValidationError, match="Column 'test_column' not found in DataFrame"):
3337
ColumnConfigWithDataFrame(column_config=column_config, df=df)
3438

3539

3640
def test_column_config_with_dataframe_pyarrow_backend_conversion():
3741
df = pd.DataFrame({"test_column": [1, 2, 3]})
38-
column_config = SamplerColumnConfig(name="test_column", sampler_type=SamplerType.CATEGORY, params={})
42+
column_config = SamplerColumnConfig(
43+
name="test_column", sampler_type=SamplerType.CATEGORY, params={"values": [1, 2, 3]}
44+
)
3945

4046
config_with_df = ColumnConfigWithDataFrame(column_config=column_config, df=df)
4147

@@ -44,7 +50,9 @@ def test_column_config_with_dataframe_pyarrow_backend_conversion():
4450

4551
def test_column_config_with_dataframe_as_tuple_method():
4652
df = pd.DataFrame({"test_column": [1, 2, 3]})
47-
column_config = SamplerColumnConfig(name="test_column", sampler_type=SamplerType.CATEGORY, params={})
53+
column_config = SamplerColumnConfig(
54+
name="test_column", sampler_type=SamplerType.CATEGORY, params={"values": [1, 2, 3]}
55+
)
4856

4957
config_with_df = ColumnConfigWithDataFrame(column_config=column_config, df=df)
5058
column_config_result, df_result = config_with_df.as_tuple()

0 commit comments

Comments
 (0)