Skip to content

Commit d6ed383

Browse files
committed
✨Add foreign_key_args and foreign_key_kwargs arguments to Field(...) to let the user define additional sqlalchemy.orm.ForeignKey attributes, such as ondelete and onupdate, for foreign keys defined in a base model.
1 parent c75743d commit d6ed383

File tree

2 files changed

+98
-1
lines changed

2 files changed

+98
-1
lines changed

sqlmodel/main.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
106106
sa_column = kwargs.pop("sa_column", Undefined)
107107
sa_column_args = kwargs.pop("sa_column_args", Undefined)
108108
sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined)
109+
sa_foreign_key_args = kwargs.pop("sa_foreign_key_args", Undefined)
110+
sa_foreign_key_kwargs = kwargs.pop("sa_foreign_key_kwargs", Undefined)
109111
if sa_column is not Undefined:
110112
if sa_column_args is not Undefined:
111113
raise RuntimeError(
@@ -153,6 +155,8 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
153155
self.sa_column = sa_column
154156
self.sa_column_args = sa_column_args
155157
self.sa_column_kwargs = sa_column_kwargs
158+
self.sa_foreign_key_args = sa_foreign_key_args
159+
self.sa_foreign_key_kwargs = sa_foreign_key_kwargs
156160

157161

158162
class RelationshipInfo(Representation):
@@ -222,6 +226,8 @@ def Field(
222226
sa_type: Union[Type[Any], UndefinedType] = Undefined,
223227
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
224228
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
229+
sa_foreign_key_args: Union[Sequence[Any], UndefinedType] = Undefined,
230+
sa_foreign_key_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
225231
schema_extra: Optional[Dict[str, Any]] = None,
226232
) -> Any:
227233
...
@@ -303,6 +309,8 @@ def Field(
303309
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
304310
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
305311
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
312+
sa_foreign_key_args: Union[Sequence[Any], UndefinedType] = Undefined,
313+
sa_foreign_key_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
306314
schema_extra: Optional[Dict[str, Any]] = None,
307315
) -> Any:
308316
current_schema_extra = schema_extra or {}
@@ -340,6 +348,8 @@ def Field(
340348
sa_column=sa_column,
341349
sa_column_args=sa_column_args,
342350
sa_column_kwargs=sa_column_kwargs,
351+
sa_foreign_key_args=sa_foreign_key_args,
352+
sa_foreign_key_kwargs=sa_foreign_key_kwargs,
343353
**current_schema_extra,
344354
)
345355
post_init_field_info(field_info)
@@ -638,7 +648,19 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
638648
unique = False
639649
if foreign_key:
640650
assert isinstance(foreign_key, str)
641-
args.append(ForeignKey(foreign_key))
651+
sa_foreign_key_args = getattr(field_info, "sa_foreign_key_args", Undefined)
652+
fk_args = (
653+
[]
654+
if sa_foreign_key_args is Undefined
655+
else list(cast(Sequence[Any], sa_foreign_key_args))
656+
)
657+
sa_foreign_key_kwargs = getattr(field_info, "sa_foreign_key_kwargs", Undefined)
658+
fk_kwargs = (
659+
{}
660+
if sa_foreign_key_kwargs is Undefined
661+
else cast(Dict[Any, Any], sa_foreign_key_kwargs)
662+
)
663+
args.append(ForeignKey(foreign_key, *fk_args, **fk_kwargs))
642664
kwargs = {
643665
"primary_key": primary_key,
644666
"nullable": nullable,
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import contextlib
2+
import re
3+
from typing import Optional
4+
5+
import pytest
6+
import sqlalchemy.exc
7+
from sqlalchemy import ForeignKey, create_engine
8+
from sqlmodel import Field, SQLModel
9+
from sqlmodel._compat import IS_PYDANTIC_V2
10+
11+
12+
def test_base_model_fk(clear_sqlmodel, caplog) -> None:
13+
class User(SQLModel, table=True):
14+
id: Optional[int] = Field(default=None, primary_key=True)
15+
16+
class Base(SQLModel):
17+
owner_id: Optional[int] = Field(
18+
default=None, sa_column_args=(ForeignKey("user.id", ondelete="SET NULL"),)
19+
)
20+
21+
class Asset(Base, table=True):
22+
id: Optional[int] = Field(default=None, primary_key=True)
23+
24+
# Fails in Pydantic v2, but not v1
25+
with pytest.raises(
26+
sqlalchemy.exc.InvalidRequestError
27+
) if IS_PYDANTIC_V2 else contextlib.nullcontext() as e:
28+
29+
class Document(Base, table=True):
30+
id: Optional[int] = Field(default=None, primary_key=True)
31+
32+
if e:
33+
assert "This ForeignKey already has a parent" in str(e.errisinstance)
34+
35+
engine = create_engine("sqlite://", echo=True)
36+
SQLModel.metadata.create_all(engine)
37+
38+
fk_log = [
39+
message
40+
for message in caplog.messages
41+
if re.search(
42+
r"FOREIGN KEY\s*\(owner_id\)\s*REFERENCES\s*user\s*\(id\)", message
43+
)
44+
][0]
45+
assert "ON DELETE SET NULL" in fk_log
46+
47+
48+
def test_base_model_fk_args(clear_sqlmodel, caplog) -> None:
49+
class User(SQLModel, table=True):
50+
id: Optional[int] = Field(default=None, primary_key=True)
51+
52+
class Base(SQLModel):
53+
owner_id: Optional[int] = Field(
54+
default=None,
55+
foreign_key="user.id",
56+
sa_foreign_key_kwargs={"ondelete": "SET NULL"},
57+
)
58+
59+
class Asset(Base, table=True):
60+
id: Optional[int] = Field(default=None, primary_key=True)
61+
62+
class Document(Base, table=True):
63+
id: Optional[int] = Field(default=None, primary_key=True)
64+
65+
engine = create_engine("sqlite://", echo=True)
66+
SQLModel.metadata.create_all(engine)
67+
68+
fk_log = [
69+
message
70+
for message in caplog.messages
71+
if re.search(
72+
r"FOREIGN KEY\s*\(owner_id\)\s*REFERENCES\s*user\s*\(id\)", message
73+
)
74+
][0]
75+
assert "ON DELETE SET NULL" in fk_log

0 commit comments

Comments
 (0)