diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 38dd501c4a..3853d528d9 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -21,12 +21,13 @@ from pydantic import VERSION as P_VERSION from pydantic import BaseModel from pydantic.fields import FieldInfo +from sqlalchemy import inspect +from sqlalchemy.orm import Mapper from typing_extensions import Annotated, get_args, get_origin # Reassign variable to make it reexported for mypy PYDANTIC_VERSION = P_VERSION -PYDANTIC_MINOR_VERSION = tuple(int(i) for i in P_VERSION.split(".")[:2]) -IS_PYDANTIC_V2 = PYDANTIC_MINOR_VERSION[0] == 2 +IS_PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") if TYPE_CHECKING: @@ -65,6 +66,35 @@ def _is_union_type(t: Any) -> bool: finish_init: ContextVar[bool] = ContextVar("finish_init", default=True) +def set_polymorphic_default_value( + self_instance: _TSQLModel, + values: Dict[str, Any], +) -> bool: + """By default, when init a model, pydantic will set the polymorphic_on + value to field default value. But when inherit a model, the polymorphic_on + should be set to polymorphic_identity value by default.""" + cls = type(self_instance) + mapper = inspect(cls) + ret = False + if isinstance(mapper, Mapper): + polymorphic_on = mapper.polymorphic_on + if polymorphic_on is not None: + polymorphic_property = mapper.get_property_by_column(polymorphic_on) + field_info = get_model_fields(cls).get(polymorphic_property.key) + if field_info: + v = values.get(polymorphic_property.key) + # if model is inherited or polymorphic_on is not explicitly set + # set the polymorphic_on by default + if mapper.inherits or v is None: + setattr( + self_instance, + polymorphic_property.key, + mapper.polymorphic_identity, + ) + ret = True + return ret + + @contextmanager def partial_init() -> Generator[None, None, None]: token = finish_init.set(False) @@ -103,14 +133,7 @@ def set_config_value( model.model_config[parameter] = value # type: ignore[literal-required] def get_model_fields(model: InstanceOrType[BaseModel]) -> Dict[str, "FieldInfo"]: - # TODO: refactor the usage of this function to always pass the class - # not the instance, and then remove this extra check - # this is for compatibility with Pydantic v3 - if isinstance(model, type): - use_model = model - else: - use_model = model.__class__ - return use_model.model_fields + return model.model_fields def get_fields_set( object: InstanceOrType["SQLModel"], @@ -298,6 +321,8 @@ def sqlmodel_table_construct( if value is not Undefined: setattr(self_instance, key, value) # End SQLModel override + # Override polymorphic_on default value + set_polymorphic_default_value(self_instance, values) return self_instance def sqlmodel_validate( diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 38c85915aa..816ec41571 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -1,5 +1,6 @@ import ipaddress import uuid +import warnings import weakref from datetime import date, datetime, time, timedelta from decimal import Decimal @@ -41,9 +42,10 @@ ) from sqlalchemy import Enum as sa_Enum from sqlalchemy.orm import ( + InstrumentedAttribute, Mapped, + MappedColumn, RelationshipProperty, - declared_attr, registry, relationship, ) @@ -56,7 +58,7 @@ from ._compat import ( # type: ignore[attr-defined] IS_PYDANTIC_V2, - PYDANTIC_MINOR_VERSION, + PYDANTIC_VERSION, BaseConfig, ModelField, ModelMetaclass, @@ -93,8 +95,8 @@ IncEx: TypeAlias = Union[ Set[int], Set[str], - Mapping[int, Union["IncEx", bool]], - Mapping[str, Union["IncEx", bool]], + Mapping[int, Union["IncEx", Literal[True]]], + Mapping[str, Union["IncEx", Literal[True]]], ] OnDeleteType = Literal["CASCADE", "SET NULL", "RESTRICT"] @@ -134,7 +136,8 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: ) if primary_key is not Undefined: raise RuntimeError( - "Passing primary_key is not supported when also passing a sa_column" + "Passing primary_key is not supported when " + "also passing a sa_column" ) if nullable is not Undefined: raise RuntimeError( @@ -142,7 +145,8 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: ) if foreign_key is not Undefined: raise RuntimeError( - "Passing foreign_key is not supported when also passing a sa_column" + "Passing foreign_key is not supported when " + "also passing a sa_column" ) if ondelete is not Undefined: raise RuntimeError( @@ -339,7 +343,7 @@ def Field( regex: Optional[str] = None, discriminator: Optional[str] = None, repr: bool = True, - sa_column: Union[Column[Any], UndefinedType] = Undefined, + sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: ... @@ -477,7 +481,7 @@ def Relationship( class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): __sqlmodel_relationships__: Dict[str, RelationshipInfo] model_config: SQLModelConfig - model_fields: ClassVar[Dict[str, FieldInfo]] + model_fields: Dict[str, FieldInfo] __config__: Type[SQLModelConfig] __fields__: Dict[str, ModelField] # type: ignore[assignment] @@ -536,7 +540,42 @@ def __new__( config_kwargs = { key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs } - new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) + is_polymorphic = False + if IS_PYDANTIC_V2: + base_fields = {} + base_annotations = {} + for base in bases[::-1]: + if issubclass(base, BaseModel): + base_fields.update(get_model_fields(base)) + base_annotations.update(base.__annotations__) + if hasattr(base, "__sqlmodel_relationships__"): + for k in base.__sqlmodel_relationships__: + # create a dummy attribute to avoid inherit + # pydantic will treat it as class variables, and will not become fields on model instances + anno = base_annotations.get(k, Any) + if get_origin(anno) is not ClassVar: + dummy_anno = ClassVar[anno] + dict_used["__annotations__"][k] = dummy_anno + + if hasattr(base, "__tablename__"): + is_polymorphic = True + # use base_fields overwriting the ones from the class for inherit + # if base is a sqlalchemy model, it's attributes will be an InstrumentedAttribute + # thus pydantic will use the value of the attribute as the default value + base_annotations.update(dict_used["__annotations__"]) + dict_used["__annotations__"] = base_annotations + base_fields.update(dict_used) + dict_used = base_fields + # if is_polymorphic, disable pydantic `shadows an attribute` warning + if is_polymorphic: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="Field name .+ shadows an attribute in parent.+", + ) + new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) + else: + new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) new_cls.__annotations__ = { **relationship_annotations, **pydantic_annotations, @@ -556,9 +595,22 @@ def get_config(name: str) -> Any: config_table = get_config("table") if config_table is True: + # sqlalchemy mark a class as table by check if it has __tablename__ attribute + # or if __tablename__ is in __annotations__. Only set __tablename__ if it's + # a table model + if new_cls.__name__ != "SQLModel" and not hasattr(new_cls, "__tablename__"): + setattr(new_cls, "__tablename__", new_cls.__name__.lower()) # noqa: B010 # If it was passed by kwargs, ensure it's also set in config set_config_value(model=new_cls, parameter="table", value=config_table) for k, v in get_model_fields(new_cls).items(): + original_v = getattr(new_cls, k, None) + if ( + isinstance(original_v, InstrumentedAttribute) + and k not in class_dict + ): + # The attribute was already set by SQLAlchemy, don't override it + # Needed for polymorphic models, see #36 + continue col = get_column_from_field(v) setattr(new_cls, k, col) # Set a config flag to tell FastAPI that this should be read with a field @@ -592,7 +644,15 @@ def __init__( # trying to create a new SQLAlchemy, for a new table, with the same name, that # triggers an error base_is_table = any(is_table_model_class(base) for base in bases) - if is_table_model_class(cls) and not base_is_table: + _mapper_args = dict_.get("__mapper_args__", {}) + polymorphic_identity = _mapper_args.get("polymorphic_identity") + polymorphic_abstract = _mapper_args.get("polymorphic_abstract") + has_polymorphic = ( + polymorphic_identity is not None or polymorphic_abstract is not None + ) + + # allow polymorphic models inherit from table models + if is_table_model_class(cls) and (not base_is_table or has_polymorphic): for rel_name, rel_info in cls.__sqlmodel_relationships__.items(): if rel_info.sa_relationship: # There's a SQLAlchemy relationship declared, that takes precedence @@ -700,13 +760,13 @@ def get_sqlalchemy_type(field: Any) -> Any: raise ValueError(f"{type_} has no matching SQLAlchemy type") -def get_column_from_field(field: Any) -> Column: # type: ignore +def get_column_from_field(field: Any) -> Union[Column, MappedColumn]: # type: ignore if IS_PYDANTIC_V2: field_info = field else: field_info = field.field_info sa_column = getattr(field_info, "sa_column", Undefined) - if isinstance(sa_column, Column): + if isinstance(sa_column, Column) or isinstance(sa_column, MappedColumn): return sa_column sa_type = get_sqlalchemy_type(field) primary_key = getattr(field_info, "primary_key", Undefined) @@ -770,7 +830,6 @@ def get_column_from_field(field: Any) -> Column: # type: ignore class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry): # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values __slots__ = ("__weakref__",) - __tablename__: ClassVar[Union[str, Callable[..., str]]] __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty[Any]]] __name__: ClassVar[str] metadata: ClassVar[MetaData] @@ -834,12 +893,8 @@ def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: if not (isinstance(k, str) and k.startswith("_sa_")) ] - @declared_attr # type: ignore - def __tablename__(cls) -> str: - return cls.__name__.lower() - @classmethod - def model_validate( # type: ignore[override] + def model_validate( cls: Type[_TSQLModel], obj: Any, *, @@ -863,25 +918,20 @@ def model_dump( mode: Union[Literal["json", "python"], str] = "python", include: Union[IncEx, None] = None, exclude: Union[IncEx, None] = None, - context: Union[Any, None] = None, - by_alias: Union[bool, None] = None, + context: Union[Dict[str, Any], None] = None, + by_alias: bool = False, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, round_trip: bool = False, warnings: Union[bool, Literal["none", "warn", "error"]] = True, - fallback: Union[Callable[[Any], Any], None] = None, serialize_as_any: bool = False, ) -> Dict[str, Any]: - if PYDANTIC_MINOR_VERSION < (2, 11): - by_alias = by_alias or False - if PYDANTIC_MINOR_VERSION >= (2, 7): + if PYDANTIC_VERSION >= "2.7.0": extra_kwargs: Dict[str, Any] = { "context": context, "serialize_as_any": serialize_as_any, } - if PYDANTIC_MINOR_VERSION >= (2, 11): - extra_kwargs["fallback"] = fallback else: extra_kwargs = {} if IS_PYDANTIC_V2: @@ -901,7 +951,7 @@ def model_dump( return super().dict( include=include, exclude=exclude, - by_alias=by_alias or False, + by_alias=by_alias, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, diff --git a/tests/test_polymorphic_model.py b/tests/test_polymorphic_model.py new file mode 100644 index 0000000000..8d9a88661c --- /dev/null +++ b/tests/test_polymorphic_model.py @@ -0,0 +1,284 @@ +from typing import Optional + +from sqlalchemy import ForeignKey +from sqlalchemy.orm import mapped_column +from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select + +from tests.conftest import needs_pydanticv2 + + +@needs_pydanticv2 +def test_polymorphic_joined_table(clear_sqlmodel) -> None: + class Hero(SQLModel, table=True): + __tablename__ = "hero" + id: Optional[int] = Field(default=None, primary_key=True) + hero_type: str = Field(default="hero") + + __mapper_args__ = { + "polymorphic_on": "hero_type", + "polymorphic_identity": "normal_hero", + } + + class DarkHero(Hero): + __tablename__ = "dark_hero" + id: Optional[int] = Field( + default=None, + sa_column=mapped_column(ForeignKey("hero.id"), primary_key=True), + ) + dark_power: str = Field( + default="dark", + sa_column=mapped_column( + nullable=False, use_existing_column=True, default="dark" + ), + ) + + __mapper_args__ = { + "polymorphic_identity": "dark", + } + + engine = create_engine("sqlite:///:memory:", echo=True) + SQLModel.metadata.create_all(engine) + with Session(engine) as db: + hero = Hero() + db.add(hero) + dark_hero = DarkHero() + db.add(dark_hero) + db.commit() + statement = select(DarkHero) + result = db.exec(statement).all() + assert len(result) == 1 + assert isinstance(result[0].dark_power, str) + + +@needs_pydanticv2 +def test_polymorphic_joined_table_with_sqlmodel_field(clear_sqlmodel) -> None: + class Hero(SQLModel, table=True): + __tablename__ = "hero" + id: Optional[int] = Field(default=None, primary_key=True) + hero_type: str = Field(default="hero") + + __mapper_args__ = { + "polymorphic_on": "hero_type", + "polymorphic_identity": "normal_hero", + } + + class DarkHero(Hero): + __tablename__ = "dark_hero" + id: Optional[int] = Field( + default=None, + primary_key=True, + foreign_key="hero.id", + ) + dark_power: str = Field( + default="dark", + sa_column=mapped_column( + nullable=False, use_existing_column=True, default="dark" + ), + ) + + __mapper_args__ = { + "polymorphic_identity": "dark", + } + + engine = create_engine("sqlite:///:memory:", echo=True) + SQLModel.metadata.create_all(engine) + with Session(engine) as db: + hero = Hero() + db.add(hero) + dark_hero = DarkHero() + db.add(dark_hero) + db.commit() + statement = select(DarkHero) + result = db.exec(statement).all() + assert len(result) == 1 + assert isinstance(result[0].dark_power, str) + + +@needs_pydanticv2 +def test_polymorphic_single_table(clear_sqlmodel) -> None: + class Hero(SQLModel, table=True): + __tablename__ = "hero" + id: Optional[int] = Field(default=None, primary_key=True) + hero_type: str = Field(default="hero") + + __mapper_args__ = { + "polymorphic_on": "hero_type", + "polymorphic_identity": "normal_hero", + } + + class DarkHero(Hero): + dark_power: str = Field( + default="dark", + sa_column=mapped_column( + nullable=False, use_existing_column=True, default="dark" + ), + ) + + __mapper_args__ = { + "polymorphic_identity": "dark", + } + + engine = create_engine("sqlite:///:memory:", echo=True) + SQLModel.metadata.create_all(engine) + with Session(engine) as db: + hero = Hero() + db.add(hero) + dark_hero = DarkHero(dark_power="pokey") + db.add(dark_hero) + db.commit() + statement = select(DarkHero) + result = db.exec(statement).all() + assert len(result) == 1 + assert isinstance(result[0].dark_power, str) + + +@needs_pydanticv2 +def test_polymorphic_relationship(clear_sqlmodel) -> None: + class Tool(SQLModel, table=True): + __tablename__ = "tool_table" + + id: int = Field(primary_key=True) + + name: str + + class Person(SQLModel, table=True): + __tablename__ = "person_table" + + id: int = Field(primary_key=True) + + discriminator: str + name: str + + tool_id: int = Field(foreign_key="tool_table.id") + tool: Tool = Relationship() + + __mapper_args__ = { + "polymorphic_on": "discriminator", + "polymorphic_identity": "simple_person", + } + + class Worker(Person): + __mapper_args__ = { + "polymorphic_identity": "worker", + } + + engine = create_engine("sqlite:///:memory:", echo=True) + SQLModel.metadata.create_all(engine) + with Session(engine) as db: + tool = Tool(id=1, name="Hammer") + db.add(tool) + worker = Worker(id=2, name="Bob", tool_id=1) + db.add(worker) + db.commit() + + statement = select(Worker).where(Worker.tool_id == 1) + result = db.exec(statement).all() + assert len(result) == 1 + assert isinstance(result[0].tool, Tool) + + +@needs_pydanticv2 +def test_polymorphic_deeper(clear_sqlmodel) -> None: + class Employee(SQLModel, table=True): + __tablename__ = "employee" + + id: Optional[int] = Field(default=None, primary_key=True) + name: str + type: str = Field(default="employee") + + __mapper_args__ = { + "polymorphic_identity": "employee", + "polymorphic_on": "type", + } + + class Executive(Employee): + """An executive of the company""" + + executive_background: Optional[str] = Field( + sa_column=mapped_column(nullable=True), default=None + ) + + __mapper_args__ = {"polymorphic_abstract": True} + + class Technologist(Employee): + """An employee who works with technology""" + + competencies: Optional[str] = Field( + sa_column=mapped_column(nullable=True), default=None + ) + + __mapper_args__ = {"polymorphic_abstract": True} + + class Manager(Executive): + """A manager""" + + __mapper_args__ = {"polymorphic_identity": "manager"} + + class Principal(Executive): + """A principal of the company""" + + __mapper_args__ = {"polymorphic_identity": "principal"} + + class Engineer(Technologist): + """An engineer""" + + __mapper_args__ = {"polymorphic_identity": "engineer"} + + class SysAdmin(Technologist): + """A systems administrator""" + + __mapper_args__ = {"polymorphic_identity": "sysadmin"} + + # Create database and session + engine = create_engine("sqlite:///:memory:", echo=True) + SQLModel.metadata.create_all(engine) + + with Session(engine) as db: + # Add different employee types + manager = Manager(name="Alice", executive_background="MBA") + principal = Principal(name="Bob", executive_background="Founder") + engineer = Engineer(name="Charlie", competencies="Python, SQL") + sysadmin = SysAdmin(name="Diana", competencies="Linux, Networking") + + db.add(manager) + db.add(principal) + db.add(engineer) + db.add(sysadmin) + db.commit() + + # Query each type to verify they persist correctly + managers = db.exec(select(Manager)).all() + principals = db.exec(select(Principal)).all() + engineers = db.exec(select(Engineer)).all() + sysadmins = db.exec(select(SysAdmin)).all() + + # Query abstract classes to verify they return appropriate concrete classes + executives = db.exec(select(Executive)).all() + technologists = db.exec(select(Technologist)).all() + + # All employees + all_employees = db.exec(select(Employee)).all() + + # Assert individual type counts + assert len(managers) == 1 + assert len(principals) == 1 + assert len(engineers) == 1 + assert len(sysadmins) == 1 + + # Check that abstract classes can't be instantiated directly + # but their subclasses are correctly returned when querying + assert len(executives) == 2 + assert len(technologists) == 2 + assert len(all_employees) == 4 + + # Check that properties of abstract classes are accessible from concrete instances + assert managers[0].executive_background == "MBA" + assert principals[0].executive_background == "Founder" + assert engineers[0].competencies == "Python, SQL" + assert sysadmins[0].competencies == "Linux, Networking" + + # Check polymorphic identities + assert managers[0].type == "manager" + assert principals[0].type == "principal" + assert engineers[0].type == "engineer" + assert sysadmins[0].type == "sysadmin"