Skip to content

Commit 6ff6cfa

Browse files
authored
SQLAlchemyDTO: column/relationship type inference. (#1879)
If type annotations aren't available for a given column/relationship we make an effort to infer a type annotation from the relevant SQLAlchemy object. Closes #1853
1 parent 48a9719 commit 6ff6cfa

File tree

3 files changed

+133
-16
lines changed

3 files changed

+133
-16
lines changed

litestar/contrib/sqlalchemy/dto.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from functools import singledispatchmethod
4-
from typing import TYPE_CHECKING, Generic, TypeVar
4+
from typing import TYPE_CHECKING, Generic, Optional, TypeVar
55

66
from sqlalchemy import Column, inspect, orm, sql
77
from sqlalchemy.ext.associationproxy import AssociationProxy, AssociationProxyExtensionType
@@ -13,6 +13,7 @@
1313
Mapped,
1414
NotExtension,
1515
QueryableAttribute,
16+
RelationshipDirection,
1617
RelationshipProperty,
1718
)
1819

@@ -22,6 +23,7 @@
2223
from litestar.dto.factory.utils import get_model_type_hints
2324
from litestar.exceptions import ImproperlyConfiguredException
2425
from litestar.types.empty import Empty
26+
from litestar.typing import ParsedType
2527
from litestar.utils.helpers import get_fully_qualified_class_name
2628
from litestar.utils.signature import ParsedSignature
2729

@@ -30,8 +32,6 @@
3032

3133
from typing_extensions import TypeAlias
3234

33-
from litestar.typing import ParsedType
34-
3535
__all__ = ("SQLAlchemyDTO",)
3636

