|
52 | 52 | from sqlalchemy.orm.instrumentation import is_instrumented |
53 | 53 | from sqlalchemy.sql.schema import MetaData |
54 | 54 | from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid |
55 | | -from typing_extensions import Literal, TypeAlias, deprecated, get_origin |
| 55 | +from typing_extensions import Annotated, Literal, TypeAlias, deprecated, get_args, get_origin |
56 | 56 |
|
57 | 57 | from ._compat import ( # type: ignore[attr-defined] |
58 | 58 | IS_PYDANTIC_V2, |
@@ -475,6 +475,16 @@ def Relationship( |
475 | 475 | return relationship_info |
476 | 476 |
|
477 | 477 |
|
| 478 | +def get_annotated_relationshipinfo(t: Type) -> RelationshipInfo | None: |
| 479 | + """Get the first RelationshipInfo from Annotated or None if not Annotated with RelationshipInfo.""" |
| 480 | + if get_origin(t) is not Annotated: |
| 481 | + return None |
| 482 | + for a in get_args(t): |
| 483 | + if isinstance(a, RelationshipInfo): |
| 484 | + return a |
| 485 | + return None |
| 486 | + |
| 487 | + |
478 | 488 | @__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) |
479 | 489 | class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): |
480 | 490 | __sqlmodel_relationships__: Dict[str, RelationshipInfo] |
@@ -515,7 +525,12 @@ def __new__( |
515 | 525 | else: |
516 | 526 | dict_for_pydantic[k] = v |
517 | 527 | for k, v in original_annotations.items(): |
518 | | - if k in relationships: |
| 528 | + # check for `field: Annotated[Any, Relationship()]` |
| 529 | + t = get_annotated_relationshipinfo(v) |
| 530 | + if t: |
| 531 | + relationships[k] = t |
| 532 | + relationship_annotations[k] = get_args(v)[0] |
| 533 | + elif k in relationships: |
519 | 534 | relationship_annotations[k] = v |
520 | 535 | else: |
521 | 536 | pydantic_annotations[k] = v |
|
0 commit comments