Skip to content

Conversation

@rjurney
Copy link

@rjurney rjurney commented Sep 11, 2025

Summary

Changes

  1. Forward Reference Resolution: Added model_rebuild() call to resolve forward references before processing
  2. ForwardRef Type Handling: Explicitly handle ForwardRef types, defaulting them to string type when unresolved
  3. Type Safety: Added inspect.isclass() check before issubclass() to prevent TypeError with non-class types
  4. Circular Reference Detection: Implemented visited models tracking to prevent infinite recursion in self-referential models
  5. Specific Exception Handling: Catch specific Pydantic exceptions (PydanticUndefinedAnnotation, PydanticSchemaGenerationError) instead of generic Exception

Test plan

Added comprehensive test suite (tests/test_forward_references.py) covering:

  • Models with forward references
  • Enums in nested models
  • Undefined forward references
  • Circular references (self-referential models)

Also included integration test (test_forward_ref_fix.py) demonstrating the fix works with real BAML-generated Company/Ticker models.

All tests pass successfully.

Related Issues

Fixes #798

- Add model_rebuild() call to resolve forward references before processing
- Handle ForwardRef types explicitly, defaulting to string type
- Add isclass() check before issubclass() to prevent TypeError with non-class types
- Implement circular reference detection to prevent infinite recursion
- Add comprehensive test suite for forward references, enums, and circular dependencies
- Catch specific Pydantic exceptions (PydanticUndefinedAnnotation, PydanticSchemaGenerationError)

Fixes mitchelllisle#798
This test demonstrates the fix working with real-world BAML-generated models
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove this file? it looks like a personal test script that isn't relevant for this repo

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes forward references and circular dependencies in Pydantic models when converting to PySpark schemas. The fix ensures BAML-generated models with forward references can properly convert to PySpark schemas without crashes or infinite recursion.

  • Adds forward reference resolution via model_rebuild() calls before schema processing
  • Implements circular reference detection to prevent infinite recursion
  • Provides fallback handling for unresolved forward references

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
src/sparkdantic/model.py Core implementation of forward reference resolution, circular dependency tracking, and improved error handling
tests/test_forward_references.py Comprehensive test suite covering forward references, circular dependencies, and enum handling scenarios
tests/test_computed_field.py Minor code style improvements (quote consistency)

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

# field_type is a string (likely an unresolved forward reference)
# Try to get it from the model's namespace
if hasattr(model, '__module__'):
import sys
Copy link

Copilot AI Sep 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import sys statement should be moved to the top of the file with other imports rather than being imported inline. This follows Python best practices and improves code clarity.

Suggested change
import sys

Copilot uses AI. Check for mistakes.
Comment on lines +211 to 225
# Add current model to visited set
_visited_models = _visited_models.copy() # Make a copy to avoid modifying the original
_visited_models.add(model)

# Resolve forward references in the model before processing
if hasattr(model, 'model_rebuild'):
try:
model.model_rebuild()
except (PydanticUndefinedAnnotation, PydanticSchemaGenerationError):
# If rebuilding fails due to undefined annotations or schema generation errors,
# continue anyway as the model might still be usable
pass

fields = []
for name, info in _get_schema_items(model, mode):
Copy link

Copilot AI Sep 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The circular reference detection logic creates a copy of the visited models set for each recursive call, which could be inefficient for deeply nested structures. Consider using a context manager or tracking depth instead of copying the entire set each time.

Suggested change
# Add current model to visited set
_visited_models = _visited_models.copy() # Make a copy to avoid modifying the original
_visited_models.add(model)
# Resolve forward references in the model before processing
if hasattr(model, 'model_rebuild'):
try:
model.model_rebuild()
except (PydanticUndefinedAnnotation, PydanticSchemaGenerationError):
# If rebuilding fails due to undefined annotations or schema generation errors,
# continue anyway as the model might still be usable
pass
fields = []
for name, info in _get_schema_items(model, mode):
# Add current model to visited set (in-place), and ensure removal after processing
_visited_models.add(model)
try:
# Resolve forward references in the model before processing
if hasattr(model, 'model_rebuild'):
try:
model.model_rebuild()
except (PydanticUndefinedAnnotation, PydanticSchemaGenerationError):
# If rebuilding fails due to undefined annotations or schema generation errors,
# continue anyway as the model might still be usable
pass
fields = []
for name, info in _get_schema_items(model, mode):
# ... rest of the loop and function body ...
return {
'type': 'struct',
'fields': fields,
}
finally:
_visited_models.remove(model)

