Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions src/data_designer/engine/processing/gsonschema/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# SPDX-License-Identifier: Apache-2.0

import logging
import re
from copy import deepcopy
from decimal import ROUND_HALF_UP, Decimal
from typing import Any, overload

from jsonschema import Draft202012Validator, ValidationError, validators
Expand Down Expand Up @@ -70,6 +72,57 @@ def extend_jsonschema_validator_with_pruning(validator):
return validators.extend(validator, {"additionalProperties": prune_additional_properties})


def _get_decimal_info_from_anyof(schema: dict) -> tuple[bool, int | None]:
"""Check if schema is a Decimal anyOf and extract decimal places.

Returns (is_decimal, decimal_places) where decimal_places is None if no constraint.
"""
any_of = schema.get("anyOf")
if not isinstance(any_of, list):
return False, None

has_number = any(item.get("type") == "number" for item in any_of)
if not has_number:
return False, None

for item in any_of:
if item.get("type") == "string" and "pattern" in item:
match = re.search(r"\\d\{0,(\d+)\}", item["pattern"])
if match:
return True, int(match.group(1))
return True, None # Decimal without precision constraint
return False, None


def normalize_decimal_fields(obj: DataObjectT, schema: JSONSchemaT) -> DataObjectT:
"""Normalize Decimal-like anyOf fields to floats with proper precision."""
if not isinstance(obj, dict):
return obj

defs = schema.get("$defs", {})
obj_schema = defs.get(schema.get("$ref", "")[len("#/$defs/") :], schema)
props = obj_schema.get("properties", {})

for key, value in obj.items():
field_schema = props.get(key, {})
if "$ref" in field_schema:
field_schema = defs.get(field_schema["$ref"][len("#/$defs/") :], {})

if isinstance(value, dict):
obj[key] = normalize_decimal_fields(value, schema)
elif isinstance(value, list):
obj[key] = [normalize_decimal_fields(v, schema) if isinstance(v, dict) else v for v in value]
elif isinstance(value, (int, float, str)) and not isinstance(value, bool):
is_decimal, decimal_places = _get_decimal_info_from_anyof(field_schema)
if is_decimal:
d = Decimal(str(value))
if decimal_places is not None:
d = d.quantize(Decimal(f"0.{'0' * decimal_places}"), rounding=ROUND_HALF_UP)
obj[key] = float(d)

return obj


## We don't expect the outer data type (e.g. dict, list, or const) to be
## modified by the pruning action.
@overload
Expand Down Expand Up @@ -140,4 +193,6 @@ def validate(
except ValidationError as exc:
raise JSONSchemaValidationError(str(exc)) from exc

final_object = normalize_decimal_fields(final_object, schema)

return final_object
31 changes: 31 additions & 0 deletions tests/engine/processing/gsonschema/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,34 @@ def test_invalid_data_type():
data = {"num": "not a number", "extra": "should be removed"}
with pytest.raises(JSONSchemaValidationError):
validate(data, schema, pruning=True, no_extra_properties=True)


def test_normalize_decimal_anyof_fields() -> None:
"""Test that Decimal-like anyOf fields are normalized to floats with proper precision."""
schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"price": {
"anyOf": [
{"type": "number"},
{"type": "string", "pattern": r"^(?!^[-+.]*$)[+-]?0*\d*\.?\d{0,2}0*$"},
]
},
},
}

# Numeric value with extra precision should be rounded to 2 decimal places
result1 = validate({"name": "Widget", "price": 189.999}, schema)
assert result1["price"] == 190.0
assert isinstance(result1["price"], float)

# Numeric value should be converted to float
result2 = validate({"name": "Gadget", "price": 50.5}, schema)
assert result2["price"] == 50.5
assert isinstance(result2["price"], float)

# String value should be converted to float
result3 = validate({"name": "Gizmo", "price": "249.99"}, schema)
assert result3["price"] == 249.99
assert isinstance(result3["price"], float)