Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sqlmodel/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ def get_relationship_to(
# If a list, then also get the real field
elif origin is list:
use_annotation = get_args(annotation)[0]
# If a dict, then use the value type
elif origin is dict:
use_annotation = get_args(annotation)[1]

return get_relationship_to(
name=name, rel_info=rel_info, annotation=use_annotation
Expand Down
48 changes: 48 additions & 0 deletions tests/test_attribute_keyed_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from enum import Enum
from typing import Dict, Optional

from sqlalchemy.orm.collections import attribute_keyed_dict
from sqlmodel import Field, Index, Relationship, Session, SQLModel, create_engine


def test_attribute_keyed_dict_works(clear_sqlmodel):
class Color(str, Enum):
Orange = "Orange"
Blue = "Blue"

class Child(SQLModel, table=True):
__tablename__ = "children"
__table_args__ = (
Index("ix_children_parent_id_color", "parent_id", "color", unique=True),
)

id: Optional[int] = Field(primary_key=True, default=None)
parent_id: int = Field(foreign_key="parents.id")
color: Color
value: int

class Parent(SQLModel, table=True):
__tablename__ = "parents"

id: Optional[int] = Field(primary_key=True, default=None)
children_by_color: Dict[Color, Child] = Relationship(
sa_relationship_kwargs={"collection_class": attribute_keyed_dict("color")}
)

engine = create_engine("sqlite://")
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
parent = Parent()
session.add(parent)
session.commit()
session.refresh(parent)
session.add(Child(parent_id=parent.id, color=Color.Orange, value=1))
session.add(Child(parent_id=parent.id, color=Color.Blue, value=2))
session.commit()
session.refresh(parent)
assert parent.children_by_color[Color.Orange].parent_id == parent.id
assert parent.children_by_color[Color.Orange].color == Color.Orange
assert parent.children_by_color[Color.Orange].value == 1
assert parent.children_by_color[Color.Blue].parent_id == parent.id
assert parent.children_by_color[Color.Blue].color == Color.Blue
assert parent.children_by_color[Color.Blue].value == 2