@@ -97,6 +97,15 @@ def sqlmodel(self) -> SQLModel:
9797 attrs = {name : getattr (self , name ) for name in self .__sqlmodel__ .__fields__ }
9898 return self .__sqlmodel__ (** attrs )
9999
100+ def check_self_reference (clsname : str , field ):
101+ # Check if the field is a self-referential relationship
102+ if (
103+ field .type == ForwardRef (clsname )
104+ or field .type == Optional [ForwardRef (clsname )]
105+ ):
106+ return True
107+ return False
108+
100109 def get_field_def (cls , field ) -> Union [Field , Relationship ]:
101110 sql_meta = field .metadata .get ("SQL" , {})
102111 has_foreign_key = bool (sql_meta .get ("foreign_key" , None ))
@@ -129,7 +138,16 @@ def get_field_def(cls, field) -> Union[Field, Relationship]:
129138 back_populates = inflection .underscore (cls .__name__ )
130139 if sql_meta .get ("many_to_one" , False ):
131140 back_populates = inflection .pluralize (back_populates )
132- return Relationship (back_populates = back_populates )
141+
142+ key_column = sql_meta .get ("key_column" , None )
143+ self_reference = check_self_reference (cls .__name__ , field )
144+ sa_relationship_kwargs = (
145+ dict (remote_side = key_column ) if key_column and self_reference else None
146+ )
147+ return Relationship (
148+ back_populates = back_populates ,
149+ sa_relationship_kwargs = sa_relationship_kwargs ,
150+ )
133151 if has_foreign_key :
134152 return Field (default = None , foreign_key = sql_meta ["foreign_key" ])
135153 raise "Unsupported case"
@@ -169,20 +187,18 @@ def patch_back_populates_types(field, back_populates, cls, sqlmodel_cls):
169187 # TODO: log exception?
170188 pass
171189 inner = type_class .__args__ [0 ]
172- if isinstance (inner , ForwardRef ):
173- # can't patch right now. Try at a later time via back_populates
174- return
175- other_class = inner .__sqlmodel__
176- old = other_class .__annotations__ [back_populates ]
177- # Should be sqlalchemy.orm.base.Mapped[typing.List[ForwardRef('T')]]
178- # replace it with Mapped[List[sqlmodel_cls]]
179- origin = get_origin (old )
180- inner = get_args (old )
181- if origin == Mapped and len (inner ) and get_origin (inner [0 ]) is list :
182- other_class .__annotations__ [back_populates ] = Mapped [
183- List [sqlmodel_cls ]
184- ]
185- other_class .sqlmodel_rebuild ()
190+ if not isinstance (inner , ForwardRef ):
191+ other_class = inner .__sqlmodel__
192+ old = other_class .__annotations__ [back_populates ]
193+ # Should be sqlalchemy.orm.base.Mapped[typing.List[ForwardRef('T')]]
194+ # replace it with Mapped[List[sqlmodel_cls]]
195+ origin = get_origin (old )
196+ inner = get_args (old )
197+ if origin == Mapped and len (inner ) and get_origin (inner [0 ]) is list :
198+ other_class .__annotations__ [back_populates ] = Mapped [
199+ List [sqlmodel_cls ]
200+ ]
201+ other_class .sqlmodel_rebuild ()
186202
187203 # Replace Optional['T'] with Optional[TSQLModel]
188204 old = field .type
@@ -192,6 +208,9 @@ def patch_back_populates_types(field, back_populates, cls, sqlmodel_cls):
192208 if origin == Union and len (inner ) and inner [0 ] == ForwardRef (cls .__name__ ):
193209 sqlmodel_cls .__annotations__ [field .name ] = Optional [sqlmodel_cls ]
194210 needs_rebuild = True
211+ if origin == list and len (inner ) and inner [0 ] == ForwardRef (cls .__name__ ):
212+ sqlmodel_cls .__annotations__ [field .name ] = List [sqlmodel_cls ]
213+ needs_rebuild = True
195214
196215 # Replace Optional[T] with Optional[TSQLModel] if T is a dataclass
197216 if origin == Union and len (inner ) and is_dataclass (inner [0 ]):
0 commit comments