diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 7c916f79af..87a4f04a06 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -3,6 +3,7 @@ import ipaddress import uuid import weakref +from copy import copy from datetime import date, datetime, time, timedelta from decimal import Decimal from enum import Enum @@ -166,6 +167,12 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: if ondelete is not Undefined: if foreign_key is Undefined: raise RuntimeError("ondelete can only be used with foreign_key") + if not isinstance(foreign_key, str): + raise RuntimeError( + "Passing ondelete to Field is not supported when foreign_key is " + "specified as sa_column_args=[ForeignKey(...)]. Pass ondelete as " + "a parameter to ForeignKey instead" + ) super().__init__(default=default, **kwargs) self.primary_key = primary_key self.nullable = nullable @@ -733,14 +740,18 @@ def get_column_from_field(field: Any) -> Column: # type: ignore if unique is Undefined: unique = False if foreign_key: - if field_info.ondelete == "SET NULL" and not nullable: - raise RuntimeError('ondelete="SET NULL" requires nullable=True') - assert isinstance(foreign_key, str) - ondelete = getattr(field_info, "ondelete", Undefined) - if ondelete is Undefined: - ondelete = None - assert isinstance(ondelete, (str, type(None))) # for typing - args.append(ForeignKey(foreign_key, ondelete=ondelete)) + if isinstance(foreign_key, str): + if field_info.ondelete == "SET NULL" and not nullable: + raise RuntimeError('ondelete="SET NULL" requires nullable=True') + assert isinstance(foreign_key, str) + ondelete = getattr(field_info, "ondelete", Undefined) + if ondelete is Undefined: + ondelete = None + assert isinstance(ondelete, (str, type(None))) # for typing + args.append(ForeignKey(foreign_key, ondelete=ondelete)) + else: + assert isinstance(foreign_key, ForeignKey) + args.append(copy(foreign_key)) kwargs = { "primary_key": primary_key, "nullable": nullable, @@ -756,7 +767,11 @@ def get_column_from_field(field: Any) -> Column: # type: ignore kwargs["default"] = sa_default sa_column_args = getattr(field_info, "sa_column_args", Undefined) if sa_column_args is not Undefined: - args.extend(list(cast(Sequence[Any], sa_column_args))) + for arg_v in list(cast(Sequence[Any], sa_column_args)): + if isinstance(arg_v, ForeignKey): + args.append(copy(arg_v)) + else: + args.append(arg_v) sa_column_kwargs = getattr(field_info, "sa_column_kwargs", Undefined) if sa_column_kwargs is not Undefined: kwargs.update(cast(Dict[Any, Any], sa_column_kwargs)) diff --git a/tests/test_field_sa_fk_args_kwargs.py b/tests/test_field_sa_fk_args_kwargs.py new file mode 100644 index 0000000000..3ac8aec92f --- /dev/null +++ b/tests/test_field_sa_fk_args_kwargs.py @@ -0,0 +1,51 @@ +from typing import Optional + +from sqlalchemy import ForeignKey, create_engine +from sqlmodel import Field, SQLModel + + +def test_base_model_fk(clear_sqlmodel, caplog) -> None: + class User(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + class Base(SQLModel): + owner_id: Optional[int] = Field( + default=None, sa_column_args=(ForeignKey("user.id", ondelete="SET NULL"),) + ) + + class Asset(Base, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + class Document(Base, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + engine = create_engine("sqlite://", echo=True) + SQLModel.metadata.create_all(engine) + + assert ( + "FOREIGN KEY(owner_id) REFERENCES user (id) ON DELETE SET NULL" in caplog.text + ) + + +def test_base_model_fk_args(clear_sqlmodel, caplog) -> None: + class User(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + class Base(SQLModel): + owner_id: Optional[int] = Field( + default=None, + foreign_key=ForeignKey("user.id", ondelete="SET NULL"), + ) + + class Asset(Base, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + class Document(Base, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + engine = create_engine("sqlite://", echo=True) + SQLModel.metadata.create_all(engine) + + assert ( + "FOREIGN KEY(owner_id) REFERENCES user (id) ON DELETE SET NULL" in caplog.text + )