Skip to content

Commit 0c661f7

Browse files
committed
Fix self refernce case
1 parent f8e15b2 commit 0c661f7

File tree

2 files changed

+66
-15
lines changed

2 files changed

+66
-15
lines changed

fquery/sqlmodel.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,15 @@ def sqlmodel(self) -> SQLModel:
9797
attrs = {name: getattr(self, name) for name in self.__sqlmodel__.__fields__}
9898
return self.__sqlmodel__(**attrs)
9999

100+
def check_self_reference(clsname: str, field):
101+
# Check if the field is a self-referential relationship
102+
if (
103+
field.type == ForwardRef(clsname)
104+
or field.type == Optional[ForwardRef(clsname)]
105+
):
106+
return True
107+
return False
108+
100109
def get_field_def(cls, field) -> Union[Field, Relationship]:
101110
sql_meta = field.metadata.get("SQL", {})
102111
has_foreign_key = bool(sql_meta.get("foreign_key", None))
@@ -129,7 +138,16 @@ def get_field_def(cls, field) -> Union[Field, Relationship]:
129138
back_populates = inflection.underscore(cls.__name__)
130139
if sql_meta.get("many_to_one", False):
131140
back_populates = inflection.pluralize(back_populates)
132-
return Relationship(back_populates=back_populates)
141+
142+
key_column = sql_meta.get("key_column", None)
143+
self_reference = check_self_reference(cls.__name__, field)
144+
sa_relationship_kwargs = (
145+
dict(remote_side=key_column) if key_column and self_reference else None
146+
)
147+
return Relationship(
148+
back_populates=back_populates,
149+
sa_relationship_kwargs=sa_relationship_kwargs,
150+
)
133151
if has_foreign_key:
134152
return Field(default=None, foreign_key=sql_meta["foreign_key"])
135153
raise "Unsupported case"
@@ -169,20 +187,18 @@ def patch_back_populates_types(field, back_populates, cls, sqlmodel_cls):
169187
# TODO: log exception?
170188
pass
171189
inner = type_class.__args__[0]
172-
if isinstance(inner, ForwardRef):
173-
# can't patch right now. Try at a later time via back_populates
174-
return
175-
other_class = inner.__sqlmodel__
176-
old = other_class.__annotations__[back_populates]
177-
# Should be sqlalchemy.orm.base.Mapped[typing.List[ForwardRef('T')]]
178-
# replace it with Mapped[List[sqlmodel_cls]]
179-
origin = get_origin(old)
180-
inner = get_args(old)
181-
if origin == Mapped and len(inner) and get_origin(inner[0]) is list:
182-
other_class.__annotations__[back_populates] = Mapped[
183-
List[sqlmodel_cls]
184-
]
185-
other_class.sqlmodel_rebuild()
190+
if not isinstance(inner, ForwardRef):
191+
other_class = inner.__sqlmodel__
192+
old = other_class.__annotations__[back_populates]
193+
# Should be sqlalchemy.orm.base.Mapped[typing.List[ForwardRef('T')]]
194+
# replace it with Mapped[List[sqlmodel_cls]]
195+
origin = get_origin(old)
196+
inner = get_args(old)
197+
if origin == Mapped and len(inner) and get_origin(inner[0]) is list:
198+
other_class.__annotations__[back_populates] = Mapped[
199+
List[sqlmodel_cls]
200+
]
201+
other_class.sqlmodel_rebuild()
186202

187203
# Replace Optional['T'] with Optional[TSQLModel]
188204
old = field.type
@@ -192,6 +208,9 @@ def patch_back_populates_types(field, back_populates, cls, sqlmodel_cls):
192208
if origin == Union and len(inner) and inner[0] == ForwardRef(cls.__name__):
193209
sqlmodel_cls.__annotations__[field.name] = Optional[sqlmodel_cls]
194210
needs_rebuild = True
211+
if origin == list and len(inner) and inner[0] == ForwardRef(cls.__name__):
212+
sqlmodel_cls.__annotations__[field.name] = List[sqlmodel_cls]
213+
needs_rebuild = True
195214

196215
# Replace Optional[T] with Optional[TSQLModel] if T is a dataclass
197216
if origin == Union and len(inner) and is_dataclass(inner[0]):

tests/test_self_reference.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from dataclasses import field
2+
from typing import List, Optional
3+
4+
from sqlalchemy import create_engine
5+
from sqlalchemy.orm import sessionmaker
6+
from sqlmodel import SQLModel
7+
8+
from fquery.sqlmodel import SQL_PK, many_to_one, one_to_many, sqlmodel
9+
10+
11+
@sqlmodel
12+
class Topic:
13+
id: Optional[int] = field(default=None, **SQL_PK)
14+
name: str
15+
description: Optional[str] = None
16+
wikidata_id: Optional[str] = None
17+
probability: Optional[float] = None
18+
level: int
19+
combined_prob: Optional[float] = None
20+
parent: Optional["Topic"] = many_to_one(
21+
"TopicSQLModel.id", back_populates="children"
22+
)
23+
children: List["Topic"] = one_to_many(back_populates="parent")
24+
25+
26+
def test_self_reference():
27+
engine = create_engine("duckdb:///:memory:", echo=False)
28+
SQLModel.metadata.create_all(engine)
29+
Session = sessionmaker(bind=engine)
30+
31+
with Session() as session:
32+
session.query(Topic.__sqlmodel__).all()

0 commit comments

Comments
 (0)