Skip to content

Commit 42b0e6e

Browse files
raphaelgibsonRaphael Gibsontiangolo
authored
✨ Allow setting unique in Field() for a column (#83)
Co-authored-by: Raphael Gibson <[email protected]> Co-authored-by: Sebastián Ramírez <[email protected]>
1 parent 1ca2880 commit 42b0e6e

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

sqlmodel/main.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
6161
primary_key = kwargs.pop("primary_key", False)
6262
nullable = kwargs.pop("nullable", Undefined)
6363
foreign_key = kwargs.pop("foreign_key", Undefined)
64+
unique = kwargs.pop("unique", False)
6465
index = kwargs.pop("index", Undefined)
6566
sa_column = kwargs.pop("sa_column", Undefined)
6667
sa_column_args = kwargs.pop("sa_column_args", Undefined)
@@ -80,6 +81,7 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
8081
self.primary_key = primary_key
8182
self.nullable = nullable
8283
self.foreign_key = foreign_key
84+
self.unique = unique
8385
self.index = index
8486
self.sa_column = sa_column
8587
self.sa_column_args = sa_column_args
@@ -141,6 +143,7 @@ def Field(
141143
regex: Optional[str] = None,
142144
primary_key: bool = False,
143145
foreign_key: Optional[Any] = None,
146+
unique: bool = False,
144147
nullable: Union[bool, UndefinedType] = Undefined,
145148
index: Union[bool, UndefinedType] = Undefined,
146149
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
@@ -171,6 +174,7 @@ def Field(
171174
regex=regex,
172175
primary_key=primary_key,
173176
foreign_key=foreign_key,
177+
unique=unique,
174178
nullable=nullable,
175179
index=index,
176180
sa_column=sa_column,
@@ -426,12 +430,14 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore
426430
nullable = not primary_key and _is_field_nullable(field)
427431
args = []
428432
foreign_key = getattr(field.field_info, "foreign_key", None)
433+
unique = getattr(field.field_info, "unique", False)
429434
if foreign_key:
430435
args.append(ForeignKey(foreign_key))
431436
kwargs = {
432437
"primary_key": primary_key,
433438
"nullable": nullable,
434439
"index": index,
440+
"unique": unique,
435441
}
436442
sa_default = Undefined
437443
if field.field_info.default_factory:

tests/test_main.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from typing import Optional
2+
3+
import pytest
4+
from sqlalchemy.exc import IntegrityError
5+
from sqlmodel import Field, Session, SQLModel, create_engine
6+
7+
8+
def test_should_allow_duplicate_row_if_unique_constraint_is_not_passed(clear_sqlmodel):
9+
class Hero(SQLModel, table=True):
10+
id: Optional[int] = Field(default=None, primary_key=True)
11+
name: str
12+
secret_name: str
13+
age: Optional[int] = None
14+
15+
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
16+
hero_2 = Hero(name="Deadpond", secret_name="Dive Wilson")
17+
18+
engine = create_engine("sqlite://")
19+
20+
SQLModel.metadata.create_all(engine)
21+
22+
with Session(engine) as session:
23+
session.add(hero_1)
24+
session.commit()
25+
session.refresh(hero_1)
26+
27+
with Session(engine) as session:
28+
session.add(hero_2)
29+
session.commit()
30+
session.refresh(hero_2)
31+
32+
with Session(engine) as session:
33+
heroes = session.query(Hero).all()
34+
assert len(heroes) == 2
35+
assert heroes[0].name == heroes[1].name
36+
37+
38+
def test_should_allow_duplicate_row_if_unique_constraint_is_false(clear_sqlmodel):
39+
class Hero(SQLModel, table=True):
40+
id: Optional[int] = Field(default=None, primary_key=True)
41+
name: str
42+
secret_name: str = Field(unique=False)
43+
age: Optional[int] = None
44+
45+
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
46+
hero_2 = Hero(name="Deadpond", secret_name="Dive Wilson")
47+
48+
engine = create_engine("sqlite://")
49+
50+
SQLModel.metadata.create_all(engine)
51+
52+
with Session(engine) as session:
53+
session.add(hero_1)
54+
session.commit()
55+
session.refresh(hero_1)
56+
57+
with Session(engine) as session:
58+
session.add(hero_2)
59+
session.commit()
60+
session.refresh(hero_2)
61+
62+
with Session(engine) as session:
63+
heroes = session.query(Hero).all()
64+
assert len(heroes) == 2
65+
assert heroes[0].name == heroes[1].name
66+
67+
68+
def test_should_raise_exception_when_try_to_duplicate_row_if_unique_constraint_is_true(
69+
clear_sqlmodel,
70+
):
71+
class Hero(SQLModel, table=True):
72+
id: Optional[int] = Field(default=None, primary_key=True)
73+
name: str
74+
secret_name: str = Field(unique=True)
75+
age: Optional[int] = None
76+
77+
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
78+
hero_2 = Hero(name="Deadpond", secret_name="Dive Wilson")
79+
80+
engine = create_engine("sqlite://")
81+
82+
SQLModel.metadata.create_all(engine)
83+
84+
with Session(engine) as session:
85+
session.add(hero_1)
86+
session.commit()
87+
session.refresh(hero_1)
88+
89+
with pytest.raises(IntegrityError):
90+
with Session(engine) as session:
91+
session.add(hero_2)
92+
session.commit()
93+
session.refresh(hero_2)

0 commit comments

Comments
 (0)