3737
T = TypeVar("T", bound="DeclarativeBase | Collection[DeclarativeBase]")
@@ -89,10 +89,8 @@ def _(
8989
(parsed_type,) = parsed_type.inner_types
9090
else:
9191
raise NotImplementedError(f"Expected 'Mapped' origin, got: '{parsed_type.origin}'")
92-
except KeyError as e:
93-
raise ImproperlyConfiguredException(
94-
f"No type information found for '{orm_descriptor}'. Has a type annotation been added to the column?"
95-
) from e
92+
except KeyError:
93+
parsed_type = parse_type_from_element(elem)
9694

9795
return [
9896
FieldDefinition(
@@ -221,6 +219,59 @@ def default_factory(d: Any = sqla_default) -> Any:
221219
else:
222220
raise ValueError("Unexpected default type")
223221
else:
224-
if getattr(elem, "nullable", False):
222+
if (
223+
isinstance(elem, RelationshipProperty)
224+
and detect_nullable_relationship(elem)
225+
or getattr(elem, "nullable", False)
226+
):
225227
default = None
228+
226229
return default, default_factory
230+
231+
232+
def parse_type_from_element(elem: ElementType) -> ParsedType:
233+
"""Parses a type from a SQLAlchemy element.
234+
235+
Args:
236+
elem: The SQLAlchemy element to parse.
237+
238+
Returns:
239+
ParsedType: The parsed type.
240+
241+
Raises:
242+
ImproperlyConfiguredException: If the type cannot be parsed.
243+
"""
244+
245+
if isinstance(elem, Column):
246+
if elem.nullable:
247+
return ParsedType(Optional[elem.type.python_type])
248+
return ParsedType(elem.type.python_type)
249+
250+
if isinstance(elem, RelationshipProperty):
251+
if elem.direction in (RelationshipDirection.ONETOMANY, RelationshipDirection.MANYTOMANY):
252+
collection_type = ParsedType(elem.collection_class or list)
253+
return ParsedType(collection_type.safe_generic_origin[elem.mapper.class_])
254+
255+
if detect_nullable_relationship(elem):
256+
return ParsedType(Optional[elem.mapper.class_])
257+
258+
return ParsedType(elem.mapper.class_)
259+
260+
raise ImproperlyConfiguredException(
261+
f"Unable to parse type from element '{elem}'. Consider adding a type hint.",
262+
)
263+
264+
265+
def detect_nullable_relationship(elem: RelationshipProperty) -> bool:
266+
"""Detects if a relationship is nullable.
267+
268+
This attempts to decide if we should allow a ``None`` default value for a relationship by looking at the
269+
foreign key fields. If all foreign key fields are nullable, then we allow a ``None`` default value.
270+
271+
Args:
272+
elem: The relationship to check.
273+
274+
Returns:
275+
bool: ``True`` if the relationship is nullable, ``False`` otherwise.
276+
"""
277+
return elem.direction == RelationshipDirection.MANYTOONE and all(c.nullable for c in elem.local_columns)

litestar/dto/factory/_backends/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def transfer_type_data(
279279
if isinstance(transfer_type, CollectionType):
280280
if transfer_type.has_nested:
281281
return transfer_nested_collection_type_data(
282-
transfer_type.parsed_type.origin, transfer_type, dto_for, source_value
282+
transfer_type.parsed_type.instantiable_origin, transfer_type, dto_for, source_value
283283
)
284284
return transfer_type.parsed_type.instantiable_origin(source_value)
285285
return source_value

tests/unit/test_contrib/test_sqlalchemy/test_dto.py

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77

88
import pytest
99
import sqlalchemy
10-
from sqlalchemy import func
11-
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, declared_attr, mapped_column
10+
from sqlalchemy import ForeignKey, func
11+
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, declared_attr, mapped_column, relationship
1212
from typing_extensions import Annotated
1313

14-
from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO
14+
from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO, parse_type_from_element
1515
from litestar.dto.factory import DTOConfig, DTOField, Mark
1616
from litestar.dto.factory.field import DTO_FIELD_META_KEY
1717
from litestar.dto.interface import ConnectionContext, HandlerContext
@@ -553,10 +553,76 @@ class A(Base):
553553
assert vars(model)["c"] == [1, 2, 3]
554554

555555

556-
async def test_no_type_hints(base: type[DeclarativeBase], connection_context: ConnectionContext) -> None:
556+
async def test_no_type_hint_column(base: type[DeclarativeBase], connection_context: ConnectionContext) -> None:
557557
class Model(base):
558-
field = mapped_column(sqlalchemy.String)
558+
nullable_field = mapped_column(sqlalchemy.String)
559+
not_nullable_field = mapped_column(sqlalchemy.String, nullable=False, default="")
559560

560561
dto_type = SQLAlchemyDTO[Annotated[Model, DTOConfig()]]
561-
with pytest.raises(ImproperlyConfiguredException, match="No type information found for 'Model.field'"):
562-
await get_model_from_dto(dto_type, Model, connection_context, b"")
562+
model = await get_model_from_dto(dto_type, Model, connection_context, b"{}")
563+
assert model.nullable_field is None
564+
assert model.not_nullable_field == ""
565+
566+
567+
async def test_no_type_hint_scalar_relationship_with_nullable_fk(
568+
base: type[DeclarativeBase], connection_context: ConnectionContext
569+
) -> None:
570+
class Child(base):
571+
...
572+
573+
class Model(base):
574+
child_id = mapped_column(ForeignKey("child.id"))
575+
child = relationship(Child)
576+
577+
dto_type = SQLAlchemyDTO[Annotated[Model, DTOConfig(exclude={"child_id"})]]
578+
model = await get_model_from_dto(dto_type, Model, connection_context, b"{}")
579+
assert model.child is None
580+
581+
582+
async def test_no_type_hint_scalar_relationship_with_not_nullable_fk(
583+
base: type[DeclarativeBase], connection_context: ConnectionContext
584+
) -> None:
585+
class Child(base):
586+
...
587+
588+
class Model(base):
589+
child_id = mapped_column(ForeignKey("child.id"), nullable=False)
590+
child = relationship(Child)
591+
592+
dto_type = SQLAlchemyDTO[Annotated[Model, DTOConfig(exclude={"child_id"})]]
593+
model = await get_model_from_dto(dto_type, Model, connection_context, b'{"child": {}}')
594+
assert isinstance(model.child, Child)
595+
596+
597+
async def test_no_type_hint_collection_relationship(
598+
base: type[DeclarativeBase], connection_context: ConnectionContext
599+
) -> None:
600+
class Child(base):
601+
model_id = mapped_column(ForeignKey("model.id"))
602+
603+
class Model(base):
604+
children = relationship(Child)
605+
606+
dto_type = SQLAlchemyDTO[Annotated[Model, DTOConfig()]]
607+
model = await get_model_from_dto(dto_type, Model, connection_context, b'{"children": []}')
608+
assert model.children == []
609+
610+
611+
async def test_no_type_hint_collection_relationship_alt_collection_class(
612+
base: type[DeclarativeBase], connection_context: ConnectionContext
613+
) -> None:
614+
class Child(base):
615+
model_id = mapped_column(ForeignKey("model.id"))
616+
617+
class Model(base):
618+
children = relationship(Child, collection_class=set)
619+
620+
dto_type = SQLAlchemyDTO[Annotated[Model, DTOConfig()]]
621+
model = await get_model_from_dto(dto_type, Model, connection_context, b'{"children": []}')
622+
assert model.children == set()
623+
624+
625+
def test_parse_type_from_element_failure() -> None:
626+
with pytest.raises(ImproperlyConfiguredException) as exc:
627+
parse_type_from_element(1)
628+
assert str(exc.value) == "500: Unable to parse type from element '1'. Consider adding a type hint."

0 commit comments

Comments
 (0)