Skip to content

Commit c875330

Browse files
committed
feat: enhance relationship handling for Pydantic to SQLModel conversion with support for dicts
1 parent b5d7b2d commit c875330

File tree

4 files changed

+188
-20
lines changed

4 files changed

+188
-20
lines changed

sqlmodel/_compat.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,11 @@ def sqlmodel_validate(
343343
# Get and set any relationship objects
344344
if is_table_model_class(cls):
345345
for key in new_obj.__sqlmodel_relationships__:
346-
value = getattr(use_obj, key, Undefined)
346+
# Handle both dict and object access
347+
if isinstance(use_obj, dict):
348+
value = use_obj.get(key, Undefined)
349+
else:
350+
value = getattr(use_obj, key, Undefined)
347351
if value is not Undefined:
348352
setattr(new_obj, key, value)
349353
return new_obj

sqlmodel/main.py

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -924,7 +924,9 @@ def __setattr__(self, name: str, value: Any) -> None:
924924
and name in self.__sqlmodel_relationships__
925925
and value is not None
926926
):
927-
value = _convert_pydantic_to_table_model(value, name, self.__class__)
927+
value = _convert_pydantic_to_table_model(
928+
value, name, self.__class__, self
929+
)
928930

929931
# Set in SQLAlchemy, before Pydantic to trigger events and updates
930932
if is_table_model_class(self.__class__) and is_instrumented(self, name): # type: ignore[no-untyped-call]
@@ -1127,7 +1129,10 @@ def sqlmodel_update(
11271129

11281130

11291131
def _convert_pydantic_to_table_model(
1130-
value: Any, relationship_name: str, owner_class: Type["SQLModel"]
1132+
value: Any,
1133+
relationship_name: str,
1134+
owner_class: Type["SQLModel"],
1135+
instance: Optional["SQLModel"] = None,
11311136
) -> Any:
11321137
"""
11331138
Convert Pydantic objects to table models for relationship assignments.
@@ -1136,6 +1141,7 @@ def _convert_pydantic_to_table_model(
11361141
value: The value being assigned to the relationship
11371142
relationship_name: Name of the relationship attribute
11381143
owner_class: The class that owns the relationship
1144+
instance: The SQLModel instance (for session context)
11391145
11401146
Returns:
11411147
Converted value(s) - table model instances instead of Pydantic objects
@@ -1185,7 +1191,7 @@ def _convert_pydantic_to_table_model(
11851191
converted_items = []
11861192
for item in value:
11871193
converted_item = _convert_single_pydantic_to_table_model(
1188-
item, target_type
1194+
item, target_type, instance
11891195
)
11901196
converted_items.append(converted_item)
11911197
return converted_items
@@ -1198,21 +1204,24 @@ def _convert_pydantic_to_table_model(
11981204
resolved_type = default_registry._class_registry.get(target_type)
11991205
if resolved_type is not None:
12001206
target_type = resolved_type
1201-
except:
1207+
except Exception:
12021208
pass
12031209

1204-
return _convert_single_pydantic_to_table_model(value, target_type)
1210+
return _convert_single_pydantic_to_table_model(value, target_type, instance)
12051211

12061212
return value
12071213

12081214

1209-
def _convert_single_pydantic_to_table_model(item: Any, target_type: Any) -> Any:
1215+
def _convert_single_pydantic_to_table_model(
1216+
item: Any, target_type: Any, instance: Optional["SQLModel"] = None
1217+
) -> Any:
12101218
"""
12111219
Convert a single Pydantic object to a table model.
12121220
12131221
Args:
12141222
item: The Pydantic object to convert
12151223
target_type: The target table model type
1224+
instance: The SQLModel instance (for session context)
12161225
12171226
Returns:
12181227
Converted table model instance or original item if no conversion needed
@@ -1226,7 +1235,9 @@ def _convert_single_pydantic_to_table_model(item: Any, target_type: Any) -> Any:
12261235
try:
12271236
# Attempt to resolve forward reference from the default registry
12281237
# This was part of the original logic and should be kept
1229-
resolved_type_from_registry = default_registry._class_registry.get(target_type)
1238+
resolved_type_from_registry = default_registry._class_registry.get(
1239+
target_type
1240+
)
12301241
if resolved_type_from_registry is not None:
12311242
resolved_target_type = resolved_type_from_registry
12321243
except Exception:
@@ -1235,9 +1246,14 @@ def _convert_single_pydantic_to_table_model(item: Any, target_type: Any) -> Any:
12351246
# `_convert_pydantic_to_table_model` should provide a resolved type.
12361247
# For safety, if it's still a string here, and item is a simple Pydantic model,
12371248
# it's best to return item to avoid errors if no concrete type is found.
1238-
if isinstance(resolved_target_type, str) and isinstance(item, BaseModel) and hasattr(item, "__class__") and not is_table_model_class(item.__class__):
1239-
return item # Fallback if no concrete type can be determined
1240-
pass # Continue if resolved_target_type is now a class or item is not a simple Pydantic model
1249+
if (
1250+
isinstance(resolved_target_type, str)
1251+
and isinstance(item, BaseModel)
1252+
and hasattr(item, "__class__")
1253+
and not is_table_model_class(item.__class__)
1254+
):
1255+
return item # Fallback if no concrete type can be determined
1256+
pass # Continue if resolved_target_type is now a class or item is not a simple Pydantic model
12411257

12421258
# If resolved_target_type is still a string and not a class, we cannot proceed with conversion.
12431259
# This can happen if the forward reference cannot be resolved.
@@ -1253,7 +1269,8 @@ def _convert_single_pydantic_to_table_model(item: Any, target_type: Any) -> Any:
12531269
if not (
12541270
hasattr(resolved_target_type, "__mro__")
12551271
and any(
1256-
hasattr(cls, "__sqlmodel_relationships__") for cls in resolved_target_type.__mro__
1272+
hasattr(cls, "__sqlmodel_relationships__")
1273+
for cls in resolved_target_type.__mro__
12571274
)
12581275
):
12591276
return item
@@ -1278,10 +1295,49 @@ def _convert_single_pydantic_to_table_model(item: Any, target_type: Any) -> Any:
12781295
# Pydantic v1
12791296
data = item.dict()
12801297

1298+
# If instance is available and item has an ID, try to find existing record
1299+
if instance is not None and "id" in data and data["id"] is not None:
1300+
from sqlalchemy.orm import object_session
1301+
1302+
session = object_session(instance)
1303+
if session is not None:
1304+
# Try to find existing record by ID
1305+
existing_record = session.get(resolved_target_type, data["id"])
1306+
if existing_record is not None:
1307+
# Update existing record with new data
1308+
for key, value in data.items():
1309+
if key != "id" and hasattr(existing_record, key):
1310+
setattr(existing_record, key, value)
1311+
return existing_record
1312+
12811313
# Create new table model instance using resolved_target_type
12821314
return resolved_target_type(**data)
12831315
except Exception:
12841316
# If conversion fails, return original item
12851317
return item
12861318

1319+
# Check if item is a dictionary that should be converted to table model
1320+
elif isinstance(item, dict):
1321+
try:
1322+
# If instance is available and item has an ID, try to find existing record
1323+
if instance is not None and "id" in item and item["id"] is not None:
1324+
from sqlalchemy.orm import object_session
1325+
1326+
session = object_session(instance)
1327+
if session is not None:
1328+
# Try to find existing record by ID
1329+
existing_record = session.get(resolved_target_type, item["id"])
1330+
if existing_record is not None:
1331+
# Update existing record with new data
1332+
for key, value in item.items():
1333+
if key != "id" and hasattr(existing_record, key):
1334+
setattr(existing_record, key, value)
1335+
return existing_record
1336+
1337+
# Create new table model instance from dictionary
1338+
return resolved_target_type(**item)
1339+
except Exception:
1340+
# If conversion fails, return original item
1341+
return item
1342+
12871343
return item

tests/test_relationships_set.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine
22

33

4-
def test_relationships_set():
4+
def test_relationships_set_pydantic():
55
class Book(SQLModel, table=True):
66
id: int = Field(default=None, primary_key=True)
77
title: str
@@ -48,9 +48,50 @@ class IAuthorCreate(SQLModel):
4848
assert author.books[1].id is not None
4949
assert author.books[2].id is not None
5050

51-
author.books[0].title = "Updated Book One"
51+
52+
def test_relationships_set_dict():
53+
class Book(SQLModel, table=True):
54+
id: int = Field(default=None, primary_key=True)
55+
title: str
56+
author_id: int = Field(foreign_key="author.id")
57+
author: "Author" = Relationship(back_populates="books")
58+
59+
class IBookCreate(SQLModel):
60+
title: str
61+
62+
class Author(SQLModel, table=True):
63+
id: int = Field(default=None, primary_key=True)
64+
name: str
65+
books: list[Book] = Relationship(back_populates="author")
66+
67+
class IAuthorCreate(SQLModel):
68+
name: str
69+
books: list[IBookCreate] = []
70+
71+
book1 = IBookCreate(title="Book One")
72+
book2 = IBookCreate(title="Book Two")
73+
book3 = IBookCreate(title="Book Three")
74+
75+
author_data = IAuthorCreate(name="Author Name", books=[book1, book2, book3])
76+
77+
author = Author.model_validate(author_data.model_dump(exclude={"id"}))
78+
79+
engine = create_engine("sqlite://")
80+
81+
SQLModel.metadata.create_all(engine)
82+
5283
with Session(engine) as session:
5384
session.add(author)
5485
session.commit()
5586
session.refresh(author)
56-
assert author.books[0].title == "Updated Book One"
87+
assert author.id is not None
88+
assert len(author.books) == 3
89+
assert author.books[0].title == "Book One"
90+
assert author.books[1].title == "Book Two"
91+
assert author.books[2].title == "Book Three"
92+
assert author.books[0].author_id == author.id
93+
assert author.books[1].author_id == author.id
94+
assert author.books[2].author_id == author.id
95+
assert author.books[0].id is not None
96+
assert author.books[1].id is not None
97+
assert author.books[2].id is not None

tests/test_relationships_update.py

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import pytest
1414

1515

16-
def test_relationships_update():
16+
def test_relationships_update_pydantic():
1717
"""Test conversion of single Pydantic model to SQLModel with forward reference."""
1818

1919
class IBookUpdate(BaseModel):
@@ -56,7 +56,7 @@ class Book(SQLModel, table=True):
5656

5757
# Prepare the update data Pydantic model
5858
author_update_dto = IAuthorUpdate(
59-
id=author_id, # This ID in DTO is informational
59+
id=author_id, # This ID in DTO is informational
6060
name="Updated Author",
6161
books=[IBookUpdate(id=book_id, title="Updated Book")],
6262
)
@@ -73,7 +73,7 @@ class Book(SQLModel, table=True):
7373
book_to_update = session.get(Book, book_update_data.id)
7474

7575
if book_to_update:
76-
if book_update_data.title is not None: # Check if title is provided
76+
if book_update_data.title is not None: # Check if title is provided
7777
book_to_update.title = book_update_data.title
7878
processed_books_list.append(book_to_update)
7979
# else:
@@ -82,9 +82,76 @@ class Book(SQLModel, table=True):
8282
# Assign the list of (potentially updated) persistent Book SQLModel objects
8383
db_author.books = processed_books_list
8484

85-
session.add(db_author) # Add the updated instance to the session (marks it as dirty)
85+
session.add(
86+
db_author
87+
) # Add the updated instance to the session (marks it as dirty)
8688
session.commit()
87-
session.refresh(db_author) # Refresh to get the latest state from DB
89+
session.refresh(db_author) # Refresh to get the latest state from DB
90+
91+
# Assertions on the original IDs and updated content
92+
assert db_author.id == author_id
93+
assert db_author.name == "Updated Author"
94+
assert len(db_author.books) == 1
95+
assert db_author.books[0].id == book_id
96+
assert db_author.books[0].title == "Updated Book"
97+
98+
99+
def test_relationships_update_dict():
100+
"""Test conversion of single Pydantic model to SQLModel with forward reference."""
101+
102+
class IBookUpdate(BaseModel):
103+
id: int
104+
title: str | None = None
105+
106+
class IAuthorUpdate(BaseModel):
107+
id: int
108+
name: str | None = None
109+
books: list[IBookUpdate] | None = None
110+
111+
class Author(SQLModel, table=True):
112+
id: Optional[int] = Field(default=None, primary_key=True)
113+
name: str
114+
books: List["Book"] = Relationship(back_populates="author")
115+
116+
class Book(SQLModel, table=True):
117+
id: Optional[int] = Field(default=None, primary_key=True)
118+
title: str
119+
author_id: Optional[int] = Field(default=None, foreign_key="author.id")
120+
author: Optional["Author"] = Relationship(back_populates="books")
121+
122+
engine = create_engine("sqlite://", echo=False)
123+
SQLModel.metadata.create_all(engine)
124+
125+
with Session(engine) as session:
126+
book = Book(title="Test Book")
127+
author = Author(name="Test Author", books=[book])
128+
session.add(author)
129+
session.commit()
130+
session.refresh(author)
131+
132+
author_id = author.id
133+
book_id = book.id
134+
135+
with Session(engine) as session:
136+
# Fetch the existing author
137+
db_author = session.get(Author, author_id)
138+
assert db_author is not None, "Author to update was not found in the database."
139+
140+
# Prepare the update data Pydantic model
141+
author_update_dto = IAuthorUpdate(
142+
id=author_id, # This ID in DTO is informational
143+
name="Updated Author",
144+
books=[IBookUpdate(id=book_id, title="Updated Book")],
145+
)
146+
147+
update_data = author_update_dto.model_dump()
148+
149+
for field in update_data:
150+
setattr(db_author, field, update_data[field])
151+
152+
session.add(db_author)
153+
session.commit()
154+
session.refresh(db_author)
88155

89156
# Assertions on the original IDs and updated content
90157
assert db_author.id == author_id

0 commit comments

Comments
 (0)