Skip to content

Commit 805363f

Browse files
committed
better jsonschema validation
1 parent cb12418 commit 805363f

File tree

4 files changed

+403
-14
lines changed

4 files changed

+403
-14
lines changed

lib/idp_common_pkg/idp_common/extraction/agentic_idp.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,10 @@ def extraction_tool(
199199
extraction: model_class, # pyright: ignore[reportInvalidTypeForm]
200200
agent: Agent, # pyright: ignore[reportInvalidTypeForm]
201201
) -> str: # pyright: ignore[reportInvalidTypeForm]
202+
"""Use this tool to return the requested data extraction.
203+
When you call this tool it overwrites the previous extraction, if you want to expand the extraction use jsonpatch.
204+
This tool needs to be Successfully invoked before the patch tool can be used."""
205+
202206
logger.info("extraction_tool called", extra={"models_extraction": extraction})
203207
extraction_model = model_class(**extraction) # pyright: ignore[reportAssignmentType]
204208
extraction_dict = extraction_model.model_dump()
@@ -238,11 +242,6 @@ def apply_json_patches(
238242
"patches_applied": len(patches),
239243
}
240244

241-
extraction_tool.__doc__ = f"""
242-
Use this tool to return the requested data extraction.
243-
When you call this tool it overwrites the previous extraction, if you want to expand the extraction use jsonpatch.
244-
This tool needs to be Successfully invoked before the patch tool can be used.
245-
required extraction schema is: {model_class.model_json_schema()}"""
246245
return extraction_tool, apply_json_patches
247246

248247

lib/idp_common_pkg/idp_common/schema/pydantic_generator.py

Lines changed: 125 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
import sys
1515
import tempfile
1616
from pathlib import Path
17-
from typing import Any, Dict, List, Optional, Tuple, Type
17+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
1818

19+
import jsonschema
1920
from datamodel_code_generator import DataModelType, InputFileType, generate
20-
from pydantic import BaseModel, ConfigDict, create_model
21+
from pydantic import BaseModel, ConfigDict, create_model, model_validator
2122

2223
logger = logging.getLogger(__name__)
2324

@@ -88,6 +89,83 @@ def _normalize_class_name(name: str) -> str:
8889
)
8990

9091

