Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 91 additions & 17 deletions src/sparkdantic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
Annotated,
Any,
Dict,
ForwardRef,
Literal,
Optional,
Set,
Type,
Union,
get_args,
Expand All @@ -31,6 +33,7 @@
SecretBytes,
SecretStr,
)
from pydantic.errors import PydanticSchemaGenerationError, PydanticUndefinedAnnotation
from pydantic.fields import ComputedFieldInfo, FieldInfo
from pydantic.json_schema import JsonSchemaMode

Expand Down Expand Up @@ -173,6 +176,7 @@ def create_json_spark_schema(
by_alias: bool = True,
mode: JsonSchemaMode = 'validation',
exclude_fields: bool = False,
_visited_models: Optional[Set[Type[BaseModel]]] = None,
) -> Dict[str, Any]:
"""Generates a PySpark JSON compatible schema from the model fields. This operates similarly to
`pydantic.BaseModel.model_json_schema()`.
Expand All @@ -184,6 +188,7 @@ def create_json_spark_schema(
mode (pydantic.json_schema.JsonSchemaMode): The mode in which to generate the schema.
exclude_fields (bool): Indicates whether to exclude fields from the schema. Fields to be excluded should
be annotated with `Field(exclude=True)` field attribute
_visited_models: Internal parameter to track visited models and prevent infinite recursion

Returns:
Dict[str, Any]: The generated PySpark JSON schema
Expand All @@ -192,7 +197,29 @@ def create_json_spark_schema(
raise TypeError('`model` must be of type `SparkModel` or `pydantic.BaseModel`')

if mode not in get_args(JsonSchemaMode):
raise ValueError(f'`mode` must be one of {get_args(JsonSchemaMode)}')
raise ValueError(f"`mode` must be one of {get_args(JsonSchemaMode)}")

# Initialize visited models set if not provided
if _visited_models is None:
_visited_models = set()

# Check for circular references
if model in _visited_models:
# Return a placeholder for circular references
return {'type': 'struct', 'fields': []}

# Add current model to visited set
_visited_models = _visited_models.copy() # Make a copy to avoid modifying the original
_visited_models.add(model)

# Resolve forward references in the model before processing
if hasattr(model, 'model_rebuild'):
try:
model.model_rebuild()
except (PydanticUndefinedAnnotation, PydanticSchemaGenerationError):
# If rebuilding fails due to undefined annotations or schema generation errors,
# continue anyway as the model might still be usable
pass

fields = []
for name, info in _get_schema_items(model, mode):
Comment on lines +211 to 225
Copy link

Copilot AI Sep 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The circular reference detection logic creates a copy of the visited models set for each recursive call, which could be inefficient for deeply nested structures. Consider using a context manager or tracking depth instead of copying the entire set each time.

Suggested change
# Add current model to visited set
_visited_models = _visited_models.copy() # Make a copy to avoid modifying the original
_visited_models.add(model)
# Resolve forward references in the model before processing
if hasattr(model, 'model_rebuild'):
try:
model.model_rebuild()
except (PydanticUndefinedAnnotation, PydanticSchemaGenerationError):
# If rebuilding fails due to undefined annotations or schema generation errors,
# continue anyway as the model might still be usable
pass
fields = []
for name, info in _get_schema_items(model, mode):
# Add current model to visited set (in-place), and ensure removal after processing
_visited_models.add(model)
try:
# Resolve forward references in the model before processing
if hasattr(model, 'model_rebuild'):
try:
model.model_rebuild()
except (PydanticUndefinedAnnotation, PydanticSchemaGenerationError):
# If rebuilding fails due to undefined annotations or schema generation errors,
# continue anyway as the model might still be usable
pass
fields = []
for name, info in _get_schema_items(model, mode):
# ... rest of the loop and function body ...
return {
'type': 'struct',
'fields': fields,
}
finally:
_visited_models.remove(model)

Copilot uses AI. Check for mistakes.
Expand All @@ -219,7 +246,7 @@ def create_json_spark_schema(
try:
if _is_base_model(field_type):
spark_type = create_json_spark_schema(
field_type, safe_casting, by_alias, mode, exclude_fields
field_type, safe_casting, by_alias, mode, exclude_fields, _visited_models
)
elif override is not None:
if isinstance(override, str):
Expand All @@ -232,20 +259,51 @@ def create_json_spark_schema(
msg = '`spark_type` override should be a `str` type name (e.g. long)'
if utils.have_pyspark:
msg += ' or `pyspark.sql.types.DataType` (e.g. LongType)'
msg += f', but got {override}'
msg += f", but got {override}"
raise TypeError(msg)
elif isinstance(field_type, str):
spark_type = field_type
# field_type is a string (likely an unresolved forward reference)
# Try to get it from the model's namespace
if hasattr(model, '__module__'):
import sys
Copy link

Copilot AI Sep 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import sys statement should be moved to the top of the file with other imports rather than being imported inline. This follows Python best practices and improves code clarity.

Suggested change
import sys

Copilot uses AI. Check for mistakes.

module = sys.modules.get(model.__module__)
if module and hasattr(module, field_type):
resolved_type = getattr(module, field_type)
if _is_base_model(resolved_type):
spark_type = create_json_spark_schema(
resolved_type,
safe_casting,
by_alias,
mode,
exclude_fields,
_visited_models,
)
else:
# If it's not a BaseModel, treat as string type
spark_type = 'string'
else:
# Could not resolve, default to string type
spark_type = 'string'
else:
# No module info, default to string type
spark_type = 'string'
elif utils.have_pyspark and _is_spark_datatype(field_type):
spark_type = field_type.typeName()
else:
metadata = _get_metadata(info)
spark_type = _from_python_type(
field_type, metadata, safe_casting, by_alias, mode, exclude_fields
field_type,
metadata,
safe_casting,
by_alias,
mode,
exclude_fields,
_visited_models,
)
except Exception as raised_error:
raise TypeConversionError(
f'Error converting field `{name}` to PySpark type'
f"Error converting field `{name}` to PySpark type"
) from raised_error

nullable = _is_optional(annotation_or_return_type)
Expand Down Expand Up @@ -326,7 +384,7 @@ def _get_spark_type(t: Type) -> str:
"""
spark_type = _type_mapping.get(t)
if spark_type is None:
raise TypeError(f'Type {t} not recognized')
raise TypeError(f"Type {t} not recognized")
return spark_type


