Skip to content

Commit 16ee026

Browse files
author
Natarajan Krishnaswami
committed
Incorporate review feedback
1 parent a801312 commit 16ee026

File tree

2 files changed

+26
-14
lines changed

2 files changed

+26
-14
lines changed

sqlmodel/_compat.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -178,16 +178,16 @@ def get_relationship_to(
178178
# If a list, then also get the real field
179179
elif origin is list:
180180
use_annotation = get_args(annotation)[0]
181-
# If a dict or Mapping, then use the value (second) type argument
182-
elif origin is dict or origin is Mapping:
181+
# If a dict, then use the value (second) type argument
182+
elif origin is dict:
183183
args = get_args(annotation)
184-
if len(args) >= 2:
185-
use_annotation = args[1]
186-
else:
184+
if len(args) != 2:
187185
raise ValueError(
188-
f"Dict/Mapping relationship field '{name}' must have both "
189-
"key and value type arguments (e.g., dict[str, Model])"
186+
f"Dict/Mapping relationship field '{name}' has {len(args)} "
187+
"type arguments. Exactly two required (e.g., dict[str, "
188+
"Model])"
190189
)
190+
use_annotation = args[1]
191191

192192
return get_relationship_to(
193193
name=name, rel_info=rel_info, annotation=use_annotation

tests/test_attribute_keyed_dict.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from sqlalchemy.orm.collections import attribute_keyed_dict
88
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine
99

10-
from tests.conftest import needs_pydanticv2
10+
from tests.conftest import needs_py310, needs_pydanticv2
1111

1212

1313
def test_attribute_keyed_dict_works(clear_sqlmodel):
@@ -53,8 +53,8 @@ class Parent(SQLModel, table=True):
5353
# typing.Dict throws if it receives the wrong number of type arguments, but dict
5454
# (3.10+) does not; and Pydantic v1 fails to process models with dicts with no
5555
# type arguments.
56-
@pytest.mark.skipif(sys.version_info < (3, 10), reason="dict is not subscriptable")
5756
@needs_pydanticv2
57+
@needs_py310
5858
def test_dict_relationship_throws_on_missing_annotation_arg(clear_sqlmodel):
5959
class Color(str, Enum):
6060
Orange = "Orange"
@@ -68,11 +68,10 @@ class Child(SQLModel, table=True):
6868
color: Color
6969
value: int
7070

71-
error_msg_re = re.escape(
72-
"Dict/Mapping relationship field 'children_by_color' must have both key and value type arguments (e.g., dict[str, Model])"
73-
)
71+
error_msg_fmt = "Dict/Mapping relationship field 'children_by_color' has {count} type arguments. Exactly two required (e.g., dict[str, Model])"
72+
7473
# No type args
75-
with pytest.raises(ValueError, match=error_msg_re):
74+
with pytest.raises(ValueError, match=re.escape(error_msg_fmt.format(count=0))):
7675

7776
class Parent(SQLModel, table=True):
7877
__tablename__ = "parents"
@@ -85,7 +84,7 @@ class Parent(SQLModel, table=True):
8584
)
8685

8786
# One type arg
88-
with pytest.raises(ValueError, match=error_msg_re):
87+
with pytest.raises(ValueError, match=re.escape(error_msg_fmt.format(count=1))):
8988

9089
class Parent(SQLModel, table=True):
9190
__tablename__ = "parents"
@@ -96,3 +95,16 @@ class Parent(SQLModel, table=True):
9695
"collection_class": attribute_keyed_dict("color")
9796
}
9897
)
98+
99+
# Three type args
100+
with pytest.raises(ValueError, match=re.escape(error_msg_fmt.format(count=3))):
101+
102+
class Parent(SQLModel, table=True):
103+
__tablename__ = "parents"
104+
105+
id: Optional[int] = Field(primary_key=True, default=None)
106+
children_by_color: dict[Color, Child, str] = Relationship(
107+
sa_relationship_kwargs={
108+
"collection_class": attribute_keyed_dict("color")
109+
}
110+
)

0 commit comments

Comments
 (0)