Skip to content

Commit 717594e

Browse files
authored
✨ Do not allow invalid combinations of field parameters for columns and relationships, sa_column excludes sa_column_args, primary_key, nullable, etc. (#681)
* ♻️ Make sa_column exclusive, do not allow incompatible arguments, sa_column_args, primary_key, etc * ✅ Add tests for new errors when incorrectly using sa_column * ✅ Add tests for sa_column_args and sa_column_kwargs * ♻️ Do not allow sa_relationship with sa_relationship_args or sa_relationship_kwargs * ✅ Add tests for relationship errors * ✅ Fix test for sa_column_args
1 parent e4e1385 commit 717594e

File tree

4 files changed

+332
-10
lines changed

4 files changed

+332
-10
lines changed

sqlmodel/main.py

Lines changed: 141 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
TypeVar,
2323
Union,
2424
cast,
25+
overload,
2526
)
2627

2728
from pydantic import BaseConfig, BaseModel
@@ -87,6 +88,28 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
8788
"Passing sa_column_kwargs is not supported when "
8889
"also passing a sa_column"
8990
)
91+
if primary_key is not Undefined:
92+
raise RuntimeError(
93+
"Passing primary_key is not supported when "
94+
"also passing a sa_column"
95+
)
96+
if nullable is not Undefined:
97+
raise RuntimeError(
98+
"Passing nullable is not supported when " "also passing a sa_column"
99+
)
100+
if foreign_key is not Undefined:
101+
raise RuntimeError(
102+
"Passing foreign_key is not supported when "
103+
"also passing a sa_column"
104+
)
105+
if unique is not Undefined:
106+
raise RuntimeError(
107+
"Passing unique is not supported when " "also passing a sa_column"
108+
)
109+
if index is not Undefined:
110+
raise RuntimeError(
111+
"Passing index is not supported when " "also passing a sa_column"
112+
)
90113
super().__init__(default=default, **kwargs)
91114
self.primary_key = primary_key
92115
self.nullable = nullable
@@ -126,6 +149,86 @@ def __init__(
126149
self.sa_relationship_kwargs = sa_relationship_kwargs
127150

128151

152+
@overload
153+
def Field(
154+
default: Any = Undefined,
155+
*,
156+
default_factory: Optional[NoArgAnyCallable] = None,
157+
alias: Optional[str] = None,
158+
title: Optional[str] = None,
159+
description: Optional[str] = None,
160+
exclude: Union[
161+
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
162+
] = None,
163+
include: Union[
164+
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
165+
] = None,
166+
const: Optional[bool] = None,
167+
gt: Optional[float] = None,
168+
ge: Optional[float] = None,
169+
lt: Optional[float] = None,
170+
le: Optional[float] = None,
171+
multiple_of: Optional[float] = None,
172+
max_digits: Optional[int] = None,
173+
decimal_places: Optional[int] = None,
174+
min_items: Optional[int] = None,
175+
max_items: Optional[int] = None,
176+
unique_items: Optional[bool] = None,
177+
min_length: Optional[int] = None,
178+
max_length: Optional[int] = None,
179+
allow_mutation: bool = True,
180+
regex: Optional[str] = None,
181+
discriminator: Optional[str] = None,
182+
repr: bool = True,
183+
primary_key: Union[bool, UndefinedType] = Undefined,
184+
foreign_key: Any = Undefined,
185+
unique: Union[bool, UndefinedType] = Undefined,
186+
nullable: Union[bool, UndefinedType] = Undefined,
187+
index: Union[bool, UndefinedType] = Undefined,
188+
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
189+
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
190+
schema_extra: Optional[Dict[str, Any]] = None,
191+
) -> Any:
192+
...
193+
194+
195+
@overload
196+
def Field(
197+
default: Any = Undefined,
198+
*,
199+
default_factory: Optional[NoArgAnyCallable] = None,
200+
alias: Optional[str] = None,
201+
title: Optional[str] = None,
202+
description: Optional[str] = None,
203+
exclude: Union[
204+
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
205+
] = None,
206+
include: Union[
207+
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
208+
] = None,
209+
const: Optional[bool] = None,
210+
gt: Optional[float] = None,
211+
ge: Optional[float] = None,
212+
lt: Optional[float] = None,
213+
le: Optional[float] = None,
214+
multiple_of: Optional[float] = None,
215+
max_digits: Optional[int] = None,
216+
decimal_places: Optional[int] = None,
217+
min_items: Optional[int] = None,
218+
max_items: Optional[int] = None,
219+
unique_items: Optional[bool] = None,
220+
min_length: Optional[int] = None,
221+
max_length: Optional[int] = None,
222+
allow_mutation: bool = True,
223+
regex: Optional[str] = None,
224+
discriminator: Optional[str] = None,
225+
repr: bool = True,
226+
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
227+
schema_extra: Optional[Dict[str, Any]] = None,
228+
) -> Any:
229+
...
230+
231+
129232
def Field(
130233
default: Any = Undefined,
131234
*,
@@ -156,9 +259,9 @@ def Field(
156259
regex: Optional[str] = None,
157260
discriminator: Optional[str] = None,
158261
repr: bool = True,
159-
primary_key: bool = False,
160-
foreign_key: Optional[Any] = None,
161-
unique: bool = False,
262+
primary_key: Union[bool, UndefinedType] = Undefined,
263+
foreign_key: Any = Undefined,
264+
unique: Union[bool, UndefinedType] = Undefined,
162265
nullable: Union[bool, UndefinedType] = Undefined,
163266
index: Union[bool, UndefinedType] = Undefined,
164267
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
@@ -206,6 +309,27 @@ def Field(
206309
return field_info
207310

208311

312+
@overload
313+
def Relationship(
314+
*,
315+
back_populates: Optional[str] = None,
316+
link_model: Optional[Any] = None,
317+
sa_relationship_args: Optional[Sequence[Any]] = None,
318+
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
319+
) -> Any:
320+
...
321+
322+
323+
@overload
324+
def Relationship(
325+
*,
326+
back_populates: Optional[str] = None,
327+
link_model: Optional[Any] = None,
328+
sa_relationship: Optional[RelationshipProperty] = None, # type: ignore
329+
) -> Any:
330+
...
331+
332+
209333
def Relationship(
210334
*,
211335
back_populates: Optional[str] = None,
@@ -440,21 +564,28 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore
440564
if isinstance(sa_column, Column):
441565
return sa_column
442566
sa_type = get_sqlalchemy_type(field)
443-
primary_key = getattr(field.field_info, "primary_key", False)
567+
primary_key = getattr(field.field_info, "primary_key", Undefined)
568+
if primary_key is Undefined:
569+
primary_key = False
444570
index = getattr(field.field_info, "index", Undefined)
445571
if index is Undefined:
446572
index = False
447573
nullable = not primary_key and _is_field_noneable(field)
448574
# Override derived nullability if the nullable property is set explicitly
449575
# on the field
450-
if hasattr(field.field_info, "nullable"):
451-
field_nullable = getattr(field.field_info, "nullable") # noqa: B009
452-
if field_nullable != Undefined:
453-
nullable = field_nullable
576+
field_nullable = getattr(field.field_info, "nullable", Undefined) # noqa: B009
577+
if field_nullable != Undefined:
578+
assert not isinstance(field_nullable, UndefinedType)
579+
nullable = field_nullable
454580
args = []
455-
foreign_key = getattr(field.field_info, "foreign_key", None)
456-
unique = getattr(field.field_info, "unique", False)
581+
foreign_key = getattr(field.field_info, "foreign_key", Undefined)
582+
if foreign_key is Undefined:
583+
foreign_key = None
584+
unique = getattr(field.field_info, "unique", Undefined)
585+
if unique is Undefined:
586+
unique = False
457587
if foreign_key:
588+
assert isinstance(foreign_key, str)
458589
args.append(ForeignKey(foreign_key))
459590
kwargs = {
460591
"primary_key": primary_key,

tests/test_field_sa_args_kwargs.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from typing import Optional
2+
3+
from sqlalchemy import ForeignKey
4+
from sqlmodel import Field, SQLModel, create_engine
5+
6+
7+
def test_sa_column_args(clear_sqlmodel, caplog) -> None:
8+
class Team(SQLModel, table=True):
9+
id: Optional[int] = Field(default=None, primary_key=True)
10+
name: str
11+
12+
class Hero(SQLModel, table=True):
13+
id: Optional[int] = Field(default=None, primary_key=True)
14+
team_id: Optional[int] = Field(
15+
default=None,
16+
sa_column_args=[ForeignKey("team.id")],
17+
)
18+
19+
engine = create_engine("sqlite://", echo=True)
20+
SQLModel.metadata.create_all(engine)
21+
create_table_log = [
22+
message for message in caplog.messages if "CREATE TABLE hero" in message
23+
][0]
24+
assert "FOREIGN KEY(team_id) REFERENCES team (id)" in create_table_log
25+
26+
27+
def test_sa_column_kargs(clear_sqlmodel, caplog) -> None:
28+
class Item(SQLModel, table=True):
29+
id: Optional[int] = Field(
30+
default=None,
31+
sa_column_kwargs={"primary_key": True},
32+
)
33+
34+
engine = create_engine("sqlite://", echo=True)
35+
SQLModel.metadata.create_all(engine)
36+
create_table_log = [
37+
message for message in caplog.messages if "CREATE TABLE item" in message
38+
][0]
39+
assert "PRIMARY KEY (id)" in create_table_log

tests/test_field_sa_column.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from typing import Optional
2+
3+
import pytest
4+
from sqlalchemy import Column, Integer, String
5+
from sqlmodel import Field, SQLModel
6+
7+
8+
def test_sa_column_takes_precedence() -> None:
9+
class Item(SQLModel, table=True):
10+
id: Optional[int] = Field(
11+
default=None,
12+
sa_column=Column(String, primary_key=True, nullable=False),
13+
)
14+
15+
# It would have been nullable with no sa_column
16+
assert Item.id.nullable is False # type: ignore
17+
assert isinstance(Item.id.type, String) # type: ignore
18+
19+
20+
def test_sa_column_no_sa_args() -> None:
21+
with pytest.raises(RuntimeError):
22+
23+
class Item(SQLModel, table=True):
24+
id: Optional[int] = Field(
25+
default=None,
26+
sa_column_args=[Integer],
27+
sa_column=Column(Integer, primary_key=True),
28+
)
29+
30+
31+
def test_sa_column_no_sa_kargs() -> None:
32+
with pytest.raises(RuntimeError):
33+
34+
class Item(SQLModel, table=True):
35+
id: Optional[int] = Field(
36+
default=None,
37+
sa_column_kwargs={"primary_key": True},
38+
sa_column=Column(Integer, primary_key=True),
39+
)
40+
41+
42+
def test_sa_column_no_primary_key() -> None:
43+
with pytest.raises(RuntimeError):
44+
45+
class Item(SQLModel, table=True):
46+
id: Optional[int] = Field(
47+
default=None,
48+
primary_key=True,
49+
sa_column=Column(Integer, primary_key=True),
50+
)
51+
52+
53+
def test_sa_column_no_nullable() -> None:
54+
with pytest.raises(RuntimeError):
55+
56+
class Item(SQLModel, table=True):
57+
id: Optional[int] = Field(
58+
default=None,
59+
nullable=True,
60+
sa_column=Column(Integer, primary_key=True),
61+
)
62+
63+
64+
def test_sa_column_no_foreign_key() -> None:
65+
with pytest.raises(RuntimeError):
66+
67+
class Team(SQLModel, table=True):
68+
id: Optional[int] = Field(default=None, primary_key=True)
69+
name: str
70+
71+
class Hero(SQLModel, table=True):
72+
id: Optional[int] = Field(default=None, primary_key=True)
73+
team_id: Optional[int] = Field(
74+
default=None,
75+
foreign_key="team.id",
76+
sa_column=Column(Integer, primary_key=True),
77+
)
78+
79+
80+
def test_sa_column_no_unique() -> None:
81+
with pytest.raises(RuntimeError):
82+
83+
class Item(SQLModel, table=True):
84+
id: Optional[int] = Field(
85+
default=None,
86+
unique=True,
87+
sa_column=Column(Integer, primary_key=True),
88+
)
89+
90+
91+
def test_sa_column_no_index() -> None:
92+
with pytest.raises(RuntimeError):
93+
94+
class Item(SQLModel, table=True):
95+
id: Optional[int] = Field(
96+
default=None,
97+
index=True,
98+
sa_column=Column(Integer, primary_key=True),
99+
)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from typing import List, Optional
2+
3+
import pytest
4+
from sqlalchemy.orm import relationship
5+
from sqlmodel import Field, Relationship, SQLModel
6+
7+
8+
def test_sa_relationship_no_args() -> None:
9+
with pytest.raises(RuntimeError):
10+
11+
class Team(SQLModel, table=True):
12+
id: Optional[int] = Field(default=None, primary_key=True)
13+
name: str = Field(index=True)
14+
headquarters: str
15+
16+
heroes: List["Hero"] = Relationship(
17+
back_populates="team",
18+
sa_relationship_args=["Hero"],
19+
sa_relationship=relationship("Hero", back_populates="team"),
20+
)
21+
22+
class Hero(SQLModel, table=True):
23+
id: Optional[int] = Field(default=None, primary_key=True)
24+
name: str = Field(index=True)
25+
secret_name: str
26+
age: Optional[int] = Field(default=None, index=True)
27+
28+
team_id: Optional[int] = Field(default=None, foreign_key="team.id")
29+
team: Optional[Team] = Relationship(back_populates="heroes")
30+
31+
32+
def test_sa_relationship_no_kwargs() -> None:
33+
with pytest.raises(RuntimeError):
34+
35+
class Team(SQLModel, table=True):
36+
id: Optional[int] = Field(default=None, primary_key=True)
37+
name: str = Field(index=True)
38+
headquarters: str
39+
40+
heroes: List["Hero"] = Relationship(
41+
back_populates="team",
42+
sa_relationship_kwargs={"lazy": "selectin"},
43+
sa_relationship=relationship("Hero", back_populates="team"),
44+
)
45+
46+
class Hero(SQLModel, table=True):
47+
id: Optional[int] = Field(default=None, primary_key=True)
48+
name: str = Field(index=True)
49+
secret_name: str
50+
age: Optional[int] = Field(default=None, index=True)
51+
52+
team_id: Optional[int] = Field(default=None, foreign_key="team.id")
53+
team: Optional[Team] = Relationship(back_populates="heroes")

0 commit comments

Comments
 (0)