Copilot uses AI. Check for mistakes.
@mitchstockdale
Copy link
Contributor

Thanks for identifying an issue and raising a PR to address it, @rjurney!

There appear to be three main changes in this PR:

  1. Resolve forward references before schema generation
  2. Fix recursive (self-referencing) models
  3. Fix string Enum

On 1, if a model's forward references can be resolved via BaseModel.model_rebuild, is this change required? If a client expects forward refs, can they just call this method before creating the Spark schema? Correct me if I’m wrong, but this solves the original problem in the issue you raised.

On 2 (related to 1 due to forward references), unbounded recursive models don’t make much sense to me w.r.t generating structured Spark schemas. A recursive model could either map to an unstructured type (i.e. StringType, I like your choice here 👍) or a semi-structured type e.g. VariantType. I'm reluctant to handle this scenario by extending create_json_spark_schema with a private function parameter. Happy to discuss, design, and reach agreement in a separate issue on this. Additionally, do you currently face this (self-reference) problem in your original issue?

On 3, I’m not sure if there was a problem with the existing code. After rebuilding the model in your original issue, did the string Enum field still fail schema generation?

@rjurney
Copy link
Author

rjurney commented Sep 14, 2025

For one thing, you have to instantiate an object and run that method by calling model rebuild, which isn’t the published class based API and will trip many people up. The recursive stuff I can remove, Claude did that. As to Enumerate, yes your string name of the class was the issue with a string enum. It wasn’t pulling the class.

@mitchstockdale
Copy link
Contributor

The recursive stuff I can remove, Claude did that.

Thanks 👍

For one thing, you have to instantiate an object and run that method by calling model rebuild, which isn’t the published class based API and will trip many people up... As to Enumerate, yes your string name of the class was the issue with a string enum. It wasn’t pulling the class.

I see, it's because the Enum is also a forward reference (not quite what the PR summary suggests). Thanks for clarifying the issue.

BaseModel.model_rebuild is a classmethod, is quite well documented, and tries to rebuild the schema. It doesn't require model instantiation. Using a similar example to yours, the workaround for this is:

from enum import StrEnum
from typing import Optional
from pydantic import BaseModel
from sparkdantic import SparkModel

class Parent(BaseModel):
    child: Optional["Child"] = None

class Child(BaseModel):
    bar: Optional["Bar"] = None

class Bar(StrEnum):
    FOO = "FOO"

class SparkParent(Parent, SparkModel):
    pass

Child.model_rebuild()
SparkParent.model_rebuild(force=True)
SparkParent.model_json_spark_schema()
# `{'type': 'struct', 'fields': [{'name': 'child', 'type': {'type': 'struct', 'fields': [{'name': 'bar', 'type': 'string', 'nullable': True, 'metadata': {}}]}, 'nullable': True, 'metadata': {}}]}`

Note:

  • Unfortunately, it doesn't seem possible to resolve forward references for recursive models (i.e. just call model_rebuild on SparkParent, not Child).
  • If Bar and Child were defined in the right order, model_rebuild is not required.

I'm aware that this might require a complex workaround if the hierarchy of your models is unknown at runtime, which would be the case for generated models.

My strong preference would be that, given that model_rebuild is already provided by the pydantic BaseModel, we should encourage it's use. If a ForwardRef field is encountered, we could instead raise an exception with a message that includes:

  • model name; and
  • model field; and
  • suggests the client should call model_rebuild to fix the forward references before Spark schema generation; or
  • suggests the client use spark_type override

I'm reluctant to default ForwardRef fields to string type as is done here. In the above example, the schema would be:
{'type': 'struct', 'fields': [{'name': 'child', 'type': 'string', 'nullable': True, 'metadata': {}}]}

I would prefer this to be an explicit option/flag for clients instead, either in the field definition (as an override) or another parameter in create_json_spark_schema.

@rjurney
Copy link
Author

rjurney commented Sep 16, 2025

Let me double check, but the only way I could use the model at all was to subclass both my model and SparkModel, instantiate it, then serialize the schema.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

SparkModel.model_spark_schema() can't convert a nested BaseModel field

3 participants