Skip to content

Commit 210058e

Browse files
committed
add polymorphic relationships to sql model
1 parent 5289532 commit 210058e

File tree

3 files changed

+396
-37
lines changed

3 files changed

+396
-37
lines changed

sqlmodel/_compat.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@
2121
from pydantic import VERSION as P_VERSION
2222
from pydantic import BaseModel
2323
from pydantic.fields import FieldInfo
24+
from sqlalchemy import inspect
25+
from sqlalchemy.orm import Mapper
2426
from typing_extensions import Annotated, get_args, get_origin
2527

2628
# Reassign variable to make it reexported for mypy
2729
PYDANTIC_VERSION = P_VERSION
28-
PYDANTIC_MINOR_VERSION = tuple(int(i) for i in P_VERSION.split(".")[:2])
29-
IS_PYDANTIC_V2 = PYDANTIC_MINOR_VERSION[0] == 2
30+
IS_PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
3031

3132

3233
if TYPE_CHECKING:
@@ -65,6 +66,35 @@ def _is_union_type(t: Any) -> bool:
6566
finish_init: ContextVar[bool] = ContextVar("finish_init", default=True)
6667

6768

69+
def set_polymorphic_default_value(
70+
self_instance: _TSQLModel,
71+
values: Dict[str, Any],
72+
) -> bool:
73+
"""By default, when init a model, pydantic will set the polymorphic_on
74+
value to field default value. But when inherit a model, the polymorphic_on
75+
should be set to polymorphic_identity value by default."""
76+
cls = type(self_instance)
77+
mapper = inspect(cls)
78+
ret = False
79+
if isinstance(mapper, Mapper):
80+
polymorphic_on = mapper.polymorphic_on
81+
if polymorphic_on is not None:
82+
polymorphic_property = mapper.get_property_by_column(polymorphic_on)
83+
field_info = get_model_fields(cls).get(polymorphic_property.key)
84+
if field_info:
85+
v = values.get(polymorphic_property.key)
86+
# if model is inherited or polymorphic_on is not explicitly set
87+
# set the polymorphic_on by default
88+
if mapper.inherits or v is None:
89+
setattr(
90+
self_instance,
91+
polymorphic_property.key,
92+
mapper.polymorphic_identity,
93+
)
94+
ret = True
95+
return ret
96+
97+
6898
@contextmanager
6999
def partial_init() -> Generator[None, None, None]:
70100
token = finish_init.set(False)
@@ -103,14 +133,7 @@ def set_config_value(
103133
model.model_config[parameter] = value # type: ignore[literal-required]
104134

105135
def get_model_fields(model: InstanceOrType[BaseModel]) -> Dict[str, "FieldInfo"]:
106-
# TODO: refactor the usage of this function to always pass the class
107-
# not the instance, and then remove this extra check
108-
# this is for compatibility with Pydantic v3
109-
if isinstance(model, type):
110-
use_model = model
111-
else:
112-
use_model = model.__class__
113-
return use_model.model_fields
136+
return model.model_fields
114137

115138
def get_fields_set(
116139
object: InstanceOrType["SQLModel"],
@@ -298,6 +321,8 @@ def sqlmodel_table_construct(
298321
if value is not Undefined:
299322
setattr(self_instance, key, value)
300323
# End SQLModel override
324+
# Override polymorphic_on default value
325+
set_polymorphic_default_value(self_instance, values)
301326
return self_instance
302327

303328
def sqlmodel_validate(

sqlmodel/main.py

Lines changed: 77 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import ipaddress
22
import uuid
3+
import warnings
34
import weakref
45
from datetime import date, datetime, time, timedelta
56
from decimal import Decimal
@@ -41,9 +42,10 @@
4142
)
4243
from sqlalchemy import Enum as sa_Enum
4344
from sqlalchemy.orm import (
45+
InstrumentedAttribute,
4446
Mapped,
47+
MappedColumn,
4548
RelationshipProperty,
46-
declared_attr,
4749
registry,
4850
relationship,
4951
)
@@ -56,7 +58,7 @@
5658

5759
from ._compat import ( # type: ignore[attr-defined]
5860
IS_PYDANTIC_V2,
59-
PYDANTIC_MINOR_VERSION,
61+
PYDANTIC_VERSION,
6062
BaseConfig,
6163
ModelField,
6264
ModelMetaclass,
@@ -93,8 +95,8 @@
9395
IncEx: TypeAlias = Union[
9496
Set[int],
9597
Set[str],
96-
Mapping[int, Union["IncEx", bool]],
97-
Mapping[str, Union["IncEx", bool]],
98+
Mapping[int, Union["IncEx", Literal[True]]],
99+
Mapping[str, Union["IncEx", Literal[True]]],
98100
]
99101
OnDeleteType = Literal["CASCADE", "SET NULL", "RESTRICT"]
100102

@@ -134,15 +136,17 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
134136
)
135137
if primary_key is not Undefined:
136138
raise RuntimeError(
137-
"Passing primary_key is not supported when also passing a sa_column"
139+
"Passing primary_key is not supported when "
140+
"also passing a sa_column"
138141
)
139142
if nullable is not Undefined:
140143
raise RuntimeError(
141144
"Passing nullable is not supported when also passing a sa_column"
142145
)
143146
if foreign_key is not Undefined:
144147
raise RuntimeError(
145-
"Passing foreign_key is not supported when also passing a sa_column"
148+
"Passing foreign_key is not supported when "
149+
"also passing a sa_column"
146150
)
147151
if ondelete is not Undefined:
148152
raise RuntimeError(
@@ -339,7 +343,7 @@ def Field(
339343
regex: Optional[str] = None,
340344
discriminator: Optional[str] = None,
341345
repr: bool = True,
342-
sa_column: Union[Column[Any], UndefinedType] = Undefined,
346+
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
343347
schema_extra: Optional[Dict[str, Any]] = None,
344348
) -> Any: ...
345349

@@ -477,7 +481,7 @@ def Relationship(
477481
class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
478482
__sqlmodel_relationships__: Dict[str, RelationshipInfo]
479483
model_config: SQLModelConfig
480-
model_fields: ClassVar[Dict[str, FieldInfo]]
484+
model_fields: Dict[str, FieldInfo]
481485
__config__: Type[SQLModelConfig]
482486
__fields__: Dict[str, ModelField] # type: ignore[assignment]
483487

@@ -536,7 +540,42 @@ def __new__(
536540
config_kwargs = {
537541
key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs
538542
}
539-
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
543+
is_polymorphic = False
544+
if IS_PYDANTIC_V2:
545+
base_fields = {}
546+
base_annotations = {}
547+
for base in bases[::-1]:
548+
if issubclass(base, BaseModel):
549+
base_fields.update(get_model_fields(base))
550+
base_annotations.update(base.__annotations__)
551+
if hasattr(base, "__sqlmodel_relationships__"):
552+
for k in base.__sqlmodel_relationships__:
553+
# create a dummy attribute to avoid inherit
554+
# pydantic will treat it as class variables, and will not become fields on model instances
555+
anno = base_annotations.get(k, Any)
556+
if get_origin(anno) is not ClassVar:
557+
dummy_anno = ClassVar[anno]
558+
dict_used["__annotations__"][k] = dummy_anno
559+
560+
if hasattr(base, "__tablename__"):
561+
is_polymorphic = True
562+
# use base_fields overwriting the ones from the class for inherit
563+
# if base is a sqlalchemy model, it's attributes will be an InstrumentedAttribute
564+
# thus pydantic will use the value of the attribute as the default value
565+
base_annotations.update(dict_used["__annotations__"])
566+
dict_used["__annotations__"] = base_annotations
567+
base_fields.update(dict_used)
568+
dict_used = base_fields
569+
# if is_polymorphic, disable pydantic `shadows an attribute` warning
570+
if is_polymorphic:
571+
with warnings.catch_warnings():
572+
warnings.filterwarnings(
573+
"ignore",
574+
message="Field name .+ shadows an attribute in parent.+",
575+
)
576+
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
577+
else:
578+
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
540579
new_cls.__annotations__ = {
541580
**relationship_annotations,
542581
**pydantic_annotations,
@@ -556,9 +595,22 @@ def get_config(name: str) -> Any:
556595

557596
config_table = get_config("table")
558597
if config_table is True:
598+
# sqlalchemy mark a class as table by check if it has __tablename__ attribute
599+
# or if __tablename__ is in __annotations__. Only set __tablename__ if it's
600+
# a table model
601+
if new_cls.__name__ != "SQLModel" and not hasattr(new_cls, "__tablename__"):
602+
setattr(new_cls, "__tablename__", new_cls.__name__.lower()) # noqa: B010
559603
# If it was passed by kwargs, ensure it's also set in config
560604
set_config_value(model=new_cls, parameter="table", value=config_table)
561605
for k, v in get_model_fields(new_cls).items():
606+
original_v = getattr(new_cls, k, None)
607+
if (
608+
isinstance(original_v, InstrumentedAttribute)
609+
and k not in class_dict
610+
):
611+
# The attribute was already set by SQLAlchemy, don't override it
612+
# Needed for polymorphic models, see #36
613+
continue
562614
col = get_column_from_field(v)
563615
setattr(new_cls, k, col)
564616
# Set a config flag to tell FastAPI that this should be read with a field
@@ -592,7 +644,15 @@ def __init__(
592644
# trying to create a new SQLAlchemy, for a new table, with the same name, that
593645
# triggers an error
594646
base_is_table = any(is_table_model_class(base) for base in bases)
595-
if is_table_model_class(cls) and not base_is_table:
647+
_mapper_args = dict_.get("__mapper_args__", {})
648+
polymorphic_identity = _mapper_args.get("polymorphic_identity")
649+
polymorphic_abstract = _mapper_args.get("polymorphic_abstract")
650+
has_polymorphic = (
651+
polymorphic_identity is not None or polymorphic_abstract is not None
652+
)
653+
654+
# allow polymorphic models inherit from table models
655+
if is_table_model_class(cls) and (not base_is_table or has_polymorphic):
596656
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
597657
if rel_info.sa_relationship:
598658
# There's a SQLAlchemy relationship declared, that takes precedence
@@ -700,13 +760,13 @@ def get_sqlalchemy_type(field: Any) -> Any:
700760
raise ValueError(f"{type_} has no matching SQLAlchemy type")
701761

702762

703-
def get_column_from_field(field: Any) -> Column: # type: ignore
763+
def get_column_from_field(field: Any) -> Union[Column, MappedColumn]: # type: ignore
704764
if IS_PYDANTIC_V2:
705765
field_info = field
706766
else:
707767
field_info = field.field_info
708768
sa_column = getattr(field_info, "sa_column", Undefined)
709-
if isinstance(sa_column, Column):
769+
if isinstance(sa_column, Column) or isinstance(sa_column, MappedColumn):
710770
return sa_column
711771
sa_type = get_sqlalchemy_type(field)
712772
primary_key = getattr(field_info, "primary_key", Undefined)
@@ -770,7 +830,6 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
770830
class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
771831
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
772832
__slots__ = ("__weakref__",)
773-
__tablename__: ClassVar[Union[str, Callable[..., str]]]
774833
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty[Any]]]
775834
__name__: ClassVar[str]
776835
metadata: ClassVar[MetaData]
@@ -834,12 +893,8 @@ def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]:
834893
if not (isinstance(k, str) and k.startswith("_sa_"))
835894
]
836895

837-
@declared_attr # type: ignore
838-
def __tablename__(cls) -> str:
839-
return cls.__name__.lower()
840-
841896
@classmethod
842-
def model_validate( # type: ignore[override]
897+
def model_validate(
843898
cls: Type[_TSQLModel],
844899
obj: Any,
845900
*,
@@ -863,25 +918,20 @@ def model_dump(
863918
mode: Union[Literal["json", "python"], str] = "python",
864919
include: Union[IncEx, None] = None,
865920
exclude: Union[IncEx, None] = None,
866-
context: Union[Any, None] = None,
867-
by_alias: Union[bool, None] = None,
921+
context: Union[Dict[str, Any], None] = None,
922+
by_alias: bool = False,
868923
exclude_unset: bool = False,
869924
exclude_defaults: bool = False,
870925
exclude_none: bool = False,
871926
round_trip: bool = False,
872927
warnings: Union[bool, Literal["none", "warn", "error"]] = True,
873-
fallback: Union[Callable[[Any], Any], None] = None,
874928
serialize_as_any: bool = False,
875929
) -> Dict[str, Any]:
876-
if PYDANTIC_MINOR_VERSION < (2, 11):
877-
by_alias = by_alias or False
878-
if PYDANTIC_MINOR_VERSION >= (2, 7):
930+
if PYDANTIC_VERSION >= "2.7.0":
879931
extra_kwargs: Dict[str, Any] = {
880932
"context": context,
881933
"serialize_as_any": serialize_as_any,
882934
}
883-
if PYDANTIC_MINOR_VERSION >= (2, 11):
884-
extra_kwargs["fallback"] = fallback
885935
else:
886936
extra_kwargs = {}
887937
if IS_PYDANTIC_V2:
@@ -901,7 +951,7 @@ def model_dump(
901951
return super().dict(
902952
include=include,
903953
exclude=exclude,
904-
by_alias=by_alias or False,
954+
by_alias=by_alias,
905955
exclude_unset=exclude_unset,
906956
exclude_defaults=exclude_defaults,
907957
exclude_none=exclude_none,

0 commit comments

Comments
 (0)