11import ipaddress
22import uuid
3+ import warnings
34import weakref
45from datetime import date , datetime , time , timedelta
56from decimal import Decimal
4142)
4243from sqlalchemy import Enum as sa_Enum
4344from sqlalchemy .orm import (
45+ InstrumentedAttribute ,
4446 Mapped ,
47+ MappedColumn ,
4548 RelationshipProperty ,
46- declared_attr ,
4749 registry ,
4850 relationship ,
4951)
5658
5759from ._compat import ( # type: ignore[attr-defined]
5860 IS_PYDANTIC_V2 ,
59- PYDANTIC_MINOR_VERSION ,
61+ PYDANTIC_VERSION ,
6062 BaseConfig ,
6163 ModelField ,
6264 ModelMetaclass ,
9395IncEx : TypeAlias = Union [
9496 Set [int ],
9597 Set [str ],
96- Mapping [int , Union ["IncEx" , bool ]],
97- Mapping [str , Union ["IncEx" , bool ]],
98+ Mapping [int , Union ["IncEx" , Literal [ True ] ]],
99+ Mapping [str , Union ["IncEx" , Literal [ True ] ]],
98100]
99101OnDeleteType = Literal ["CASCADE" , "SET NULL" , "RESTRICT" ]
100102
@@ -134,15 +136,17 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
134136 )
135137 if primary_key is not Undefined :
136138 raise RuntimeError (
137- "Passing primary_key is not supported when also passing a sa_column"
139+ "Passing primary_key is not supported when "
140+ "also passing a sa_column"
138141 )
139142 if nullable is not Undefined :
140143 raise RuntimeError (
141144 "Passing nullable is not supported when also passing a sa_column"
142145 )
143146 if foreign_key is not Undefined :
144147 raise RuntimeError (
145- "Passing foreign_key is not supported when also passing a sa_column"
148+ "Passing foreign_key is not supported when "
149+ "also passing a sa_column"
146150 )
147151 if ondelete is not Undefined :
148152 raise RuntimeError (
@@ -339,7 +343,7 @@ def Field(
339343 regex : Optional [str ] = None ,
340344 discriminator : Optional [str ] = None ,
341345 repr : bool = True ,
342- sa_column : Union [Column [ Any ] , UndefinedType ] = Undefined ,
346+ sa_column : Union [Column , UndefinedType ] = Undefined , # type: ignore
343347 schema_extra : Optional [Dict [str , Any ]] = None ,
344348) -> Any : ...
345349
@@ -477,7 +481,7 @@ def Relationship(
477481class SQLModelMetaclass (ModelMetaclass , DeclarativeMeta ):
478482 __sqlmodel_relationships__ : Dict [str , RelationshipInfo ]
479483 model_config : SQLModelConfig
480- model_fields : ClassVar [ Dict [str , FieldInfo ] ]
484+ model_fields : Dict [str , FieldInfo ]
481485 __config__ : Type [SQLModelConfig ]
482486 __fields__ : Dict [str , ModelField ] # type: ignore[assignment]
483487
@@ -536,7 +540,42 @@ def __new__(
536540 config_kwargs = {
537541 key : kwargs [key ] for key in kwargs .keys () & allowed_config_kwargs
538542 }
539- new_cls = super ().__new__ (cls , name , bases , dict_used , ** config_kwargs )
543+ is_polymorphic = False
544+ if IS_PYDANTIC_V2 :
545+ base_fields = {}
546+ base_annotations = {}
547+ for base in bases [::- 1 ]:
548+ if issubclass (base , BaseModel ):
549+ base_fields .update (get_model_fields (base ))
550+ base_annotations .update (base .__annotations__ )
551+ if hasattr (base , "__sqlmodel_relationships__" ):
552+ for k in base .__sqlmodel_relationships__ :
553+ # create a dummy attribute to avoid inherit
554+ # pydantic will treat it as class variables, and will not become fields on model instances
555+ anno = base_annotations .get (k , Any )
556+ if get_origin (anno ) is not ClassVar :
557+ dummy_anno = ClassVar [anno ]
558+ dict_used ["__annotations__" ][k ] = dummy_anno
559+
560+ if hasattr (base , "__tablename__" ):
561+ is_polymorphic = True
562+ # use base_fields overwriting the ones from the class for inherit
563+ # if base is a sqlalchemy model, it's attributes will be an InstrumentedAttribute
564+ # thus pydantic will use the value of the attribute as the default value
565+ base_annotations .update (dict_used ["__annotations__" ])
566+ dict_used ["__annotations__" ] = base_annotations
567+ base_fields .update (dict_used )
568+ dict_used = base_fields
569+ # if is_polymorphic, disable pydantic `shadows an attribute` warning
570+ if is_polymorphic :
571+ with warnings .catch_warnings ():
572+ warnings .filterwarnings (
573+ "ignore" ,
574+ message = "Field name .+ shadows an attribute in parent.+" ,
575+ )
576+ new_cls = super ().__new__ (cls , name , bases , dict_used , ** config_kwargs )
577+ else :
578+ new_cls = super ().__new__ (cls , name , bases , dict_used , ** config_kwargs )
540579 new_cls .__annotations__ = {
541580 ** relationship_annotations ,
542581 ** pydantic_annotations ,
@@ -556,9 +595,22 @@ def get_config(name: str) -> Any:
556595
557596 config_table = get_config ("table" )
558597 if config_table is True :
598+ # sqlalchemy mark a class as table by check if it has __tablename__ attribute
599+ # or if __tablename__ is in __annotations__. Only set __tablename__ if it's
600+ # a table model
601+ if new_cls .__name__ != "SQLModel" and not hasattr (new_cls , "__tablename__" ):
602+ setattr (new_cls , "__tablename__" , new_cls .__name__ .lower ()) # noqa: B010
559603 # If it was passed by kwargs, ensure it's also set in config
560604 set_config_value (model = new_cls , parameter = "table" , value = config_table )
561605 for k , v in get_model_fields (new_cls ).items ():
606+ original_v = getattr (new_cls , k , None )
607+ if (
608+ isinstance (original_v , InstrumentedAttribute )
609+ and k not in class_dict
610+ ):
611+ # The attribute was already set by SQLAlchemy, don't override it
612+ # Needed for polymorphic models, see #36
613+ continue
562614 col = get_column_from_field (v )
563615 setattr (new_cls , k , col )
564616 # Set a config flag to tell FastAPI that this should be read with a field
@@ -592,7 +644,15 @@ def __init__(
592644 # trying to create a new SQLAlchemy, for a new table, with the same name, that
593645 # triggers an error
594646 base_is_table = any (is_table_model_class (base ) for base in bases )
595- if is_table_model_class (cls ) and not base_is_table :
647+ _mapper_args = dict_ .get ("__mapper_args__" , {})
648+ polymorphic_identity = _mapper_args .get ("polymorphic_identity" )
649+ polymorphic_abstract = _mapper_args .get ("polymorphic_abstract" )
650+ has_polymorphic = (
651+ polymorphic_identity is not None or polymorphic_abstract is not None
652+ )
653+
654+ # allow polymorphic models inherit from table models
655+ if is_table_model_class (cls ) and (not base_is_table or has_polymorphic ):
596656 for rel_name , rel_info in cls .__sqlmodel_relationships__ .items ():
597657 if rel_info .sa_relationship :
598658 # There's a SQLAlchemy relationship declared, that takes precedence
@@ -700,13 +760,13 @@ def get_sqlalchemy_type(field: Any) -> Any:
700760 raise ValueError (f"{ type_ } has no matching SQLAlchemy type" )
701761
702762
703- def get_column_from_field (field : Any ) -> Column : # type: ignore
763+ def get_column_from_field (field : Any ) -> Union [ Column , MappedColumn ] : # type: ignore
704764 if IS_PYDANTIC_V2 :
705765 field_info = field
706766 else :
707767 field_info = field .field_info
708768 sa_column = getattr (field_info , "sa_column" , Undefined )
709- if isinstance (sa_column , Column ):
769+ if isinstance (sa_column , Column ) or isinstance ( sa_column , MappedColumn ) :
710770 return sa_column
711771 sa_type = get_sqlalchemy_type (field )
712772 primary_key = getattr (field_info , "primary_key" , Undefined )
@@ -770,7 +830,6 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
770830class SQLModel (BaseModel , metaclass = SQLModelMetaclass , registry = default_registry ):
771831 # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
772832 __slots__ = ("__weakref__" ,)
773- __tablename__ : ClassVar [Union [str , Callable [..., str ]]]
774833 __sqlmodel_relationships__ : ClassVar [Dict [str , RelationshipProperty [Any ]]]
775834 __name__ : ClassVar [str ]
776835 metadata : ClassVar [MetaData ]
@@ -834,12 +893,8 @@ def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]:
834893 if not (isinstance (k , str ) and k .startswith ("_sa_" ))
835894 ]
836895
837- @declared_attr # type: ignore
838- def __tablename__ (cls ) -> str :
839- return cls .__name__ .lower ()
840-
841896 @classmethod
842- def model_validate ( # type: ignore[override]
897+ def model_validate (
843898 cls : Type [_TSQLModel ],
844899 obj : Any ,
845900 * ,
@@ -863,25 +918,20 @@ def model_dump(
863918 mode : Union [Literal ["json" , "python" ], str ] = "python" ,
864919 include : Union [IncEx , None ] = None ,
865920 exclude : Union [IncEx , None ] = None ,
866- context : Union [Any , None ] = None ,
867- by_alias : Union [ bool , None ] = None ,
921+ context : Union [Dict [ str , Any ] , None ] = None ,
922+ by_alias : bool = False ,
868923 exclude_unset : bool = False ,
869924 exclude_defaults : bool = False ,
870925 exclude_none : bool = False ,
871926 round_trip : bool = False ,
872927 warnings : Union [bool , Literal ["none" , "warn" , "error" ]] = True ,
873- fallback : Union [Callable [[Any ], Any ], None ] = None ,
874928 serialize_as_any : bool = False ,
875929 ) -> Dict [str , Any ]:
876- if PYDANTIC_MINOR_VERSION < (2 , 11 ):
877- by_alias = by_alias or False
878- if PYDANTIC_MINOR_VERSION >= (2 , 7 ):
930+ if PYDANTIC_VERSION >= "2.7.0" :
879931 extra_kwargs : Dict [str , Any ] = {
880932 "context" : context ,
881933 "serialize_as_any" : serialize_as_any ,
882934 }
883- if PYDANTIC_MINOR_VERSION >= (2 , 11 ):
884- extra_kwargs ["fallback" ] = fallback
885935 else :
886936 extra_kwargs = {}
887937 if IS_PYDANTIC_V2 :
@@ -901,7 +951,7 @@ def model_dump(
901951 return super ().dict (
902952 include = include ,
903953 exclude = exclude ,
904- by_alias = by_alias or False ,
954+ by_alias = by_alias ,
905955 exclude_unset = exclude_unset ,
906956 exclude_defaults = exclude_defaults ,
907957 exclude_none = exclude_none ,
0 commit comments