Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 35 additions & 10 deletions sqlmodel/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
from pydantic import VERSION as P_VERSION
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from sqlalchemy import inspect
from sqlalchemy.orm import Mapper
from typing_extensions import Annotated, get_args, get_origin

# Reassign variable to make it reexported for mypy
PYDANTIC_VERSION = P_VERSION
PYDANTIC_MINOR_VERSION = tuple(int(i) for i in P_VERSION.split(".")[:2])
IS_PYDANTIC_V2 = PYDANTIC_MINOR_VERSION[0] == 2
IS_PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")


if TYPE_CHECKING:
Expand Down Expand Up @@ -65,6 +66,35 @@ def _is_union_type(t: Any) -> bool:
finish_init: ContextVar[bool] = ContextVar("finish_init", default=True)


def set_polymorphic_default_value(
self_instance: _TSQLModel,
values: Dict[str, Any],
) -> bool:
"""By default, when init a model, pydantic will set the polymorphic_on
value to field default value. But when inherit a model, the polymorphic_on
should be set to polymorphic_identity value by default."""
cls = type(self_instance)
mapper = inspect(cls)
ret = False
if isinstance(mapper, Mapper):
polymorphic_on = mapper.polymorphic_on
if polymorphic_on is not None:
polymorphic_property = mapper.get_property_by_column(polymorphic_on)
field_info = get_model_fields(cls).get(polymorphic_property.key)
if field_info:
v = values.get(polymorphic_property.key)
# if model is inherited or polymorphic_on is not explicitly set
# set the polymorphic_on by default
if mapper.inherits or v is None:
setattr(
self_instance,
polymorphic_property.key,
mapper.polymorphic_identity,
)
ret = True
return ret


