Skip to content

Commit 368851d

Browse files
committed
fix overload and implementation of get_joined
1 parent d1c6e16 commit 368851d

File tree

3 files changed

+442
-3
lines changed

3 files changed

+442
-3
lines changed

fastcrud/crud/fast_crud.py

Lines changed: 100 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1401,10 +1401,91 @@ async def get_multi(
14011401

14021402
return response
14031403

1404+
@overload
1405+
async def get_joined(
1406+
self,
1407+
db: AsyncSession,
1408+
*,
1409+
schema_to_select: type[SelectSchemaType],
1410+
return_as_model: Literal[True],
1411+
join_model: Optional[ModelType] = None,
1412+
join_on: Optional[Union[Join, BinaryExpression]] = None,
1413+
join_prefix: Optional[str] = None,
1414+
join_schema_to_select: Optional[type[SelectSchemaType]] = None,
1415+
join_type: str = "left",
1416+
alias: Optional[AliasedClass] = None,
1417+
join_filters: Optional[dict] = None,
1418+
joins_config: Optional[list[JoinConfig]] = None,
1419+
nest_joins: bool = False,
1420+
relationship_type: Optional[str] = None,
1421+
**kwargs: Any,
1422+
) -> Optional[SelectSchemaType]: ...
1423+
1424+
@overload
1425+
async def get_joined(
1426+
self,
1427+
db: AsyncSession,
1428+
*,
1429+
schema_to_select: None = None,
1430+
return_as_model: Literal[False] = False,
1431+
join_model: Optional[ModelType] = None,
1432+
join_on: Optional[Union[Join, BinaryExpression]] = None,
1433+
join_prefix: Optional[str] = None,
1434+
join_schema_to_select: Optional[type[SelectSchemaType]] = None,
1435+
join_type: str = "left",
1436+
alias: Optional[AliasedClass] = None,
1437+
join_filters: Optional[dict] = None,
1438+
joins_config: Optional[list[JoinConfig]] = None,
1439+
nest_joins: bool = False,
1440+
relationship_type: Optional[str] = None,
1441+
**kwargs: Any,
1442+
) -> Optional[dict[str, Any]]: ...
1443+
1444+
@overload
1445+
async def get_joined(
1446+
self,
1447+
db: AsyncSession,
1448+
*,
1449+
schema_to_select: type[SelectSchemaType],
1450+
return_as_model: Literal[False] = False,
1451+
join_model: Optional[ModelType] = None,
1452+
join_on: Optional[Union[Join, BinaryExpression]] = None,
1453+
join_prefix: Optional[str] = None,
1454+
join_schema_to_select: Optional[type[SelectSchemaType]] = None,
1455+
join_type: str = "left",
1456+
alias: Optional[AliasedClass] = None,
1457+
join_filters: Optional[dict] = None,
1458+
joins_config: Optional[list[JoinConfig]] = None,
1459+
nest_joins: bool = False,
1460+
relationship_type: Optional[str] = None,
1461+
**kwargs: Any,
1462+
) -> Optional[dict[str, Any]]: ...
1463+
1464+
@overload
1465+
async def get_joined(
1466+
self,
1467+
db: AsyncSession,
1468+
*,
1469+
schema_to_select: Optional[type[SelectSchemaType]] = None,
1470+
return_as_model: bool = False,
1471+
join_model: Optional[ModelType] = None,
1472+
join_on: Optional[Union[Join, BinaryExpression]] = None,
1473+
join_prefix: Optional[str] = None,
1474+
join_schema_to_select: Optional[type[SelectSchemaType]] = None,
1475+
join_type: str = "left",
1476+
alias: Optional[AliasedClass] = None,
1477+
join_filters: Optional[dict] = None,
1478+
joins_config: Optional[list[JoinConfig]] = None,
1479+
nest_joins: bool = False,
1480+
relationship_type: Optional[str] = None,
1481+
**kwargs: Any,
1482+
) -> Optional[Union[dict[str, Any], SelectSchemaType]]: ...
1483+
14041484
async def get_joined(
14051485
self,
14061486
db: AsyncSession,
14071487
schema_to_select: Optional[type[SelectSchemaType]] = None,
1488+
return_as_model: bool = False,
14081489
join_model: Optional[ModelType] = None,
14091490
join_on: Optional[Union[Join, BinaryExpression]] = None,
14101491
join_prefix: Optional[str] = None,
@@ -1416,7 +1497,7 @@ async def get_joined(
14161497
nest_joins: bool = False,
14171498
relationship_type: Optional[str] = None,
14181499
**kwargs: Any,
1419-
) -> Optional[dict[str, Any]]:
1500+
) -> Optional[Union[dict[str, Any], SelectSchemaType]]:
14201501
"""
14211502
Fetches a single record with one or multiple joins on other models. If `join_on` is not provided, the method attempts
14221503
to automatically detect the join condition using foreign key relationships. For multiple joins, use `joins_config` to
@@ -1427,6 +1508,7 @@ async def get_joined(
14271508
Args:
14281509
db: The SQLAlchemy async session.
14291510
schema_to_select: Pydantic schema for selecting specific columns from the primary model. Required if `return_as_model` is True.
1511+
return_as_model: If `True`, returns data as a Pydantic model instance based on `schema_to_select`. Defaults to `False`.
14301512
join_model: The model to join with.
14311513
join_on: SQLAlchemy Join object for specifying the `ON` clause of the join. If `None`, the join condition is auto-detected based on foreign keys.
14321514
join_prefix: Optional prefix to be added to all columns of the joined model. If `None`, no prefix is added.
@@ -1440,7 +1522,10 @@ async def get_joined(
14401522
**kwargs: Filters to apply to the primary model query, supporting advanced comparison operators for refined searching.
14411523
14421524
Returns:
1443-
A dictionary representing the joined record, or `None` if no record matches the criteria.
1525+
A dictionary or Pydantic model instance representing the joined record, or `None` if no record matches the criteria:
1526+
1527+
- When `return_as_model=True` and `schema_to_select` is provided: `Optional[SelectSchemaType]`
1528+
- When `return_as_model=False`: `Optional[Dict[str, Any]]`
14441529
14451530
Raises:
14461531
ValueError: If both single join parameters and `joins_config` are used simultaneously.
@@ -1708,7 +1793,19 @@ async def get_joined(
17081793
else:
17091794
data_list = []
17101795

1711-
return process_joined_data(data_list, join_definitions, nest_joins, self.model)
1796+
processed_data = process_joined_data(
1797+
data_list, join_definitions, nest_joins, self.model
1798+
)
1799+
1800+
if processed_data is None or not return_as_model:
1801+
return processed_data
1802+
1803+
if not schema_to_select:
1804+
raise ValueError(
1805+
"schema_to_select must be provided when return_as_model is True."
1806+
)
1807+
1808+
return schema_to_select(**processed_data)
17121809

17131810
@overload
17141811
async def get_multi_joined(

tests/sqlalchemy/crud/test_get_joined.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import pytest
2+
from typing import Optional
3+
from pydantic import BaseModel
24
from sqlalchemy import and_
35
from fastcrud import FastCRUD, JoinConfig, aliased
46
from ...sqlalchemy.conftest import (
@@ -24,6 +26,22 @@
2426
)
2527

2628

29+
# Schema that includes joined fields for testing
30+
class JoinedTestTier(BaseModel):
31+
name: str
32+
tier_id: int
33+
tier_name: str
34+
35+
36+
# Flattened schema for task with joined fields
37+
class TaskWithJoinedData(BaseModel):
38+
id: int
39+
name: str
40+
description: Optional[str] = None
41+
assignee_name: Optional[str] = None
42+
department_name: Optional[str] = None
43+
44+
2745
@pytest.mark.asyncio
2846
async def test_get_joined_basic(async_session, test_data, test_data_tier):
2947
for tier_item in test_data_tier:
@@ -713,3 +731,156 @@ async def test_get_joined_nested_data_none_dict(async_session):
713731
assert task3_result["client"] is None, "Task 3 should have no client."
714732
assert task3_result["department"] is None, "Task 3 should have no department."
715733
assert task3_result["assignee"] is None, "Task 3 should have no assignee."
734+
735+
736+
@pytest.mark.asyncio
737+
async def test_get_joined_return_as_model_true(
738+
async_session, test_data, test_data_tier
739+
):
740+
"""Test get_joined with return_as_model=True returns Pydantic model instance."""
741+
for tier_item in test_data_tier:
742+
async_session.add(TierModel(**tier_item))
743+
await async_session.commit()
744+
745+
for user_item in test_data:
746+
async_session.add(ModelTest(**user_item))
747+
await async_session.commit()
748+
749+
crud = FastCRUD(ModelTest)
750+
result = await crud.get_joined(
751+
db=async_session,
752+
join_model=TierModel,
753+
join_prefix="tier_",
754+
schema_to_select=JoinedTestTier,
755+
join_schema_to_select=TierSchemaTest,
756+
return_as_model=True,
757+
)
758+
759+
assert result is not None
760+
assert isinstance(
761+
result, JoinedTestTier
762+
), "Result should be a JoinedTestTier Pydantic model instance"
763+
assert hasattr(result, "name"), "Result should have name attribute"
764+
assert hasattr(result, "tier_name"), "Result should have tier_name attribute"
765+
assert result.tier_name is not None, "tier_name should have a value"
766+
767+
768+
@pytest.mark.asyncio
769+
async def test_get_joined_return_as_model_false(
770+
async_session, test_data, test_data_tier
771+
):
772+
"""Test get_joined with return_as_model=False returns dict (default behavior)."""
773+
for tier_item in test_data_tier:
774+
async_session.add(TierModel(**tier_item))
775+
await async_session.commit()
776+
777+
for user_item in test_data:
778+
async_session.add(ModelTest(**user_item))
779+
await async_session.commit()
780+
781+
crud = FastCRUD(ModelTest)
782+
result = await crud.get_joined(
783+
db=async_session,
784+
join_model=TierModel,
785+
join_prefix="tier_",
786+
schema_to_select=ReadSchemaTest,
787+
join_schema_to_select=TierSchemaTest,
788+
return_as_model=False,
789+
)
790+
791+
assert result is not None
792+
assert isinstance(result, dict), "Result should be a dictionary"
793+
assert "name" in result
794+
assert "tier_name" in result
795+
796+
797+
@pytest.mark.asyncio
798+
async def test_get_joined_return_as_model_without_schema_raises_error(
799+
async_session, test_data, test_data_tier
800+
):
801+
"""Test get_joined with return_as_model=True but no schema raises ValueError."""
802+
for tier_item in test_data_tier:
803+
async_session.add(TierModel(**tier_item))
804+
await async_session.commit()
805+
806+
for user_item in test_data:
807+
async_session.add(ModelTest(**user_item))
808+
await async_session.commit()
809+
810+
crud = FastCRUD(ModelTest)
811+
with pytest.raises(ValueError) as exc_info:
812+
await crud.get_joined(
813+
db=async_session,
814+
join_model=TierModel,
815+
join_prefix="tier_",
816+
return_as_model=True, # Missing schema_to_select
817+
)
818+
819+
assert "schema_to_select must be provided when return_as_model is True" in str(
820+
exc_info.value
821+
)
822+
823+
824+
@pytest.mark.asyncio
825+
async def test_get_joined_return_as_model_with_joins_config(async_session):
826+
"""Test get_joined with return_as_model=True using joins_config."""
827+
# Create test data
828+
department = Department(name="Engineering")
829+
async_session.add(department)
830+
await async_session.flush()
831+
832+
user = User(
833+
name="John Doe",
834+
username="john",
835+
836+
department_id=department.id,
837+
)
838+
async_session.add(user)
839+
await async_session.flush()
840+
841+
task = Task(
842+
name="Test Task",
843+
description="Test Task",
844+
assignee_id=user.id,
845+
department_id=department.id,
846+
)
847+
async_session.add(task)
848+
await async_session.commit()
849+
850+
# Test with joins_config and return_as_model=True
851+
task_crud = FastCRUD(Task)
852+
joins_config = [
853+
JoinConfig(
854+
model=User,
855+
join_on=Task.assignee_id == User.id,
856+
join_prefix="assignee_",
857+
schema_to_select=UserReadSub,
858+
join_type="left",
859+
),
860+
JoinConfig(
861+
model=Department,
862+
join_on=Task.department_id == Department.id,
863+
join_prefix="department_",
864+
schema_to_select=DepartmentRead,
865+
join_type="left",
866+
),
867+
]
868+
869+
result = await task_crud.get_joined(
870+
db=async_session,
871+
id=task.id,
872+
schema_to_select=TaskWithJoinedData,
873+
joins_config=joins_config,
874+
return_as_model=True,
875+
)
876+
877+
assert result is not None
878+
assert isinstance(
879+
result, TaskWithJoinedData
880+
), "Result should be a TaskWithJoinedData Pydantic model instance"
881+
assert hasattr(result, "description"), "Result should have description attribute"
882+
assert result.description == "Test Task"
883+
assert hasattr(
884+
result, "assignee_name"
885+
), "Result should have assignee_name attribute"
886+
assert result.assignee_name == "John Doe"

0 commit comments

Comments
 (0)