-
Notifications
You must be signed in to change notification settings - Fork 22
Fix forward references and circular dependencies in Pydantic models #799
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- Add model_rebuild() call to resolve forward references before processing - Handle ForwardRef types explicitly, defaulting to string type - Add isclass() check before issubclass() to prevent TypeError with non-class types - Implement circular reference detection to prevent infinite recursion - Add comprehensive test suite for forward references, enums, and circular dependencies - Catch specific Pydantic exceptions (PydanticUndefinedAnnotation, PydanticSchemaGenerationError) Fixes mitchelllisle#798
This test demonstrates the fix working with real-world BAML-generated models
test_forward_ref_fix.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you remove this file? it looks like a personal test script that isn't relevant for this repo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes forward references and circular dependencies in Pydantic models when converting to PySpark schemas. The fix ensures BAML-generated models with forward references can properly convert to PySpark schemas without crashes or infinite recursion.
- Adds forward reference resolution via
model_rebuild()calls before schema processing - Implements circular reference detection to prevent infinite recursion
- Provides fallback handling for unresolved forward references
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
src/sparkdantic/model.py |
Core implementation of forward reference resolution, circular dependency tracking, and improved error handling |
tests/test_forward_references.py |
Comprehensive test suite covering forward references, circular dependencies, and enum handling scenarios |
tests/test_computed_field.py |
Minor code style improvements (quote consistency) |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| # field_type is a string (likely an unresolved forward reference) | ||
| # Try to get it from the model's namespace | ||
| if hasattr(model, '__module__'): | ||
| import sys |
Copilot
AI
Sep 13, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The import sys statement should be moved to the top of the file with other imports rather than being imported inline. This follows Python best practices and improves code clarity.
| import sys |
| # Add current model to visited set | ||
| _visited_models = _visited_models.copy() # Make a copy to avoid modifying the original | ||
| _visited_models.add(model) | ||
|
|
||
| # Resolve forward references in the model before processing | ||
| if hasattr(model, 'model_rebuild'): | ||
| try: | ||
| model.model_rebuild() | ||
| except (PydanticUndefinedAnnotation, PydanticSchemaGenerationError): | ||
| # If rebuilding fails due to undefined annotations or schema generation errors, | ||
| # continue anyway as the model might still be usable | ||
| pass | ||
|
|
||
| fields = [] | ||
| for name, info in _get_schema_items(model, mode): |
Copilot
AI
Sep 13, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The circular reference detection logic creates a copy of the visited models set for each recursive call, which could be inefficient for deeply nested structures. Consider using a context manager or tracking depth instead of copying the entire set each time.
| # Add current model to visited set | |
| _visited_models = _visited_models.copy() # Make a copy to avoid modifying the original | |
| _visited_models.add(model) | |
| # Resolve forward references in the model before processing | |
| if hasattr(model, 'model_rebuild'): | |
| try: | |
| model.model_rebuild() | |
| except (PydanticUndefinedAnnotation, PydanticSchemaGenerationError): | |
| # If rebuilding fails due to undefined annotations or schema generation errors, | |
| # continue anyway as the model might still be usable | |
| pass | |
| fields = [] | |
| for name, info in _get_schema_items(model, mode): | |
| # Add current model to visited set (in-place), and ensure removal after processing | |
| _visited_models.add(model) | |
| try: | |
| # Resolve forward references in the model before processing | |
| if hasattr(model, 'model_rebuild'): | |
| try: | |
| model.model_rebuild() | |
| except (PydanticUndefinedAnnotation, PydanticSchemaGenerationError): | |
| # If rebuilding fails due to undefined annotations or schema generation errors, | |
| # continue anyway as the model might still be usable | |
| pass | |
| fields = [] | |
| for name, info in _get_schema_items(model, mode): | |
| # ... rest of the loop and function body ... | |
| return { | |
| 'type': 'struct', | |
| 'fields': fields, | |
| } | |
| finally: | |
| _visited_models.remove(model) |
|
Thanks for identifying an issue and raising a PR to address it, @rjurney! There appear to be three main changes in this PR:
On 1, if a model's forward references can be resolved via On 2 (related to 1 due to forward references), unbounded recursive models don’t make much sense to me w.r.t generating structured Spark schemas. A recursive model could either map to an unstructured type (i.e. On 3, I’m not sure if there was a problem with the existing code. After rebuilding the model in your original issue, did the string |
|
For one thing, you have to instantiate an object and run that method by calling model rebuild, which isn’t the published class based API and will trip many people up. The recursive stuff I can remove, Claude did that. As to Enumerate, yes your string name of the class was the issue with a string enum. It wasn’t pulling the class. |
Thanks 👍
I see, it's because the BaseModel.model_rebuild is a from enum import StrEnum
from typing import Optional
from pydantic import BaseModel
from sparkdantic import SparkModel
class Parent(BaseModel):
child: Optional["Child"] = None
class Child(BaseModel):
bar: Optional["Bar"] = None
class Bar(StrEnum):
FOO = "FOO"
class SparkParent(Parent, SparkModel):
pass
Child.model_rebuild()
SparkParent.model_rebuild(force=True)
SparkParent.model_json_spark_schema()
# `{'type': 'struct', 'fields': [{'name': 'child', 'type': {'type': 'struct', 'fields': [{'name': 'bar', 'type': 'string', 'nullable': True, 'metadata': {}}]}, 'nullable': True, 'metadata': {}}]}`Note:
I'm aware that this might require a complex workaround if the hierarchy of your models is unknown at runtime, which would be the case for generated models. My strong preference would be that, given that
I'm reluctant to default I would prefer this to be an explicit option/flag for clients instead, either in the field definition (as an override) or another parameter in |
|
Let me double check, but the only way I could use the model at all was to subclass both my model and SparkModel, instantiate it, then serialize the schema. |
Summary
SparkModel.model_spark_schema()can't convert a nestedBaseModelfield #798 where BAML-generated models with forward references fail to convert to PySpark schemasChanges
model_rebuild()call to resolve forward references before processingForwardReftypes, defaulting them to string type when unresolvedinspect.isclass()check beforeissubclass()to prevent TypeError with non-class typesPydanticUndefinedAnnotation,PydanticSchemaGenerationError) instead of generic ExceptionTest plan
Added comprehensive test suite (
tests/test_forward_references.py) covering:Also included integration test (
test_forward_ref_fix.py) demonstrating the fix works with real BAML-generated Company/Ticker models.All tests pass successfully.
Related Issues
Fixes #798