@contextmanager
def partial_init() -> Generator[None, None, None]:
token = finish_init.set(False)
Expand Down Expand Up @@ -103,14 +133,7 @@ def set_config_value(
model.model_config[parameter] = value # type: ignore[literal-required]

def get_model_fields(model: InstanceOrType[BaseModel]) -> Dict[str, "FieldInfo"]:
# TODO: refactor the usage of this function to always pass the class
# not the instance, and then remove this extra check
# this is for compatibility with Pydantic v3
if isinstance(model, type):
use_model = model
else:
use_model = model.__class__
return use_model.model_fields
return model.model_fields

def get_fields_set(
object: InstanceOrType["SQLModel"],
Expand Down Expand Up @@ -298,6 +321,8 @@ def sqlmodel_table_construct(
if value is not Undefined:
setattr(self_instance, key, value)
# End SQLModel override
# Override polymorphic_on default value
set_polymorphic_default_value(self_instance, values)
return self_instance

def sqlmodel_validate(
Expand Down
104 changes: 77 additions & 27 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ipaddress
import uuid
import warnings
import weakref
from datetime import date, datetime, time, timedelta
from decimal import Decimal
Expand Down Expand Up @@ -41,9 +42,10 @@
)
from sqlalchemy import Enum as sa_Enum
from sqlalchemy.orm import (
InstrumentedAttribute,
Mapped,
MappedColumn,
RelationshipProperty,
declared_attr,
registry,
relationship,
)
Expand All @@ -56,7 +58,7 @@

from ._compat import ( # type: ignore[attr-defined]
IS_PYDANTIC_V2,
PYDANTIC_MINOR_VERSION,
PYDANTIC_VERSION,
BaseConfig,
ModelField,
ModelMetaclass,
Expand Down Expand Up @@ -93,8 +95,8 @@
IncEx: TypeAlias = Union[
Set[int],
Set[str],
Mapping[int, Union["IncEx", bool]],
Mapping[str, Union["IncEx", bool]],
Mapping[int, Union["IncEx", Literal[True]]],
Mapping[str, Union["IncEx", Literal[True]]],
]
OnDeleteType = Literal["CASCADE", "SET NULL", "RESTRICT"]

Expand Down Expand Up @@ -134,15 +136,17 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
)
if primary_key is not Undefined:
raise RuntimeError(
"Passing primary_key is not supported when also passing a sa_column"
"Passing primary_key is not supported when "
"also passing a sa_column"
)
if nullable is not Undefined:
raise RuntimeError(
"Passing nullable is not supported when also passing a sa_column"
)
if foreign_key is not Undefined:
raise RuntimeError(
"Passing foreign_key is not supported when also passing a sa_column"
"Passing foreign_key is not supported when "
"also passing a sa_column"
)
if ondelete is not Undefined:
raise RuntimeError(
Expand Down Expand Up @@ -339,7 +343,7 @@ def Field(
regex: Optional[str] = None,
discriminator: Optional[str] = None,
repr: bool = True,
sa_column: Union[Column[Any], UndefinedType] = Undefined,
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any: ...

Expand Down Expand Up @@ -477,7 +481,7 @@ def Relationship(
class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
__sqlmodel_relationships__: Dict[str, RelationshipInfo]
model_config: SQLModelConfig
model_fields: ClassVar[Dict[str, FieldInfo]]
model_fields: Dict[str, FieldInfo]
__config__: Type[SQLModelConfig]
__fields__: Dict[str, ModelField] # type: ignore[assignment]

Expand Down Expand Up @@ -536,7 +540,42 @@ def __new__(
config_kwargs = {
key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs
}
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
is_polymorphic = False
if IS_PYDANTIC_V2:
base_fields = {}
base_annotations = {}
for base in bases[::-1]:
if issubclass(base, BaseModel):
base_fields.update(get_model_fields(base))
base_annotations.update(base.__annotations__)
if hasattr(base, "__sqlmodel_relationships__"):
for k in base.__sqlmodel_relationships__:
# create a dummy attribute to avoid inherit
# pydantic will treat it as class variables, and will not become fields on model instances
anno = base_annotations.get(k, Any)
if get_origin(anno) is not ClassVar:
dummy_anno = ClassVar[anno]
dict_used["__annotations__"][k] = dummy_anno

if hasattr(base, "__tablename__"):
is_polymorphic = True
# use base_fields overwriting the ones from the class for inherit
# if base is a sqlalchemy model, it's attributes will be an InstrumentedAttribute
# thus pydantic will use the value of the attribute as the default value
base_annotations.update(dict_used["__annotations__"])
dict_used["__annotations__"] = base_annotations
base_fields.update(dict_used)
dict_used = base_fields
# if is_polymorphic, disable pydantic `shadows an attribute` warning
if is_polymorphic:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Field name .+ shadows an attribute in parent.+",
)
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
else:
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
new_cls.__annotations__ = {
**relationship_annotations,
**pydantic_annotations,
Expand All @@ -556,9 +595,22 @@ def get_config(name: str) -> Any:

config_table = get_config("table")
if config_table is True:
# sqlalchemy mark a class as table by check if it has __tablename__ attribute
# or if __tablename__ is in __annotations__. Only set __tablename__ if it's
# a table model
if new_cls.__name__ != "SQLModel" and not hasattr(new_cls, "__tablename__"):
setattr(new_cls, "__tablename__", new_cls.__name__.lower()) # noqa: B010
# If it was passed by kwargs, ensure it's also set in config
set_config_value(model=new_cls, parameter="table", value=config_table)
for k, v in get_model_fields(new_cls).items():
original_v = getattr(new_cls, k, None)
if (
isinstance(original_v, InstrumentedAttribute)
and k not in class_dict
):
# The attribute was already set by SQLAlchemy, don't override it
# Needed for polymorphic models, see #36
continue
col = get_column_from_field(v)
setattr(new_cls, k, col)
# Set a config flag to tell FastAPI that this should be read with a field
Expand Down Expand Up @@ -592,7 +644,15 @@ def __init__(
# trying to create a new SQLAlchemy, for a new table, with the same name, that
# triggers an error
base_is_table = any(is_table_model_class(base) for base in bases)
if is_table_model_class(cls) and not base_is_table:
_mapper_args = dict_.get("__mapper_args__", {})
polymorphic_identity = _mapper_args.get("polymorphic_identity")
polymorphic_abstract = _mapper_args.get("polymorphic_abstract")
has_polymorphic = (
polymorphic_identity is not None or polymorphic_abstract is not None
)

# allow polymorphic models inherit from table models
if is_table_model_class(cls) and (not base_is_table or has_polymorphic):
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
if rel_info.sa_relationship:
# There's a SQLAlchemy relationship declared, that takes precedence
Expand Down Expand Up @@ -700,13 +760,13 @@ def get_sqlalchemy_type(field: Any) -> Any:
raise ValueError(f"{type_} has no matching SQLAlchemy type")


def get_column_from_field(field: Any) -> Column: # type: ignore
def get_column_from_field(field: Any) -> Union[Column, MappedColumn]: # type: ignore
if IS_PYDANTIC_V2:
field_info = field
else:
field_info = field.field_info
sa_column = getattr(field_info, "sa_column", Undefined)
if isinstance(sa_column, Column):
if isinstance(sa_column, Column) or isinstance(sa_column, MappedColumn):
return sa_column
sa_type = get_sqlalchemy_type(field)
primary_key = getattr(field_info, "primary_key", Undefined)
Expand Down Expand Up @@ -770,7 +830,6 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
__slots__ = ("__weakref__",)
__tablename__: ClassVar[Union[str, Callable[..., str]]]
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty[Any]]]
__name__: ClassVar[str]
metadata: ClassVar[MetaData]
Expand Down Expand Up @@ -834,12 +893,8 @@ def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]:
if not (isinstance(k, str) and k.startswith("_sa_"))
]

@declared_attr # type: ignore
def __tablename__(cls) -> str:
return cls.__name__.lower()

@classmethod
def model_validate( # type: ignore[override]
def model_validate(
cls: Type[_TSQLModel],
obj: Any,
*,
Expand All @@ -863,25 +918,20 @@ def model_dump(
mode: Union[Literal["json", "python"], str] = "python",
include: Union[IncEx, None] = None,
exclude: Union[IncEx, None] = None,
context: Union[Any, None] = None,
by_alias: Union[bool, None] = None,
context: Union[Dict[str, Any], None] = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: Union[bool, Literal["none", "warn", "error"]] = True,
fallback: Union[Callable[[Any], Any], None] = None,
serialize_as_any: bool = False,
) -> Dict[str, Any]:
if PYDANTIC_MINOR_VERSION < (2, 11):
by_alias = by_alias or False
if PYDANTIC_MINOR_VERSION >= (2, 7):
if PYDANTIC_VERSION >= "2.7.0":
extra_kwargs: Dict[str, Any] = {
"context": context,
"serialize_as_any": serialize_as_any,
}
if PYDANTIC_MINOR_VERSION >= (2, 11):
extra_kwargs["fallback"] = fallback
else:
extra_kwargs = {}
if IS_PYDANTIC_V2:
Expand All @@ -901,7 +951,7 @@ def model_dump(
return super().dict(
include=include,
exclude=exclude,
by_alias=by_alias or False,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
Expand Down
Loading
Loading