Skip to content

Commit e141966

Browse files
committed
Add comprehensive tests for Pydantic to SQLModel conversion and relationship updates
- Implement tests for single and list relationships in Pydantic to SQLModel conversion. - Validate mixed assignments of Pydantic and SQLModel instances. - Test database integration to ensure converted models work with database operations. - Add tests for edge cases and performance characteristics of relationship updates. - Ensure proper handling of forward references in relationships. - Create simple tests for basic relationship updates with Pydantic models.
1 parent 683b0c7 commit e141966

24 files changed

+2611
-57
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ site
1313
.cache
1414
.venv*
1515
uv.lock
16+
.timetracker

pyproject.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,12 @@ known-third-party = ["sqlmodel", "sqlalchemy", "pydantic", "fastapi"]
134134
[tool.ruff.lint.pyupgrade]
135135
# Preserve types, even if a file imports `from __future__ import annotations`.
136136
keep-runtime-typing = true
137+
138+
[dependency-groups]
139+
dev = [
140+
"coverage>=7.2.7",
141+
"dirty-equals>=0.7.1.post0",
142+
"fastapi>=0.103.2",
143+
"httpx>=0.24.1",
144+
"pytest>=7.4.4",
145+
]

sqlmodel/main.py

Lines changed: 169 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@
5555
registry,
5656
relationship,
5757
)
58-
from sqlalchemy.orm.properties import MappedSQLExpression
5958
from sqlalchemy.orm.attributes import set_attribute
6059
from sqlalchemy.orm.decl_api import DeclarativeMeta
6160
from sqlalchemy.orm.instrumentation import is_instrumented
61+
from sqlalchemy.orm.properties import MappedSQLExpression
6262
from sqlalchemy.sql.schema import MetaData
6363
from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid
6464
from typing_extensions import Literal, TypeAlias, deprecated, get_origin
@@ -918,6 +918,14 @@ def __setattr__(self, name: str, value: Any) -> None:
918918
self.__dict__[name] = value
919919
return
920920
else:
921+
# Convert Pydantic objects to table models for relationships
922+
if (
923+
is_table_model_class(self.__class__)
924+
and name in self.__sqlmodel_relationships__
925+
and value is not None
926+
):
927+
value = _convert_pydantic_to_table_model(value, name, self.__class__)
928+
921929
# Set in SQLAlchemy, before Pydantic to trigger events and updates
922930
if is_table_model_class(self.__class__) and is_instrumented(self, name): # type: ignore[no-untyped-call]
923931
set_attribute(self, name, value)
@@ -1116,3 +1124,163 @@ def sqlmodel_update(
11161124
f"is not a dict or SQLModel or Pydantic model: {obj}"
11171125
)
11181126
return self
1127+
1128+
1129+
def _convert_pydantic_to_table_model(
1130+
value: Any, relationship_name: str, owner_class: Type["SQLModel"]
1131+
) -> Any:
1132+
"""
1133+
Convert Pydantic objects to table models for relationship assignments.
1134+
1135+
Args:
1136+
value: The value being assigned to the relationship
1137+
relationship_name: Name of the relationship attribute
1138+
owner_class: The class that owns the relationship
1139+
1140+
Returns:
1141+
Converted value(s) - table model instances instead of Pydantic objects
1142+
"""
1143+
from typing import get_args, get_origin
1144+
1145+
# Get the relationship annotation to determine target type
1146+
if relationship_name not in owner_class.__annotations__:
1147+
return value
1148+
1149+
raw_ann = owner_class.__annotations__[relationship_name]
1150+
origin = get_origin(raw_ann)
1151+
1152+
# Handle Mapped[...] annotations
1153+
if origin is Mapped:
1154+
ann = raw_ann.__args__[0]
1155+
else:
1156+
ann = raw_ann
1157+
1158+
# Get the target relationship type
1159+
try:
1160+
rel_info = owner_class.__sqlmodel_relationships__[relationship_name]
1161+
relationship_to = get_relationship_to(
1162+
name=relationship_name, rel_info=rel_info, annotation=ann
1163+
)
1164+
except (KeyError, AttributeError):
1165+
return value
1166+
1167+
# Handle list/sequence relationships
1168+
list_origin = get_origin(ann)
1169+
if list_origin is list:
1170+
target_type = get_args(ann)[0]
1171+
if isinstance(target_type, str):
1172+
# Forward reference - try to resolve from SQLAlchemy's registry
1173+
try:
1174+
resolved_type = default_registry._class_registry.get(target_type)
1175+
if resolved_type is not None:
1176+
target_type = resolved_type
1177+
else:
1178+
target_type = relationship_to
1179+
except Exception:
1180+
target_type = relationship_to
1181+
else:
1182+
target_type = relationship_to
1183+
1184+
if isinstance(value, (list, tuple)):
1185+
converted_items = []
1186+
for item in value:
1187+
converted_item = _convert_single_pydantic_to_table_model(
1188+
item, target_type
1189+
)
1190+
converted_items.append(converted_item)
1191+
return converted_items
1192+
else:
1193+
# Single relationship
1194+
target_type = relationship_to
1195+
if isinstance(target_type, str):
1196+
# Forward reference - try to resolve from SQLAlchemy's registry
1197+
try:
1198+
resolved_type = default_registry._class_registry.get(target_type)
1199+
if resolved_type is not None:
1200+
target_type = resolved_type
1201+
except:
1202+
pass
1203+
1204+
return _convert_single_pydantic_to_table_model(value, target_type)
1205+
1206+
return value
1207+
1208+
1209+
def _convert_single_pydantic_to_table_model(item: Any, target_type: Any) -> Any:
1210+
"""
1211+
Convert a single Pydantic object to a table model.
1212+
1213+
Args:
1214+
item: The Pydantic object to convert
1215+
target_type: The target table model type
1216+
1217+
Returns:
1218+
Converted table model instance or original item if no conversion needed
1219+
"""
1220+
# If item is None, return as-is
1221+
if item is None:
1222+
return item
1223+
1224+
# If target_type is a string (forward reference), try to resolve it
1225+
if isinstance(target_type, str):
1226+
try:
1227+
resolved_type = default_registry._class_registry.get(target_type)
1228+
if resolved_type is not None:
1229+
target_type = resolved_type
1230+
except Exception:
1231+
pass
1232+
1233+
# If target_type is still a string after resolution attempt,
1234+
# we can't perform type checks or conversions
1235+
if isinstance(target_type, str):
1236+
# If item is a BaseModel but not a table model, try conversion
1237+
if (
1238+
isinstance(item, BaseModel)
1239+
and hasattr(item, "__class__")
1240+
and not is_table_model_class(item.__class__)
1241+
):
1242+
# Can't convert without knowing the actual target type
1243+
return item
1244+
else:
1245+
return item
1246+
1247+
# If item is already the correct type, return as-is
1248+
if isinstance(item, target_type):
1249+
return item
1250+
1251+
# Check if target_type is a SQLModel table class
1252+
if not (
1253+
hasattr(target_type, "__mro__")
1254+
and any(
1255+
hasattr(cls, "__sqlmodel_relationships__") for cls in target_type.__mro__
1256+
)
1257+
):
1258+
return item
1259+
1260+
# Check if target is a table model
1261+
if not is_table_model_class(target_type):
1262+
return item
1263+
1264+
# Check if item is a BaseModel (Pydantic model) but not a table model
1265+
if (
1266+
isinstance(item, BaseModel)
1267+
and hasattr(item, "__class__")
1268+
and not is_table_model_class(item.__class__)
1269+
):
1270+
# Convert Pydantic model to table model
1271+
try:
1272+
# Get the data from the Pydantic model
1273+
if hasattr(item, "model_dump"):
1274+
# Pydantic v2
1275+
data = item.model_dump()
1276+
else:
1277+
# Pydantic v1
1278+
data = item.dict()
1279+
1280+
# Create new table model instance
1281+
return target_type(**data)
1282+
except Exception:
1283+
# If conversion fails, return original item
1284+
return item
1285+
1286+
return item

0 commit comments

Comments
 (0)