|
55 | 55 | registry, |
56 | 56 | relationship, |
57 | 57 | ) |
58 | | -from sqlalchemy.orm.properties import MappedSQLExpression |
59 | 58 | from sqlalchemy.orm.attributes import set_attribute |
60 | 59 | from sqlalchemy.orm.decl_api import DeclarativeMeta |
61 | 60 | from sqlalchemy.orm.instrumentation import is_instrumented |
| 61 | +from sqlalchemy.orm.properties import MappedSQLExpression |
62 | 62 | from sqlalchemy.sql.schema import MetaData |
63 | 63 | from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid |
64 | 64 | from typing_extensions import Literal, TypeAlias, deprecated, get_origin |
@@ -918,6 +918,14 @@ def __setattr__(self, name: str, value: Any) -> None: |
918 | 918 | self.__dict__[name] = value |
919 | 919 | return |
920 | 920 | else: |
| 921 | + # Convert Pydantic objects to table models for relationships |
| 922 | + if ( |
| 923 | + is_table_model_class(self.__class__) |
| 924 | + and name in self.__sqlmodel_relationships__ |
| 925 | + and value is not None |
| 926 | + ): |
| 927 | + value = _convert_pydantic_to_table_model(value, name, self.__class__) |
| 928 | + |
921 | 929 | # Set in SQLAlchemy, before Pydantic to trigger events and updates |
922 | 930 | if is_table_model_class(self.__class__) and is_instrumented(self, name): # type: ignore[no-untyped-call] |
923 | 931 | set_attribute(self, name, value) |
@@ -1116,3 +1124,163 @@ def sqlmodel_update( |
1116 | 1124 | f"is not a dict or SQLModel or Pydantic model: {obj}" |
1117 | 1125 | ) |
1118 | 1126 | return self |
| 1127 | + |
| 1128 | + |
| 1129 | +def _convert_pydantic_to_table_model( |
| 1130 | + value: Any, relationship_name: str, owner_class: Type["SQLModel"] |
| 1131 | +) -> Any: |
| 1132 | + """ |
| 1133 | + Convert Pydantic objects to table models for relationship assignments. |
| 1134 | +
|
| 1135 | + Args: |
| 1136 | + value: The value being assigned to the relationship |
| 1137 | + relationship_name: Name of the relationship attribute |
| 1138 | + owner_class: The class that owns the relationship |
| 1139 | +
|
| 1140 | + Returns: |
| 1141 | + Converted value(s) - table model instances instead of Pydantic objects |
| 1142 | + """ |
| 1143 | + from typing import get_args, get_origin |
| 1144 | + |
| 1145 | + # Get the relationship annotation to determine target type |
| 1146 | + if relationship_name not in owner_class.__annotations__: |
| 1147 | + return value |
| 1148 | + |
| 1149 | + raw_ann = owner_class.__annotations__[relationship_name] |
| 1150 | + origin = get_origin(raw_ann) |
| 1151 | + |
| 1152 | + # Handle Mapped[...] annotations |
| 1153 | + if origin is Mapped: |
| 1154 | + ann = raw_ann.__args__[0] |
| 1155 | + else: |
| 1156 | + ann = raw_ann |
| 1157 | + |
| 1158 | + # Get the target relationship type |
| 1159 | + try: |
| 1160 | + rel_info = owner_class.__sqlmodel_relationships__[relationship_name] |
| 1161 | + relationship_to = get_relationship_to( |
| 1162 | + name=relationship_name, rel_info=rel_info, annotation=ann |
| 1163 | + ) |
| 1164 | + except (KeyError, AttributeError): |
| 1165 | + return value |
| 1166 | + |
| 1167 | + # Handle list/sequence relationships |
| 1168 | + list_origin = get_origin(ann) |
| 1169 | + if list_origin is list: |
| 1170 | + target_type = get_args(ann)[0] |
| 1171 | + if isinstance(target_type, str): |
| 1172 | + # Forward reference - try to resolve from SQLAlchemy's registry |
| 1173 | + try: |
| 1174 | + resolved_type = default_registry._class_registry.get(target_type) |
| 1175 | + if resolved_type is not None: |
| 1176 | + target_type = resolved_type |
| 1177 | + else: |
| 1178 | + target_type = relationship_to |
| 1179 | + except Exception: |
| 1180 | + target_type = relationship_to |
| 1181 | + else: |
| 1182 | + target_type = relationship_to |
| 1183 | + |
| 1184 | + if isinstance(value, (list, tuple)): |
| 1185 | + converted_items = [] |
| 1186 | + for item in value: |
| 1187 | + converted_item = _convert_single_pydantic_to_table_model( |
| 1188 | + item, target_type |
| 1189 | + ) |
| 1190 | + converted_items.append(converted_item) |
| 1191 | + return converted_items |
| 1192 | + else: |
| 1193 | + # Single relationship |
| 1194 | + target_type = relationship_to |
| 1195 | + if isinstance(target_type, str): |
| 1196 | + # Forward reference - try to resolve from SQLAlchemy's registry |
| 1197 | + try: |
| 1198 | + resolved_type = default_registry._class_registry.get(target_type) |
| 1199 | + if resolved_type is not None: |
| 1200 | + target_type = resolved_type |
| 1201 | + except: |
| 1202 | + pass |
| 1203 | + |
| 1204 | + return _convert_single_pydantic_to_table_model(value, target_type) |
| 1205 | + |
| 1206 | + return value |
| 1207 | + |
| 1208 | + |
| 1209 | +def _convert_single_pydantic_to_table_model(item: Any, target_type: Any) -> Any: |
| 1210 | + """ |
| 1211 | + Convert a single Pydantic object to a table model. |
| 1212 | +
|
| 1213 | + Args: |
| 1214 | + item: The Pydantic object to convert |
| 1215 | + target_type: The target table model type |
| 1216 | +
|
| 1217 | + Returns: |
| 1218 | + Converted table model instance or original item if no conversion needed |
| 1219 | + """ |
| 1220 | + # If item is None, return as-is |
| 1221 | + if item is None: |
| 1222 | + return item |
| 1223 | + |
| 1224 | + # If target_type is a string (forward reference), try to resolve it |
| 1225 | + if isinstance(target_type, str): |
| 1226 | + try: |
| 1227 | + resolved_type = default_registry._class_registry.get(target_type) |
| 1228 | + if resolved_type is not None: |
| 1229 | + target_type = resolved_type |
| 1230 | + except Exception: |
| 1231 | + pass |
| 1232 | + |
| 1233 | + # If target_type is still a string after resolution attempt, |
| 1234 | + # we can't perform type checks or conversions |
| 1235 | + if isinstance(target_type, str): |
| 1236 | + # If item is a BaseModel but not a table model, try conversion |
| 1237 | + if ( |
| 1238 | + isinstance(item, BaseModel) |
| 1239 | + and hasattr(item, "__class__") |
| 1240 | + and not is_table_model_class(item.__class__) |
| 1241 | + ): |
| 1242 | + # Can't convert without knowing the actual target type |
| 1243 | + return item |
| 1244 | + else: |
| 1245 | + return item |
| 1246 | + |
| 1247 | + # If item is already the correct type, return as-is |
| 1248 | + if isinstance(item, target_type): |
| 1249 | + return item |
| 1250 | + |
| 1251 | + # Check if target_type is a SQLModel table class |
| 1252 | + if not ( |
| 1253 | + hasattr(target_type, "__mro__") |
| 1254 | + and any( |
| 1255 | + hasattr(cls, "__sqlmodel_relationships__") for cls in target_type.__mro__ |
| 1256 | + ) |
| 1257 | + ): |
| 1258 | + return item |
| 1259 | + |
| 1260 | + # Check if target is a table model |
| 1261 | + if not is_table_model_class(target_type): |
| 1262 | + return item |
| 1263 | + |
| 1264 | + # Check if item is a BaseModel (Pydantic model) but not a table model |
| 1265 | + if ( |
| 1266 | + isinstance(item, BaseModel) |
| 1267 | + and hasattr(item, "__class__") |
| 1268 | + and not is_table_model_class(item.__class__) |
| 1269 | + ): |
| 1270 | + # Convert Pydantic model to table model |
| 1271 | + try: |
| 1272 | + # Get the data from the Pydantic model |
| 1273 | + if hasattr(item, "model_dump"): |
| 1274 | + # Pydantic v2 |
| 1275 | + data = item.model_dump() |
| 1276 | + else: |
| 1277 | + # Pydantic v1 |
| 1278 | + data = item.dict() |
| 1279 | + |
| 1280 | + # Create new table model instance |
| 1281 | + return target_type(**data) |
| 1282 | + except Exception: |
| 1283 | + # If conversion fails, return original item |
| 1284 | + return item |
| 1285 | + |
| 1286 | + return item |
0 commit comments