Expand All @@ -347,7 +405,7 @@ def _get_enum_mixin_type(t: EnumType) -> MixinType:
elif issubclass(t, str):
return str
else:
raise TypeError(f'Enum {t} is not supported. Only int and str mixins are supported.')
raise TypeError(f"Enum {t} is not supported. Only int and str mixins are supported.")


def _from_python_type(
Expand All @@ -357,6 +415,7 @@ def _from_python_type(
by_alias: bool = True,
mode: JsonSchemaMode = 'validation',
exclude_fields: bool = False,
_visited_models: Optional[Set[Type[BaseModel]]] = None,
) -> Union[str, Dict[str, Any]]:
"""Converts a Python type to a corresponding PySpark data type.

Expand All @@ -369,20 +428,30 @@ def _from_python_type(
Returns:
Union[str, Dict[str, Any]]: The corresponding PySpark data type (dict for complex types).
"""
# Handle ForwardRef types (unresolved forward references)
if isinstance(type_, ForwardRef):
# For unresolved forward references, default to string type
# This is safer than trying to resolve them which might fail
return 'string'

py_type = _get_union_type_arg(type_)

if _is_base_model(py_type):
return create_json_spark_schema(py_type, safe_casting, by_alias, mode, exclude_fields)
return create_json_spark_schema(
py_type, safe_casting, by_alias, mode, exclude_fields, _visited_models
)

args = get_args(py_type)
origin = get_origin(py_type)

if origin is None and py_type in (list, dict):
raise TypeError(f'Type argument(s) missing from {py_type.__name__}')
raise TypeError(f"Type argument(s) missing from {py_type.__name__}")

# Convert complex types
if origin is list:
element_type = _from_python_type(args[0], [], safe_casting, by_alias, mode, exclude_fields)
element_type = _from_python_type(
args[0], [], safe_casting, by_alias, mode, exclude_fields, _visited_models
)
contains_null = _is_optional(args[0])
return {
'type': 'array',
Expand All @@ -391,8 +460,12 @@ def _from_python_type(
}

if origin is dict:
key_type = _from_python_type(args[0], [], safe_casting, by_alias, mode, exclude_fields)
value_type = _from_python_type(args[1], [], safe_casting, by_alias, mode, exclude_fields)
key_type = _from_python_type(
args[0], [], safe_casting, by_alias, mode, exclude_fields, _visited_models
)
value_type = _from_python_type(
args[1], [], safe_casting, by_alias, mode, exclude_fields, _visited_models
)
value_contains_null = _is_optional(args[1])
return {
'type': 'map',
Expand All @@ -415,7 +488,8 @@ def _from_python_type(
# first arg of annotated type is the type, second is metadata that we don't do anything with (yet)
py_type = args[0]

if issubclass(py_type, Enum):
# Check if py_type is a class before using issubclass
if inspect.isclass(py_type) and issubclass(py_type, Enum):
py_type = _get_enum_mixin_type(py_type)

spark_type = _get_spark_type(py_type)
Expand All @@ -427,7 +501,7 @@ def _from_python_type(
meta = None if len(metadata) < 1 else deepcopy(metadata).pop()
max_digits = getattr(meta, 'max_digits', 10)
decimal_places = getattr(meta, 'decimal_places', 0)
return f'decimal({max_digits}, {decimal_places})'
return f"decimal({max_digits}, {decimal_places})"

return spark_type

Expand Down Expand Up @@ -604,12 +678,12 @@ def _json_type_to_ddl(json_type: Union[str, Dict[str, Any]]) -> str:

elif json_type['type'] == 'array':
element_type = _json_type_to_ddl(json_type['elementType'])
return f'ARRAY<{element_type}>'
return f"ARRAY<{element_type}>"

elif json_type['type'] == 'map':
key_type = _json_type_to_ddl(json_type['keyType'])
value_type = _json_type_to_ddl(json_type['valueType'])
return f'MAP<{key_type}, {value_type}>'
return f"MAP<{key_type}, {value_type}>"
else:
raise TypeError(f"Unsupported JSON type: {json_type['type']}")

Expand Down
6 changes: 3 additions & 3 deletions tests/test_computed_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import BaseModel, computed_field
from pyspark.sql.types import IntegerType, LongType, StringType, StructField, StructType

from sparkdantic import create_spark_schema, SparkField
from sparkdantic import SparkField, create_spark_schema


class ComputedOnlyModel(BaseModel):
Expand Down Expand Up @@ -97,7 +97,7 @@ def d(self) -> int:

def test_computed_field_with_spark_type():
class ComputedWithSparkType(BaseModel):
@computed_field(json_schema_extra={"spark_type": LongType})
@computed_field(json_schema_extra={'spark_type': LongType})
@property
def d(self) -> int:
return 4
Expand Down Expand Up @@ -129,7 +129,7 @@ def d(self) -> Annotated[int, SparkField(spark_type=LongType)]:

def test_computed_field_with_spark_type_over_annotated_return():
class ComputedWithSparkType(BaseModel):
@computed_field(json_schema_extra={"spark_type": LongType})
@computed_field(json_schema_extra={'spark_type': LongType})
@property
def d(self) -> Annotated[int, SparkField(spark_type=StringType)]:
return 4
Expand Down
Loading