Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ repos:
- id: unasyncd
additional_dependencies: ["ruff"]
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.14.2"
rev: "v0.14.5"
hooks:
# Run the linter.
- id: ruff
Expand Down
37 changes: 35 additions & 2 deletions advanced_alchemy/repository/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from sqlalchemy.orm.strategy_options import _AbstractLoad # pyright: ignore[reportPrivateUsage]
from sqlalchemy.sql import ColumnElement
from sqlalchemy.sql.dml import ReturningDelete, ReturningUpdate
from sqlalchemy.sql.selectable import ForUpdateParameter
from sqlalchemy.sql.selectable import ForUpdateArg, ForUpdateParameter

from advanced_alchemy.exceptions import ErrorMessages, NotFoundError, RepositoryError, wrap_sqlalchemy_exception
from advanced_alchemy.filters import StatementFilter, StatementTypeT
Expand Down Expand Up @@ -202,6 +202,7 @@ async def get(
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
with_for_update: ForUpdateParameter = None,
) -> ModelT: ...

async def get_one(
Expand Down Expand Up @@ -1033,6 +1034,31 @@ def _get_base_stmt(
statement = cast("StatementTypeT", statement.execution_options(**execution_options))
return statement

def _apply_for_update_options(
self,
statement: Select[tuple[ModelT]],
with_for_update: ForUpdateParameter,
) -> Select[tuple[ModelT]]:
"""Apply FOR UPDATE options to a SELECT statement when requested."""

if with_for_update in (None, False):
return statement
if with_for_update is True:
return statement.with_for_update()
if isinstance(with_for_update, ForUpdateArg):
with_for_update_kwargs: dict[str, Any] = {
"nowait": with_for_update.nowait,
"read": with_for_update.read,
"skip_locked": with_for_update.skip_locked,
"key_share": with_for_update.key_share,
}
if getattr(with_for_update, "of", None):
with_for_update_kwargs["of"] = with_for_update.of
return statement.with_for_update(**with_for_update_kwargs)
if isinstance(with_for_update, dict): # pyright: ignore
return statement.with_for_update(**with_for_update)
return statement

def _get_delete_many_statement(
self,
*,
Expand Down Expand Up @@ -1071,6 +1097,7 @@ async def get(
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
with_for_update: ForUpdateParameter = None,
) -> ModelT:
"""Get instance identified by `item_id`.

Expand All @@ -1085,6 +1112,7 @@ async def get(
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
with_for_update: Optional FOR UPDATE clause / parameters to apply to the SELECT statement.

Returns:
The retrieved instance.
Expand All @@ -1107,6 +1135,7 @@ async def get(
execution_options=execution_options,
)
statement = self._filter_select_by_kwargs(statement, [(id_attribute, item_id)])
statement = self._apply_for_update_options(statement, with_for_update)
instance = (await self._execute(statement, uniquify=loader_options_have_wildcard)).scalar_one_or_none()
instance = self.check_not_found(instance)
self._expunge(instance, auto_expunge=auto_expunge)
Expand Down Expand Up @@ -1486,7 +1515,11 @@ async def update(
id_attribute=id_attribute,
)
existing_instance = await self.get(
item_id, id_attribute=id_attribute, load=load, execution_options=execution_options
item_id,
id_attribute=id_attribute,
load=load,
execution_options=execution_options,
with_for_update=with_for_update,
)
mapper = None
with (
Expand Down
37 changes: 35 additions & 2 deletions advanced_alchemy/repository/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from sqlalchemy.orm.strategy_options import _AbstractLoad # pyright: ignore[reportPrivateUsage]
from sqlalchemy.sql import ColumnElement
from sqlalchemy.sql.dml import ReturningDelete, ReturningUpdate
from sqlalchemy.sql.selectable import ForUpdateParameter
from sqlalchemy.sql.selectable import ForUpdateArg, ForUpdateParameter

from advanced_alchemy.exceptions import ErrorMessages, NotFoundError, RepositoryError, wrap_sqlalchemy_exception
from advanced_alchemy.filters import StatementFilter, StatementTypeT
Expand Down Expand Up @@ -203,6 +203,7 @@ def get(
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
with_for_update: ForUpdateParameter = None,
) -> ModelT: ...

def get_one(
Expand Down Expand Up @@ -1034,6 +1035,31 @@ def _get_base_stmt(
statement = cast("StatementTypeT", statement.execution_options(**execution_options))
return statement

def _apply_for_update_options(
self,
statement: Select[tuple[ModelT]],
with_for_update: ForUpdateParameter,
) -> Select[tuple[ModelT]]:
"""Apply FOR UPDATE options to a SELECT statement when requested."""

if with_for_update in (None, False):
return statement
if with_for_update is True:
return statement.with_for_update()
if isinstance(with_for_update, ForUpdateArg):
with_for_update_kwargs: dict[str, Any] = {
"nowait": with_for_update.nowait,
"read": with_for_update.read,
"skip_locked": with_for_update.skip_locked,
"key_share": with_for_update.key_share,
}
if getattr(with_for_update, "of", None):
with_for_update_kwargs["of"] = with_for_update.of
return statement.with_for_update(**with_for_update_kwargs)
if isinstance(with_for_update, dict): # pyright: ignore
return statement.with_for_update(**with_for_update)
return statement

def _get_delete_many_statement(
self,
*,
Expand Down Expand Up @@ -1072,6 +1098,7 @@ def get(
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
with_for_update: ForUpdateParameter = None,
) -> ModelT:
"""Get instance identified by `item_id`.

Expand All @@ -1086,6 +1113,7 @@ def get(
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
with_for_update: Optional FOR UPDATE clause / parameters to apply to the SELECT statement.

Returns:
The retrieved instance.
Expand All @@ -1108,6 +1136,7 @@ def get(
execution_options=execution_options,
)
statement = self._filter_select_by_kwargs(statement, [(id_attribute, item_id)])
statement = self._apply_for_update_options(statement, with_for_update)
instance = (self._execute(statement, uniquify=loader_options_have_wildcard)).scalar_one_or_none()
instance = self.check_not_found(instance)
self._expunge(instance, auto_expunge=auto_expunge)
Expand Down Expand Up @@ -1487,7 +1516,11 @@ def update(
id_attribute=id_attribute,
)
existing_instance = self.get(
item_id, id_attribute=id_attribute, load=load, execution_options=execution_options
item_id,
id_attribute=id_attribute,
load=load,
execution_options=execution_options,
with_for_update=with_for_update,
)
mapper = None
with (
Expand Down
1 change: 1 addition & 0 deletions advanced_alchemy/repository/memory/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ async def get(
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
with_for_update: ForUpdateParameter = None,
) -> ModelT:
return self._find_or_raise_not_found(item_id)

Expand Down
1 change: 1 addition & 0 deletions advanced_alchemy/repository/memory/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ def get(
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
with_for_update: ForUpdateParameter = None,
) -> ModelT:
return self._find_or_raise_not_found(item_id)

Expand Down
8 changes: 6 additions & 2 deletions advanced_alchemy/service/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,8 +746,12 @@ async def update(
if item_id is not None:
# When item_id is provided, update existing instance rather than replacing it
# This preserves relationships and database-managed fields
existing_instance = await self.repository.get(
item_id, id_attribute=id_attribute, load=load, execution_options=execution_options
existing_instance: ModelT = await self.repository.get(
item_id,
id_attribute=id_attribute,
load=load,
execution_options=execution_options,
with_for_update=with_for_update,
)

# Extract attributes from converted model to update existing instance
Expand Down
8 changes: 6 additions & 2 deletions advanced_alchemy/service/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,8 +745,12 @@ def update(
if item_id is not None:
# When item_id is provided, update existing instance rather than replacing it
# This preserves relationships and database-managed fields
existing_instance = self.repository.get(
item_id, id_attribute=id_attribute, load=load, execution_options=execution_options
existing_instance: ModelT = self.repository.get(
item_id,
id_attribute=id_attribute,
load=load,
execution_options=execution_options,
with_for_update=with_for_update,
)

# Extract attributes from converted model to update existing instance
Expand Down
14 changes: 14 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,16 @@ module = [
"dishka.*",
]

[[tool.mypy.overrides]]
follow_imports = "skip"
ignore_missing_imports = true
module = [
"pytest",
"pytest.*",
"_pytest",
"_pytest.*",
]

[[tool.mypy.overrides]]
module = "advanced_alchemy._serialization"
warn_unused_ignores = false
Expand Down Expand Up @@ -506,6 +516,10 @@ module = "examples.flask.*"
disable_error_code = "unreachable"
module = "tests.integration.test_repository"

[[tool.mypy.overrides]]
module = "tests.unit.test_exceptions"
warn_unreachable = false


[tool.pyright]
disableBytesTypePromotions = true
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_extensions/test_flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from collections.abc import Generator, Sequence
from pathlib import Path
from typing import cast

import pytest
from flask import Flask, Response
Expand Down Expand Up @@ -74,7 +75,7 @@ class Repo(SQLAlchemyAsyncRepository[User]):

@pytest.fixture(scope="session")
def tmp_path_session(tmp_path_factory: pytest.TempPathFactory) -> Path:
return tmp_path_factory.mktemp("test_extensions_flask")
return cast("Path", tmp_path_factory.mktemp("test_extensions_flask"))


@pytest.fixture(scope="session")
Expand Down
84 changes: 84 additions & 0 deletions tests/unit/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sqlalchemy.exc import InvalidRequestError, SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import InstrumentedAttribute, Mapped, Session, mapped_column
from sqlalchemy.sql.selectable import ForUpdateArg
from sqlalchemy.types import TypeEngine

from advanced_alchemy import base
Expand Down Expand Up @@ -378,6 +379,89 @@ async def test_sqlalchemy_repo_get_member(
mock_repo.session.commit.assert_not_called() # pyright: ignore[reportFunctionMemberAccess]


async def test_sqlalchemy_repo_get_with_for_update(
mock_repo: SQLAlchemyAsyncRepository[Any],
mocker: MockerFixture,
) -> None:
"""Ensure FOR UPDATE options are applied when requested."""

statement = MagicMock()
statement.options.return_value = statement
statement.execution_options.return_value = statement
statement.with_for_update.return_value = statement
mock_repo.statement = statement

mocker.patch.object(mock_repo, "_get_loader_options", return_value=([], False))
mocker.patch.object(mock_repo, "_get_base_stmt", return_value=statement)
mocker.patch.object(mock_repo, "_apply_filters", return_value=statement)
mocker.patch.object(mock_repo, "_filter_select_by_kwargs", return_value=statement)
execute_result = MagicMock()
execute_result.scalar_one_or_none.return_value = MagicMock()
execute = mocker.patch.object(mock_repo, "_execute", return_value=execute_result)

instance = await maybe_async(mock_repo.get("instance-id", with_for_update=True))

assert instance is execute_result.scalar_one_or_none.return_value
statement.with_for_update.assert_called_once_with()
execute.assert_called_once_with(statement, uniquify=False)


async def test_sqlalchemy_repo_get_with_for_update_dict(
mock_repo: SQLAlchemyAsyncRepository[Any],
mocker: MockerFixture,
) -> None:
statement = MagicMock()
statement.options.return_value = statement
statement.execution_options.return_value = statement
statement.with_for_update.return_value = statement
mock_repo.statement = statement

mocker.patch.object(mock_repo, "_get_loader_options", return_value=([], False))
mocker.patch.object(mock_repo, "_get_base_stmt", return_value=statement)
mocker.patch.object(mock_repo, "_apply_filters", return_value=statement)
mocker.patch.object(mock_repo, "_filter_select_by_kwargs", return_value=statement)
execute_result = MagicMock()
execute_result.scalar_one_or_none.return_value = MagicMock()
mocker.patch.object(mock_repo, "_execute", return_value=execute_result)

await maybe_async(
mock_repo.get(
"instance-id",
with_for_update={"nowait": True, "read": False},
)
)

statement.with_for_update.assert_called_once_with(nowait=True, read=False)


async def test_sqlalchemy_repo_get_with_for_update_arg(
mock_repo: SQLAlchemyAsyncRepository[Any],
mocker: MockerFixture,
) -> None:
statement = MagicMock()
statement.options.return_value = statement
statement.execution_options.return_value = statement
statement.with_for_update.return_value = statement
mock_repo.statement = statement

mocker.patch.object(mock_repo, "_get_loader_options", return_value=([], False))
mocker.patch.object(mock_repo, "_get_base_stmt", return_value=statement)
mocker.patch.object(mock_repo, "_apply_filters", return_value=statement)
mocker.patch.object(mock_repo, "_filter_select_by_kwargs", return_value=statement)
execute_result = MagicMock()
execute_result.scalar_one_or_none.return_value = MagicMock()
mocker.patch.object(mock_repo, "_execute", return_value=execute_result)

await maybe_async(
mock_repo.get(
"instance-id",
with_for_update=ForUpdateArg(nowait=True, key_share=True),
)
)

statement.with_for_update.assert_called_once_with(nowait=True, read=False, skip_locked=False, key_share=True)


async def test_sqlalchemy_repo_get_one_member(
mock_repo: SQLAlchemyAsyncRepository[Any],
monkeypatch: MonkeyPatch,
Expand Down
Loading