4141)
4242from sqlalchemy import Enum as sa_Enum
4343from sqlalchemy .orm import (
44+ InstrumentedAttribute ,
4445 Mapped ,
46+ MappedColumn ,
4547 RelationshipProperty ,
46- declared_attr ,
4748 registry ,
4849 relationship ,
4950)
@@ -544,6 +545,15 @@ def __new__(
544545 ** pydantic_annotations ,
545546 ** new_cls .__annotations__ ,
546547 }
548+ # pydantic will set class attribute value inherited from parent as field
549+ # default value, reset it back
550+ base_fields = {}
551+ for base in bases [::- 1 ]:
552+ if issubclass (base , BaseModel ):
553+ base_fields .update (base .model_fields )
554+ for k , v in new_cls .model_fields .items ():
555+ if isinstance (v .default , InstrumentedAttribute ):
556+ new_cls .model_fields [k ] = base_fields .get (k )
547557
548558 def get_config (name : str ) -> Any :
549559 config_class_value = get_config_value (
@@ -558,9 +568,19 @@ def get_config(name: str) -> Any:
558568
559569 config_table = get_config ("table" )
560570 if config_table is True :
571+ if new_cls .__name__ != "SQLModel" and not hasattr (new_cls , "__tablename__" ):
572+ new_cls .__tablename__ = new_cls .__name__ .lower ()
561573 # If it was passed by kwargs, ensure it's also set in config
562574 set_config_value (model = new_cls , parameter = "table" , value = config_table )
563575 for k , v in get_model_fields (new_cls ).items ():
576+ original_v = getattr (new_cls , k , None )
577+ if (
578+ isinstance (original_v , InstrumentedAttribute )
579+ and k not in class_dict
580+ ):
581+ # The attribute was already set by SQLAlchemy, don't override it
582+ # Needed for polymorphic models, see #36
583+ continue
564584 col = get_column_from_field (v )
565585 setattr (new_cls , k , col )
566586 # Set a config flag to tell FastAPI that this should be read with a field
@@ -594,7 +614,13 @@ def __init__(
594614 # trying to create a new SQLAlchemy, for a new table, with the same name, that
595615 # triggers an error
596616 base_is_table = any (is_table_model_class (base ) for base in bases )
597- if is_table_model_class (cls ) and not base_is_table :
617+ polymorphic_identity = dict_ .get ("__mapper_args__" , {}).get (
618+ "polymorphic_identity"
619+ )
620+ has_polymorphic = polymorphic_identity is not None
621+
622+ # allow polymorphic models inherit from table models
623+ if is_table_model_class (cls ) and (not base_is_table or has_polymorphic ):
598624 for rel_name , rel_info in cls .__sqlmodel_relationships__ .items ():
599625 if rel_info .sa_relationship :
600626 # There's a SQLAlchemy relationship declared, that takes precedence
@@ -641,6 +667,16 @@ def __init__(
641667 # Ref: https://github.com/sqlalchemy/sqlalchemy/commit/428ea01f00a9cc7f85e435018565eb6da7af1b77
642668 # Tag: 1.4.36
643669 DeclarativeMeta .__init__ (cls , classname , bases , dict_ , ** kw )
670+ # # patch sqlmodel field's default value to polymorphic_identity
671+ # if has_polymorphic:
672+ # mapper = inspect(cls)
673+ # polymorphic_on = mapper.polymorphic_on
674+ # polymorphic_property = mapper.get_property_by_column(polymorphic_on)
675+ # field = cls.model_fields.get(polymorphic_property.key)
676+ # def get__polymorphic_identity__(kw):
677+ # return polymorphic_identity
678+ # if field:
679+ # field.default_factory = get__polymorphic_identity__
644680 else :
645681 ModelMetaclass .__init__ (cls , classname , bases , dict_ , ** kw )
646682
@@ -708,7 +744,7 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
708744 else :
709745 field_info = field .field_info
710746 sa_column = getattr (field_info , "sa_column" , Undefined )
711- if isinstance (sa_column , Column ):
747+ if isinstance (sa_column , Column ) or isinstance ( sa_column , MappedColumn ) :
712748 return sa_column
713749 sa_type = get_sqlalchemy_type (field )
714750 primary_key = getattr (field_info , "primary_key" , Undefined )
@@ -772,7 +808,6 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
772808class SQLModel (BaseModel , metaclass = SQLModelMetaclass , registry = default_registry ):
773809 # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
774810 __slots__ = ("__weakref__" ,)
775- __tablename__ : ClassVar [Union [str , Callable [..., str ]]]
776811 __sqlmodel_relationships__ : ClassVar [Dict [str , RelationshipProperty [Any ]]]
777812 __name__ : ClassVar [str ]
778813 metadata : ClassVar [MetaData ]
@@ -836,10 +871,6 @@ def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]:
836871 if not (isinstance (k , str ) and k .startswith ("_sa_" ))
837872 ]
838873
839- @declared_attr # type: ignore
840- def __tablename__ (cls ) -> str :
841- return cls .__name__ .lower ()
842-
843874 @classmethod
844875 def model_validate (
845876 cls : Type [_TSQLModel ],
0 commit comments