Skip to content

Commit e836832

Browse files
committed
comments
1 parent 8e6bc66 commit e836832

File tree

5 files changed

+24
-43
lines changed

5 files changed

+24
-43
lines changed

src/data_designer/config/processors.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import json
45
from abc import ABC
56
from enum import Enum
67
from typing import Any, Literal
@@ -9,6 +10,7 @@
910

1011
from data_designer.config.base import ConfigBase
1112
from data_designer.config.dataset_builders import BuildStage
13+
from data_designer.config.errors import InvalidConfigError
1214

1315
SUPPORTED_STAGES = [BuildStage.POST_BATCH]
1416

@@ -72,3 +74,12 @@ class SchemaTransformProcessorConfig(ProcessorConfig):
7274
""",
7375
)
7476
processor_type: Literal[ProcessorType.SCHEMA_TRANSFORM] = ProcessorType.SCHEMA_TRANSFORM
77+
78+
@field_validator("template")
79+
def validate_template(cls, v: dict[str, Any]) -> dict[str, Any]:
80+
try:
81+
json.dumps(v)
82+
except TypeError as e:
83+
if "not JSON serializable" in str(e):
84+
raise InvalidConfigError("Template must be JSON serializable")
85+
return v

src/data_designer/config/utils/validation.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from __future__ import annotations
55

6-
import json
76
from enum import Enum
87
from string import Formatter
98
from typing import Optional
@@ -37,7 +36,6 @@ class ViolationType(str, Enum):
3736
INVALID_COLUMN = "invalid_column"
3837
INVALID_MODEL_CONFIG = "invalid_model_config"
3938
INVALID_REFERENCE = "invalid_reference"
40-
INVALID_TEMPLATE = "invalid_template"
4139
PROMPT_WITHOUT_REFERENCES = "prompt_without_references"
4240

4341

@@ -303,20 +301,6 @@ def validate_schema_transform_processor(
303301
all_column_names = {c.name for c in columns}
304302
for processor_config in processor_configs:
305303
if processor_config.processor_type == ProcessorType.SCHEMA_TRANSFORM:
306-
try:
307-
json.dumps(processor_config.template)
308-
except TypeError as e:
309-
if "not JSON serializable" in str(e):
310-
violations.append(
311-
Violation(
312-
column=None,
313-
type=ViolationType.INVALID_TEMPLATE,
314-
message=f"Ancillary dataset processor {processor_config.name} template is not a valid JSON object.",
315-
level=ViolationLevel.ERROR,
316-
)
317-
)
318-
continue
319-
320304
for col, template in processor_config.template.items():
321305
template_keywords = get_prompt_template_keywords(template)
322306
invalid_keywords = set(template_keywords) - all_column_names

src/data_designer/engine/dataset_builders/artifact_storage.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,6 @@ def write_batch_to_parquet_file(
178178
batch_stage: BatchStage,
179179
subfolder: str | None = None,
180180
) -> Path:
181-
if subfolder is None:
182-
subfolder = ""
183181
file_path = self.create_batch_file_path(batch_number, batch_stage=batch_stage)
184182
self.write_parquet_file(file_path.name, dataframe, batch_stage, subfolder=subfolder)
185183
return file_path
@@ -191,8 +189,7 @@ def write_parquet_file(
191189
batch_stage: BatchStage,
192190
subfolder: str | None = None,
193191
) -> Path:
194-
if subfolder is None:
195-
subfolder = ""
192+
subfolder = subfolder or ""
196193
self.mkdir_if_needed(self._get_stage_path(batch_stage) / subfolder)
197194
file_path = self._get_stage_path(batch_stage) / subfolder / parquet_file_name
198195
dataframe.to_parquet(file_path, index=False)

tests/config/test_processors.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pydantic import ValidationError
66

77
from data_designer.config.dataset_builders import BuildStage
8+
from data_designer.config.errors import InvalidConfigError
89
from data_designer.config.processors import (
910
DropColumnsProcessorConfig,
1011
ProcessorConfig,
@@ -79,10 +80,16 @@ def test_schema_transform_processor_config_validation():
7980
with pytest.raises(ValidationError, match="Field required"):
8081
SchemaTransformProcessorConfig(name="schema_transform_processor", build_stage=BuildStage.POST_BATCH)
8182

83+
# Test invalid template raises error
84+
with pytest.raises(InvalidConfigError, match="Template must be JSON serializable"):
85+
SchemaTransformProcessorConfig(
86+
name="schema_transform_processor", build_stage=BuildStage.POST_BATCH, template={"text": {1, 2, 3}}
87+
)
8288

83-
def test_output_format_processor_config_serialization():
89+
90+
def test_schema_transform_processor_config_serialization():
8491
config = SchemaTransformProcessorConfig(
85-
name="output_format_processor",
92+
name="schema_transform_processor",
8693
build_stage=BuildStage.POST_BATCH,
8794
template={"text": "{{ col1 }}"},
8895
)

tests/config/utils/test_validation.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,6 @@
111111
template={"text": "{{ invalid_reference }}"},
112112
build_stage=BuildStage.POST_BATCH,
113113
),
114-
SchemaTransformProcessorConfig(
115-
name="schema_transform_processor_invalid_template",
116-
template={"text": {1, 2, 3}},
117-
build_stage=BuildStage.POST_BATCH,
118-
),
119114
]
120115
ALLOWED_REFERENCE = [c.name for c in COLUMNS]
121116

@@ -180,17 +175,11 @@ def test_validate_data_designer_config(
180175
type=ViolationType.INVALID_REFERENCE,
181176
message="Ancillary dataset processor attempts to reference columns 'invalid_reference' in the template for 'text', but the columns are not defined in the dataset.",
182177
level=ViolationLevel.ERROR,
183-
),
184-
Violation(
185-
column="text",
186-
type=ViolationType.INVALID_TEMPLATE,
187-
message="Ancillary dataset processor template is not a valid JSON object.",
188-
level=ViolationLevel.ERROR,
189-
),
178+
)
190179
]
191180

192181
violations = validate_data_designer_config(COLUMNS, PROCESSOR_CONFIGS, ALLOWED_REFERENCE)
193-
assert len(violations) == 7
182+
assert len(violations) == 6
194183
mock_validate_columns_not_all_dropped.assert_called_once()
195184
mock_validate_expression_references.assert_called_once()
196185
mock_validate_code_validation.assert_called_once()
@@ -283,21 +272,14 @@ def test_validate_expression_references():
283272

284273
def test_validate_schema_transform_processor():
285274
violations = validate_schema_transform_processor(COLUMNS, PROCESSOR_CONFIGS)
286-
assert len(violations) == 2
275+
assert len(violations) == 1
287276
assert violations[0].type == ViolationType.INVALID_REFERENCE
288277
assert violations[0].column is None
289278
assert (
290279
violations[0].message
291280
== "Ancillary dataset processor attempts to reference columns 'invalid_reference' in the template for 'text', but the columns are not defined in the dataset."
292281
)
293282
assert violations[0].level == ViolationLevel.ERROR
294-
assert violations[1].type == ViolationType.INVALID_TEMPLATE
295-
assert violations[1].column is None
296-
assert (
297-
violations[1].message
298-
== "Ancillary dataset processor schema_transform_processor_invalid_template template is not a valid JSON object."
299-
)
300-
assert violations[1].level == ViolationLevel.ERROR
301283

302284

303285
@patch("data_designer.config.utils.validation.Console.print")

0 commit comments

Comments
 (0)