@@ -66,6 +66,29 @@ def _is_union_type(t: Any) -> bool:
66
66
finish_init : ContextVar [bool ] = ContextVar ("finish_init" , default = True )
67
67
68
68
69
+ def set_polymorphic_default_value (self_instance , values ):
70
+ """By defalut, when init a model, pydantic will set the polymorphic_on
71
+ value to field default value. But when inherit a model, the polymorphic_on
72
+ should be set to polymorphic_identity value by default."""
73
+ cls = type (self_instance )
74
+ mapper = inspect (cls )
75
+ if isinstance (mapper , Mapper ):
76
+ polymorphic_on = mapper .polymorphic_on
77
+ if polymorphic_on is not None :
78
+ polymorphic_property = mapper .get_property_by_column (polymorphic_on )
79
+ field_info = get_model_fields (cls ).get (polymorphic_property .key )
80
+ if field_info :
81
+ v = values .get (polymorphic_property .key )
82
+ # if model is inherited or polymorphic_on is not explicitly set
83
+ # set the polymorphic_on by default
84
+ if mapper .inherits or v is None :
85
+ setattr (
86
+ self_instance ,
87
+ polymorphic_property .key ,
88
+ mapper .polymorphic_identity ,
89
+ )
90
+
91
+
69
92
@contextmanager
70
93
def partial_init () -> Generator [None , None , None ]:
71
94
token = finish_init .set (False )
@@ -293,22 +316,7 @@ def sqlmodel_table_construct(
293
316
setattr (self_instance , key , value )
294
317
# End SQLModel override
295
318
# Override polymorphic_on default value
296
- mapper = inspect (cls )
297
- if isinstance (mapper , Mapper ):
298
- polymorphic_on = mapper .polymorphic_on
299
- if polymorphic_on is not None :
300
- polymorphic_property = mapper .get_property_by_column (polymorphic_on )
301
- field_info = cls .model_fields .get (polymorphic_property .key )
302
- if field_info :
303
- v = values .get (polymorphic_property .key )
304
- # if model is inherited or polymorphic_on is not explicitly set
305
- # set the polymorphic_on by default
306
- if mapper .inherits or v is None :
307
- setattr (
308
- self_instance ,
309
- polymorphic_property .key ,
310
- mapper .polymorphic_identity ,
311
- )
319
+ set_polymorphic_default_value (self_instance , values )
312
320
return self_instance
313
321
314
322
def sqlmodel_validate (
@@ -592,3 +600,5 @@ def sqlmodel_init(*, self: "SQLModel", data: Dict[str, Any]) -> None:
592
600
for key in non_pydantic_keys :
593
601
if key in self .__sqlmodel_relationships__ :
594
602
setattr (self , key , data [key ])
603
+ # Override polymorphic_on default value
604
+ set_polymorphic_default_value (self , values )
0 commit comments