Skip to content
Open
26 changes: 24 additions & 2 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,14 @@
from sqlalchemy.orm.instrumentation import is_instrumented
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid
from typing_extensions import Literal, TypeAlias, deprecated, get_origin
from typing_extensions import (
Annotated,
Literal,
TypeAlias,
deprecated,
get_args,
get_origin,
)

from ._compat import ( # type: ignore[attr-defined]
IS_PYDANTIC_V2,
Expand Down Expand Up @@ -473,6 +480,16 @@ def Relationship(
return relationship_info


def get_annotated_relationshipinfo(t: Any) -> Optional[RelationshipInfo]:
"""Get the first RelationshipInfo from Annotated or None if not Annotated with RelationshipInfo."""
if get_origin(t) is not Annotated:
return None
for a in get_args(t):
if isinstance(a, RelationshipInfo):
return a
return None


@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
__sqlmodel_relationships__: Dict[str, RelationshipInfo]
Expand Down Expand Up @@ -513,7 +530,12 @@ def __new__(
else:
dict_for_pydantic[k] = v
for k, v in original_annotations.items():
if k in relationships:
# check for `field: Annotated[Any, Relationship()]`
t = get_annotated_relationshipinfo(v)
if t:
relationships[k] = t
relationship_annotations[k] = get_args(v)[0]
elif k in relationships:
relationship_annotations[k] = v
else:
pydantic_annotations[k] = v
Expand Down