Skip to content

Commit 37a9240

Browse files
committed
Alternative implementation: remove sa_column_fk_{args,kwargs}, allow the use to pass foreign_key=ForeignKey(...)
1 parent 0e03bac commit 37a9240

File tree

2 files changed

+21
-30
lines changed

2 files changed

+21
-30
lines changed

sqlmodel/main.py

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,6 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
121121
sa_column = kwargs.pop("sa_column", Undefined)
122122
sa_column_args = kwargs.pop("sa_column_args", Undefined)
123123
sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined)
124-
sa_foreign_key_args = kwargs.pop("sa_foreign_key_args", Undefined)
125-
sa_foreign_key_kwargs = kwargs.pop("sa_foreign_key_kwargs", Undefined)
126124
if sa_column is not Undefined:
127125
if sa_column_args is not Undefined:
128126
raise RuntimeError(
@@ -165,6 +163,10 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
165163
if ondelete is not Undefined:
166164
if foreign_key is Undefined:
167165
raise RuntimeError("ondelete can only be used with foreign_key")
166+
if not isinstance(foreign_key, str):
167+
raise RuntimeError(
168+
"ondelete can only be used with foreign_key given as a string"
169+
)
168170
super().__init__(default=default, **kwargs)
169171
self.primary_key = primary_key
170172
self.nullable = nullable
@@ -176,8 +178,6 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
176178
self.sa_column = sa_column
177179
self.sa_column_args = sa_column_args
178180
self.sa_column_kwargs = sa_column_kwargs
179-
self.sa_foreign_key_args = sa_foreign_key_args
180-
self.sa_foreign_key_kwargs = sa_foreign_key_kwargs
181181

182182

183183
class RelationshipInfo(Representation):
@@ -252,8 +252,6 @@ def Field(
252252
sa_type: Union[Type[Any], UndefinedType] = Undefined,
253253
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
254254
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
255-
sa_foreign_key_args: Union[Sequence[Any], UndefinedType] = Undefined,
256-
sa_foreign_key_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
257255
schema_extra: Optional[Dict[str, Any]] = None,
258256
) -> Any: ...
259257

@@ -390,8 +388,6 @@ def Field(
390388
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
391389
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
392390
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
393-
sa_foreign_key_args: Union[Sequence[Any], UndefinedType] = Undefined,
394-
sa_foreign_key_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
395391
schema_extra: Optional[Dict[str, Any]] = None,
396392
) -> Any:
397393
current_schema_extra = schema_extra or {}
@@ -430,8 +426,6 @@ def Field(
430426
sa_column=sa_column,
431427
sa_column_args=sa_column_args,
432428
sa_column_kwargs=sa_column_kwargs,
433-
sa_foreign_key_args=sa_foreign_key_args,
434-
sa_foreign_key_kwargs=sa_foreign_key_kwargs,
435429
**current_schema_extra,
436430
)
437431
post_init_field_info(field_info)
@@ -740,22 +734,18 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
740734
if unique is Undefined:
741735
unique = False
742736
if foreign_key:
743-
if field_info.ondelete == "SET NULL" and not nullable:
744-
raise RuntimeError('ondelete="SET NULL" requires nullable=True')
745-
assert isinstance(foreign_key, str)
746-
fk_args = []
747-
fk_kwargs = {}
748-
ondelete = getattr(field_info, "ondelete", Undefined)
749-
if ondelete is not Undefined:
737+
if isinstance(foreign_key, str):
738+
if field_info.ondelete == "SET NULL" and not nullable:
739+
raise RuntimeError('ondelete="SET NULL" requires nullable=True')
740+
assert isinstance(foreign_key, str)
741+
ondelete = getattr(field_info, "ondelete", Undefined)
742+
if ondelete is Undefined:
743+
ondelete = None
750744
assert isinstance(ondelete, (str, type(None)))
751-
fk_kwargs["ondelete"] = ondelete
752-
sa_foreign_key_args = getattr(field_info, "sa_foreign_key_args", Undefined)
753-
if sa_foreign_key_args is not Undefined:
754-
fk_args.extend(cast(Sequence[Any], sa_foreign_key_args))
755-
sa_foreign_key_kwargs = getattr(field_info, "sa_foreign_key_kwargs", Undefined)
756-
if sa_foreign_key_kwargs is not Undefined:
757-
fk_kwargs.update(cast(Dict[Any, Any], sa_foreign_key_kwargs))
758-
args.append(ForeignKey(foreign_key, *fk_args, **fk_kwargs))
745+
args.append(ForeignKey(foreign_key, ondelete=ondelete))
746+
else:
747+
assert isinstance(foreign_key, ForeignKey)
748+
args.append(foreign_key.copy())
759749
kwargs = {
760750
"primary_key": primary_key,
761751
"nullable": nullable,

tests/test_field_sa_fk_args_kwargs.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@ class Asset(Base, table=True):
2222
id: Optional[int] = Field(default=None, primary_key=True)
2323

2424
# Fails in Pydantic v2, but not v1
25-
with pytest.raises(
26-
sqlalchemy.exc.InvalidRequestError
27-
) if IS_PYDANTIC_V2 else contextlib.nullcontext() as e:
25+
with (
26+
pytest.raises(sqlalchemy.exc.InvalidRequestError)
27+
if IS_PYDANTIC_V2
28+
else contextlib.nullcontext()
29+
) as e:
2830

2931
class Document(Base, table=True):
3032
id: Optional[int] = Field(default=None, primary_key=True)
@@ -52,8 +54,7 @@ class User(SQLModel, table=True):
5254
class Base(SQLModel):
5355
owner_id: Optional[int] = Field(
5456
default=None,
55-
foreign_key="user.id",
56-
sa_foreign_key_kwargs={"ondelete": "SET NULL"},
57+
foreign_key=ForeignKey("user.id", ondelete="SET NULL"),
5758
)
5859

5960
class Asset(Base, table=True):

0 commit comments

Comments
 (0)