Skip to content

Commit 91da939

Browse files
committed
Reduced code duplication in tests using parametrization
1 parent 1f5d33c commit 91da939

File tree

2 files changed

+32
-143
lines changed

2 files changed

+32
-143
lines changed

tests/test_field_sa_column.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,86 @@
1-
from typing import Optional
1+
from typing import Optional, Union
22

33
import pytest
44
from sqlalchemy import Column, Integer, String
5+
from sqlalchemy.orm import mapped_column
56
from sqlmodel import Field, SQLModel
67

78

8-
def test_sa_column_takes_precedence(clear_sqlmodel) -> None:
9+
@pytest.mark.parametrize("column_class", [Column, mapped_column])
10+
def test_sa_column_takes_precedence(clear_sqlmodel, column_class) -> None:
911
class Item(SQLModel, table=True):
1012
id: Optional[int] = Field(
1113
default=None,
12-
sa_column=Column(String, primary_key=True, nullable=False),
14+
sa_column=column_class(String, primary_key=True, nullable=False),
1315
)
1416

1517
# It would have been nullable with no sa_column
1618
assert Item.id.nullable is False # type: ignore
1719
assert isinstance(Item.id.type, String) # type: ignore
1820

1921

20-
def test_sa_column_no_sa_args() -> None:
22+
@pytest.mark.parametrize("column_class", [Column, mapped_column])
23+
def test_sa_column_no_sa_args(column_class) -> None:
2124
with pytest.raises(RuntimeError):
2225

2326
class Item(SQLModel, table=True):
2427
id: Optional[int] = Field(
2528
default=None,
2629
sa_column_args=[Integer],
27-
sa_column=Column(Integer, primary_key=True),
30+
sa_column=column_class(Integer, primary_key=True),
2831
)
2932

3033

31-
def test_sa_column_no_sa_kargs() -> None:
34+
@pytest.mark.parametrize("column_class", [Column, mapped_column])
35+
def test_sa_column_no_sa_kargs(column_class) -> None:
3236
with pytest.raises(RuntimeError):
3337

3438
class Item(SQLModel, table=True):
3539
id: Optional[int] = Field(
3640
default=None,
3741
sa_column_kwargs={"primary_key": True},
38-
sa_column=Column(Integer, primary_key=True),
42+
sa_column=column_class(Integer, primary_key=True),
3943
)
4044

4145

42-
def test_sa_column_no_type() -> None:
46+
@pytest.mark.parametrize("column_class", [Column, mapped_column])
47+
def test_sa_column_no_type(column_class) -> None:
4348
with pytest.raises(RuntimeError):
4449

4550
class Item(SQLModel, table=True):
4651
id: Optional[int] = Field(
4752
default=None,
4853
sa_type=Integer,
49-
sa_column=Column(Integer, primary_key=True),
54+
sa_column=column_class(Integer, primary_key=True),
5055
)
5156

5257

53-
def test_sa_column_no_primary_key() -> None:
58+
@pytest.mark.parametrize("column_class", [Column, mapped_column])
59+
def test_sa_column_no_primary_key(column_class) -> None:
5460
with pytest.raises(RuntimeError):
5561

5662
class Item(SQLModel, table=True):
5763
id: Optional[int] = Field(
5864
default=None,
5965
primary_key=True,
60-
sa_column=Column(Integer, primary_key=True),
66+
sa_column=column_class(Integer, primary_key=True),
6167
)
6268

6369

64-
def test_sa_column_no_nullable() -> None:
70+
@pytest.mark.parametrize("column_class", [Column, mapped_column])
71+
def test_sa_column_no_nullable(column_class) -> None:
6572
with pytest.raises(RuntimeError):
6673

6774
class Item(SQLModel, table=True):
6875
id: Optional[int] = Field(
6976
default=None,
7077
nullable=True,
71-
sa_column=Column(Integer, primary_key=True),
78+
sa_column=column_class(Integer, primary_key=True),
7279
)
7380

7481

75-
def test_sa_column_no_foreign_key() -> None:
82+
@pytest.mark.parametrize("column_class", [Column, mapped_column])
83+
def test_sa_column_no_foreign_key(clear_sqlmodel, column_class) -> None:
7684
with pytest.raises(RuntimeError):
7785

7886
class Team(SQLModel, table=True):
@@ -84,38 +92,41 @@ class Hero(SQLModel, table=True):
8492
team_id: Optional[int] = Field(
8593
default=None,
8694
foreign_key="team.id",
87-
sa_column=Column(Integer, primary_key=True),
95+
sa_column=column_class(Integer, primary_key=True),
8896
)
8997

9098

91-
def test_sa_column_no_unique() -> None:
99+
@pytest.mark.parametrize("column_class", [Column, mapped_column])
100+
def test_sa_column_no_unique(column_class) -> None:
92101
with pytest.raises(RuntimeError):
93102

94103
class Item(SQLModel, table=True):
95104
id: Optional[int] = Field(
96105
default=None,
97106
unique=True,
98-
sa_column=Column(Integer, primary_key=True),
107+
sa_column=column_class(Integer, primary_key=True),
99108
)
100109

101110

102-
def test_sa_column_no_index() -> None:
111+
@pytest.mark.parametrize("column_class", [Column, mapped_column])
112+
def test_sa_column_no_index(column_class) -> None:
103113
with pytest.raises(RuntimeError):
104114

105115
class Item(SQLModel, table=True):
106116
id: Optional[int] = Field(
107117
default=None,
108118
index=True,
109-
sa_column=Column(Integer, primary_key=True),
119+
sa_column=column_class(Integer, primary_key=True),
110120
)
111121

112122

113-
def test_sa_column_no_ondelete() -> None:
123+
@pytest.mark.parametrize("column_class", [Column, mapped_column])
124+
def test_sa_column_no_ondelete(column_class) -> None:
114125
with pytest.raises(RuntimeError):
115126

116127
class Item(SQLModel, table=True):
117128
id: Optional[int] = Field(
118129
default=None,
119-
sa_column=Column(Integer, primary_key=True),
130+
sa_column=column_class(Integer, primary_key=True),
120131
ondelete="CASCADE",
121132
)

tests/test_field_sa_column_mapped_column.py

Lines changed: 0 additions & 122 deletions
This file was deleted.

0 commit comments

Comments
 (0)