92+
def has_advanced_constraints(schema: Dict[str, Any]) -> bool:
93+
"""
94+
Check if schema has constraints that Pydantic doesn't enforce natively.
95+
96+
Args:
97+
schema: JSON Schema definition
98+
99+
Returns:
100+
True if schema has advanced constraints requiring JSON Schema validation
101+
"""
102+
advanced_keywords = {
103+
"contains",
104+
"minContains",
105+
"maxContains",
106+
"contentMediaType",
107+
"contentEncoding",
108+
"dependentSchemas",
109+
"dependentRequired",
110+
"if",
111+
"then",
112+
"else",
113+
}
114+
115+
def check_recursive(obj: Any) -> bool:
116+
if isinstance(obj, dict):
117+
# Check for advanced keywords at this level
118+
if any(key in obj for key in advanced_keywords):
119+
return True
120+
# Recursively check nested objects
121+
for value in obj.values():
122+
if check_recursive(value):
123+
return True
124+
elif isinstance(obj, list):
125+
for item in obj:
126+
if check_recursive(item):
127+
return True
128+
return False
129+
130+
return check_recursive(schema)
131+
132+
133+
def create_json_schema_validator(
134+
original_schema: Dict[str, Any],
135+
) -> Callable[[BaseModel], BaseModel]:
136+
"""
137+
Create a Pydantic model validator that enforces JSON Schema constraints.
138+
139+
Args:
140+
original_schema: The original JSON Schema definition
141+
142+
Returns:
143+
A validator function that can be used with Pydantic's @model_validator
144+
"""
145+
146+
def validate_against_json_schema(value: BaseModel) -> BaseModel:
147+
"""Validate model data against the original JSON Schema."""
148+
# Convert Pydantic model to dict for JSON Schema validation
149+
data = value.model_dump()
150+
151+
try:
152+
# Validate against JSON Schema
153+
jsonschema.validate(data, original_schema)
154+
return value
155+
except jsonschema.ValidationError as e:
156+
# Re-raise as ValueError which Pydantic will catch and convert
157+
raise ValueError(
158+
f"JSON Schema validation failed: {e.message}. "
159+
f"Path: {'.'.join(str(p) for p in e.path) if e.path else 'root'}"
160+
)
161+
except jsonschema.SchemaError as e:
162+
# Schema itself is invalid
163+
logger.error(f"Invalid JSON Schema: {e}")
164+
raise PydanticModelGenerationError(f"Invalid JSON Schema: {e}")
165+
166+
return validate_against_json_schema
167+
168+
91169
def _find_model_in_module(
92170
generated_module: Any,
93171
schema_dict: Dict[str, Any],
@@ -164,6 +242,7 @@ def create_pydantic_model_from_json_schema(
164242
class_label: str,
165243
clean_schema: bool = True,
166244
fields_to_remove: Optional[List[str]] = None,
245+
enable_json_schema_validation: bool = True,
167246
) -> Type[BaseModel]:
168247
"""
169248
Dynamically create a Pydantic v2 model from JSON Schema.
@@ -172,11 +251,16 @@ def create_pydantic_model_from_json_schema(
172251
from a JSON Schema definition. The model is generated in a temporary
173252
file and then dynamically imported.
174253
254+
When advanced JSON Schema constraints are detected (e.g., contains, minContains,
255+
if/then/else), a model validator is automatically added to enforce these
256+
constraints at runtime.
257+
175258
Args:
176259
schema: JSON Schema definition (dict or JSON string)
177260
class_label: Label/name for the class (used for module naming and fallback)
178261
clean_schema: Whether to clean custom fields before generation (default: True)
179262
fields_to_remove: List of field prefixes to remove when cleaning (default: ["x-aws-idp-"])
263+
enable_json_schema_validation: Add JSON Schema validation for advanced constraints (default: True)
180264
181265
Returns:
182266
Dynamically created Pydantic BaseModel class
@@ -269,15 +353,47 @@ def create_pydantic_model_from_json_schema(
269353
normalized_name = _normalize_class_name(schema_title)
270354
selected_model.__name__ = normalized_name
271355

272-
# Configure model to use aliases for population and serialization
273-
# This is critical for handling nested objects where datamodel-code-generator
274-
# adds _1 suffixes to avoid naming conflicts with nested model classes
275-
final_model = create_model(
276-
selected_model.__name__,
277-
__base__=selected_model,
278-
__config__=ConfigDict(populate_by_name=True, serialize_by_alias=True),
356+
# Check if we need to add JSON Schema validation
357+
needs_validation = (
358+
enable_json_schema_validation and has_advanced_constraints(schema)
279359
)
280360

361+
if needs_validation:
362+
# Create a new model class with JSON Schema validation
363+
validator_func = create_json_schema_validator(schema)
364+
365+
# Create class with validator using type() and decorator
366+
class ModelWithValidation(selected_model): # type: ignore
367+
model_config = ConfigDict(
368+
populate_by_name=True, serialize_by_alias=True
369+
)
370+
371+
@model_validator(mode="after") # type: ignore
372+
def validate_json_schema(self): # type: ignore
373+
return validator_func(self)
374+
375+
# Set the correct name
376+
ModelWithValidation.__name__ = selected_model.__name__
377+
ModelWithValidation.__qualname__ = selected_model.__name__
378+
379+
final_model = ModelWithValidation
380+
381+
logger.info(
382+
f"Added JSON Schema validation to model '{selected_model.__name__}' "
383+
f"for class '{class_label}' due to advanced constraints"
384+
)
385+
else:
386+
# Configure model to use aliases for population and serialization
387+
# This is critical for handling nested objects where datamodel-code-generator
388+
# adds _1 suffixes to avoid naming conflicts with nested model classes
389+
final_model = create_model(
390+
selected_model.__name__,
391+
__base__=selected_model,
392+
__config__=ConfigDict(
393+
populate_by_name=True, serialize_by_alias=True
394+
),
395+
)
396+
281397
# Log the final model with its fields and aliases
282398
field_count = len(final_model.model_fields)
283399
field_info = []

0 commit comments

Comments
 (0)