diff --git a/sqlalchemy_crud_plus/crud.py b/sqlalchemy_crud_plus/crud.py index 2eb7eeb..9dce0e7 100644 --- a/sqlalchemy_crud_plus/crud.py +++ b/sqlalchemy_crud_plus/crud.py @@ -240,7 +240,7 @@ async def select_model( load_strategies: LoadStrategies | None = None, join_conditions: JoinConditions | None = None, **kwargs: Any, - ) -> Model | None: + ) -> Sequence[Row[tuple[Model, ...]] | None] | Model | None: """ Query by primary key(s) with optional relationship loading and joins. @@ -276,8 +276,7 @@ async def select_model( if join_conditions: if has_join_fill_result(join_conditions): - result = query.first() - return result[0] if result else None + return query.first() return query.scalars().first() @@ -289,7 +288,7 @@ async def select_model_by_column( load_strategies: LoadStrategies | None = None, join_conditions: JoinConditions | None = None, **kwargs: Any, - ) -> Model | None: + ) -> Sequence[Row[tuple[Model, ...]] | None] | Model | None: """ Query by column with optional relationship loading and joins. @@ -313,8 +312,7 @@ async def select_model_by_column( if join_conditions: if has_join_fill_result(join_conditions): - result = query.first() - return result[0] if result else None + return query.first() return query.scalars().first() diff --git a/sqlalchemy_crud_plus/types.py b/sqlalchemy_crud_plus/types.py index 3d40b03..7829d2f 100644 --- a/sqlalchemy_crud_plus/types.py +++ b/sqlalchemy_crud_plus/types.py @@ -5,6 +5,7 @@ from typing import Any, Literal, TypeVar from pydantic import BaseModel, ConfigDict, Field +from sqlalchemy import Alias, Table from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm.util import AliasedClass from sqlalchemy.sql.base import ExecutableOption @@ -56,7 +57,9 @@ class JoinConfig(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - model: type[Model] | AliasedClass = Field(description='The target model or aliased class to join with') + model: type[Model] | AliasedClass | Alias | Table = Field( + description='The target model, aliased class, alias, or table to join with' + ) join_on: Any = Field(description='The join condition expression (e.g., model.id == other_model.id)') join_type: JoinType = Field(default='left', description='The type of join to perform') fill_result: bool = Field(default=False, description='Whether to populate this model to the query result') diff --git a/tests/test_no_relationship.py b/tests/test_no_relationship.py index d288fdd..c8c2ebe 100644 --- a/tests/test_no_relationship.py +++ b/tests/test_no_relationship.py @@ -530,7 +530,10 @@ async def test_join_fill_result_single_model( ) if result is not None: - assert isinstance(result, NoRelUser) + assert isinstance(result, (tuple, Row)) + assert isinstance(result[0], NoRelUser) + if result[1]: + assert isinstance(result[1], NoRelProfile) @pytest.mark.asyncio