diff --git a/src/data_designer/engine/processing/gsonschema/validators.py b/src/data_designer/engine/processing/gsonschema/validators.py index d16ad69b..52ca337d 100644 --- a/src/data_designer/engine/processing/gsonschema/validators.py +++ b/src/data_designer/engine/processing/gsonschema/validators.py @@ -2,7 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import logging +import re from copy import deepcopy +from decimal import ROUND_HALF_UP, Decimal from typing import Any, overload from jsonschema import Draft202012Validator, ValidationError, validators @@ -70,6 +72,57 @@ def extend_jsonschema_validator_with_pruning(validator): return validators.extend(validator, {"additionalProperties": prune_additional_properties}) +def _get_decimal_info_from_anyof(schema: dict) -> tuple[bool, int | None]: + """Check if schema is a Decimal anyOf and extract decimal places. + + Returns (is_decimal, decimal_places) where decimal_places is None if no constraint. + """ + any_of = schema.get("anyOf") + if not isinstance(any_of, list): + return False, None + + has_number = any(item.get("type") == "number" for item in any_of) + if not has_number: + return False, None + + for item in any_of: + if item.get("type") == "string" and "pattern" in item: + match = re.search(r"\\d\{0,(\d+)\}", item["pattern"]) + if match: + return True, int(match.group(1)) + return True, None # Decimal without precision constraint + return False, None + + +def normalize_decimal_fields(obj: DataObjectT, schema: JSONSchemaT) -> DataObjectT: + """Normalize Decimal-like anyOf fields to floats with proper precision.""" + if not isinstance(obj, dict): + return obj + + defs = schema.get("$defs", {}) + obj_schema = defs.get(schema.get("$ref", "")[len("#/$defs/") :], schema) + props = obj_schema.get("properties", {}) + + for key, value in obj.items(): + field_schema = props.get(key, {}) + if "$ref" in field_schema: + field_schema = defs.get(field_schema["$ref"][len("#/$defs/") :], {}) + + if isinstance(value, dict): + obj[key] = normalize_decimal_fields(value, schema) + elif isinstance(value, list): + obj[key] = [normalize_decimal_fields(v, schema) if isinstance(v, dict) else v for v in value] + elif isinstance(value, (int, float, str)) and not isinstance(value, bool): + is_decimal, decimal_places = _get_decimal_info_from_anyof(field_schema) + if is_decimal: + d = Decimal(str(value)) + if decimal_places is not None: + d = d.quantize(Decimal(f"0.{'0' * decimal_places}"), rounding=ROUND_HALF_UP) + obj[key] = float(d) + + return obj + + ## We don't expect the outer data type (e.g. dict, list, or const) to be ## modified by the pruning action. @overload @@ -140,4 +193,6 @@ def validate( except ValidationError as exc: raise JSONSchemaValidationError(str(exc)) from exc + final_object = normalize_decimal_fields(final_object, schema) + return final_object diff --git a/tests/engine/processing/gsonschema/test_validators.py b/tests/engine/processing/gsonschema/test_validators.py index f2ca70f0..b746d3b3 100644 --- a/tests/engine/processing/gsonschema/test_validators.py +++ b/tests/engine/processing/gsonschema/test_validators.py @@ -196,3 +196,34 @@ def test_invalid_data_type(): data = {"num": "not a number", "extra": "should be removed"} with pytest.raises(JSONSchemaValidationError): validate(data, schema, pruning=True, no_extra_properties=True) + + +def test_normalize_decimal_anyof_fields() -> None: + """Test that Decimal-like anyOf fields are normalized to floats with proper precision.""" + schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "price": { + "anyOf": [ + {"type": "number"}, + {"type": "string", "pattern": r"^(?!^[-+.]*$)[+-]?0*\d*\.?\d{0,2}0*$"}, + ] + }, + }, + } + + # Numeric value with extra precision should be rounded to 2 decimal places + result1 = validate({"name": "Widget", "price": 189.999}, schema) + assert result1["price"] == 190.0 + assert isinstance(result1["price"], float) + + # Numeric value should be converted to float + result2 = validate({"name": "Gadget", "price": 50.5}, schema) + assert result2["price"] == 50.5 + assert isinstance(result2["price"], float) + + # String value should be converted to float + result3 = validate({"name": "Gizmo", "price": "249.99"}, schema) + assert result3["price"] == 249.99 + assert isinstance(result3["price"], float)