Skip to content

Commit 0b2749e

Browse files
authored
fix: correctly handle id_attribute with update (#502)
Correctly merge attributes onto existing instance when using `id_attribute` and `update`
1 parent 9bd6b36 commit 0b2749e

File tree

10 files changed

+529
-156
lines changed

10 files changed

+529
-156
lines changed

advanced_alchemy/repository/_async.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
import datetime
23
import decimal
34
import random
@@ -24,12 +25,14 @@
2425
Update,
2526
any_,
2627
delete,
28+
inspect,
2729
over,
2830
select,
2931
text,
3032
update,
3133
)
3234
from sqlalchemy import func as sql_func
35+
from sqlalchemy.exc import MissingGreenlet, NoInspectionAvailable
3336
from sqlalchemy.ext.asyncio import AsyncSession
3437
from sqlalchemy.ext.asyncio.scoping import async_scoped_session
3538
from sqlalchemy.orm import InstrumentedAttribute
@@ -45,6 +48,7 @@
4548
FilterableRepository,
4649
FilterableRepositoryProtocol,
4750
LoadSpec,
51+
column_has_defaults,
4852
get_abstract_loader_options,
4953
get_instrumented_attr,
5054
)
@@ -1478,10 +1482,45 @@ async def update(
14781482
data,
14791483
id_attribute=id_attribute,
14801484
)
1481-
# this will raise for not found, and will put the item in the session
1482-
await self.get(item_id, id_attribute=id_attribute, load=load, execution_options=execution_options)
1483-
# this will merge the inbound data to the instance we just put in the session
1484-
instance = await self._attach_to_session(data, strategy="merge")
1485+
existing_instance = await self.get(
1486+
item_id, id_attribute=id_attribute, load=load, execution_options=execution_options
1487+
)
1488+
mapper = None
1489+
with (
1490+
self.session.no_autoflush,
1491+
contextlib.suppress(MissingGreenlet, NoInspectionAvailable),
1492+
):
1493+
mapper = inspect(data)
1494+
if mapper is not None:
1495+
for column in mapper.mapper.columns:
1496+
field_name = column.key
1497+
new_field_value = getattr(data, field_name, MISSING)
1498+
if new_field_value is not MISSING:
1499+
# Skip setting columns with defaults/onupdate to None during updates
1500+
# This prevents overwriting columns that should use their defaults
1501+
if new_field_value is None and column_has_defaults(column):
1502+
continue
1503+
existing_field_value = getattr(existing_instance, field_name, MISSING)
1504+
if existing_field_value is not MISSING and existing_field_value != new_field_value:
1505+
setattr(existing_instance, field_name, new_field_value)
1506+
1507+
# Handle relationships by merging objects into session first
1508+
for relationship in mapper.mapper.relationships:
1509+
if (new_value := getattr(data, relationship.key, MISSING)) is not MISSING:
1510+
if isinstance(new_value, list):
1511+
merged_values = [ # pyright: ignore
1512+
await self.session.merge(item, load=False) # pyright: ignore
1513+
for item in new_value # pyright: ignore
1514+
]
1515+
setattr(existing_instance, relationship.key, merged_values)
1516+
elif new_value is not None:
1517+
merged_value = await self.session.merge(new_value, load=False)
1518+
setattr(existing_instance, relationship.key, merged_value)
1519+
else:
1520+
setattr(existing_instance, relationship.key, new_value)
1521+
1522+
instance = await self._attach_to_session(existing_instance, strategy="merge")
1523+
14851524
await self._flush_or_commit(auto_commit=auto_commit)
14861525
await self._refresh(
14871526
instance,

advanced_alchemy/repository/_sync.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Do not edit this file directly. It has been autogenerated from
22
# advanced_alchemy/repository/_async.py
3+
import contextlib
34
import datetime
45
import decimal
56
import random
@@ -26,12 +27,14 @@
2627
Update,
2728
any_,
2829
delete,
30+
inspect,
2931
over,
3032
select,
3133
text,
3234
update,
3335
)
3436
from sqlalchemy import func as sql_func
37+
from sqlalchemy.exc import MissingGreenlet, NoInspectionAvailable
3538
from sqlalchemy.orm import InstrumentedAttribute, Session
3639
from sqlalchemy.orm.scoping import scoped_session
3740
from sqlalchemy.orm.strategy_options import _AbstractLoad # pyright: ignore[reportPrivateUsage]
@@ -46,6 +49,7 @@
4649
FilterableRepository,
4750
FilterableRepositoryProtocol,
4851
LoadSpec,
52+
column_has_defaults,
4953
get_abstract_loader_options,
5054
get_instrumented_attr,
5155
)
@@ -1479,10 +1483,45 @@ def update(
14791483
data,
14801484
id_attribute=id_attribute,
14811485
)
1482-
# this will raise for not found, and will put the item in the session
1483-
self.get(item_id, id_attribute=id_attribute, load=load, execution_options=execution_options)
1484-
# this will merge the inbound data to the instance we just put in the session
1485-
instance = self._attach_to_session(data, strategy="merge")
1486+
existing_instance = self.get(
1487+
item_id, id_attribute=id_attribute, load=load, execution_options=execution_options
1488+
)
1489+
mapper = None
1490+
with (
1491+
self.session.no_autoflush,
1492+
contextlib.suppress(MissingGreenlet, NoInspectionAvailable),
1493+
):
1494+
mapper = inspect(data)
1495+
if mapper is not None:
1496+
for column in mapper.mapper.columns:
1497+
field_name = column.key
1498+
new_field_value = getattr(data, field_name, MISSING)
1499+
if new_field_value is not MISSING:
1500+
# Skip setting columns with defaults/onupdate to None during updates
1501+
# This prevents overwriting columns that should use their defaults
1502+
if new_field_value is None and column_has_defaults(column):
1503+
continue
1504+
existing_field_value = getattr(existing_instance, field_name, MISSING)
1505+
if existing_field_value is not MISSING and existing_field_value != new_field_value:
1506+
setattr(existing_instance, field_name, new_field_value)
1507+
1508+
# Handle relationships by merging objects into session first
1509+
for relationship in mapper.mapper.relationships:
1510+
if (new_value := getattr(data, relationship.key, MISSING)) is not MISSING:
1511+
if isinstance(new_value, list):
1512+
merged_values = [ # pyright: ignore
1513+
self.session.merge(item, load=False) # pyright: ignore
1514+
for item in new_value # pyright: ignore
1515+
]
1516+
setattr(existing_instance, relationship.key, merged_values)
1517+
elif new_value is not None:
1518+
merged_value = self.session.merge(new_value, load=False)
1519+
setattr(existing_instance, relationship.key, merged_value)
1520+
else:
1521+
setattr(existing_instance, relationship.key, new_value)
1522+
1523+
instance = self._attach_to_session(existing_instance, strategy="merge")
1524+
14861525
self._flush_or_commit(auto_commit=auto_commit)
14871526
self._refresh(
14881527
instance,

advanced_alchemy/repository/_util.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,3 +339,26 @@ def _order_by_attribute(
339339
if not isinstance(statement, Select):
340340
return statement
341341
return statement.order_by(field.desc() if is_desc else field.asc())
342+
343+
344+
def column_has_defaults(column: Any) -> bool:
345+
"""Check if a column has any type of default value or update handler.
346+
347+
This includes:
348+
- Python-side defaults (column.default)
349+
- Server-side defaults (column.server_default)
350+
- Python-side onupdate handlers (column.onupdate)
351+
- Server-side onupdate handlers (column.server_onupdate)
352+
353+
Args:
354+
column: SQLAlchemy column object to check
355+
356+
Returns:
357+
bool: True if the column has any type of default or update handler
358+
"""
359+
return (
360+
column.default is not None
361+
or column.server_default is not None
362+
or column.onupdate is not None
363+
or column.server_onupdate is not None
364+
)

advanced_alchemy/service/_async.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
SQLAlchemyAsyncQueryRepository,
2525
)
2626
from advanced_alchemy.repository._util import LoadSpec, model_from_dict
27-
from advanced_alchemy.repository.typing import ModelT, OrderingPair, SQLAlchemyAsyncRepositoryT
27+
from advanced_alchemy.repository.typing import MISSING, ModelT, OrderingPair, SQLAlchemyAsyncRepositoryT
2828
from advanced_alchemy.service._util import ResultConverter
2929
from advanced_alchemy.service.typing import (
3030
BulkModelDictT,
@@ -713,22 +713,33 @@ async def update(
713713
Returns:
714714
Updated representation.
715715
"""
716-
data = await self.to_model(data, "update")
717-
if (
718-
item_id is None
719-
and self.repository.get_id_attribute_value( # pyright: ignore[reportUnknownMemberType]
720-
item=data,
721-
id_attribute=id_attribute,
722-
)
723-
is None
724-
):
725-
msg = (
726-
"Could not identify ID attribute value. One of the following is required: "
727-
f"``item_id`` or ``data.{id_attribute or self.repository.id_attribute}``"
716+
if is_dict(data) and item_id is not None:
717+
existing_instance = await self.repository.get(
718+
item_id, id_attribute=id_attribute, load=load, execution_options=execution_options
728719
)
729-
raise RepositoryError(msg)
730-
if item_id is not None:
731-
data = self.repository.set_id_attribute_value(item_id=item_id, item=data, id_attribute=id_attribute) # pyright: ignore[reportUnknownMemberType]
720+
update_data = await self.to_model_on_update(data)
721+
if is_dict(update_data):
722+
for key, value in update_data.items():
723+
if getattr(existing_instance, key, MISSING) is not MISSING:
724+
setattr(existing_instance, key, value)
725+
data = existing_instance
726+
else:
727+
data = await self.to_model(data, "update")
728+
if (
729+
item_id is None
730+
and self.repository.get_id_attribute_value( # pyright: ignore[reportUnknownMemberType]
731+
item=data,
732+
id_attribute=id_attribute,
733+
)
734+
is None
735+
):
736+
msg = (
737+
"Could not identify ID attribute value. One of the following is required: "
738+
f"``item_id`` or ``data.{id_attribute or self.repository.id_attribute}``"
739+
)
740+
raise RepositoryError(msg)
741+
if item_id is not None:
742+
data = self.repository.set_id_attribute_value(item_id=item_id, item=data, id_attribute=id_attribute) # pyright: ignore[reportUnknownMemberType]
732743
return cast(
733744
"ModelT",
734745
await self.repository.update(

advanced_alchemy/service/_sync.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from advanced_alchemy.filters import StatementFilter
2424
from advanced_alchemy.repository import SQLAlchemySyncQueryRepository
2525
from advanced_alchemy.repository._util import LoadSpec, model_from_dict
26-
from advanced_alchemy.repository.typing import ModelT, OrderingPair, SQLAlchemySyncRepositoryT
26+
from advanced_alchemy.repository.typing import MISSING, ModelT, OrderingPair, SQLAlchemySyncRepositoryT
2727
from advanced_alchemy.service._util import ResultConverter
2828
from advanced_alchemy.service.typing import (
2929
BulkModelDictT,
@@ -712,22 +712,33 @@ def update(
712712
Returns:
713713
Updated representation.
714714
"""
715-
data = self.to_model(data, "update")
716-
if (
717-
item_id is None
718-
and self.repository.get_id_attribute_value( # pyright: ignore[reportUnknownMemberType]
719-
item=data,
720-
id_attribute=id_attribute,
721-
)
722-
is None
723-
):
724-
msg = (
725-
"Could not identify ID attribute value. One of the following is required: "
726-
f"``item_id`` or ``data.{id_attribute or self.repository.id_attribute}``"
715+
if is_dict(data) and item_id is not None:
716+
existing_instance = self.repository.get(
717+
item_id, id_attribute=id_attribute, load=load, execution_options=execution_options
727718
)
728-
raise RepositoryError(msg)
729-
if item_id is not None:
730-
data = self.repository.set_id_attribute_value(item_id=item_id, item=data, id_attribute=id_attribute) # pyright: ignore[reportUnknownMemberType]
719+
update_data = self.to_model_on_update(data)
720+
if is_dict(update_data):
721+
for key, value in update_data.items():
722+
if getattr(existing_instance, key, MISSING) is not MISSING:
723+
setattr(existing_instance, key, value)
724+
data = existing_instance
725+
else:
726+
data = self.to_model(data, "update")
727+
if (
728+
item_id is None
729+
and self.repository.get_id_attribute_value( # pyright: ignore[reportUnknownMemberType]
730+
item=data,
731+
id_attribute=id_attribute,
732+
)
733+
is None
734+
):
735+
msg = (
736+
"Could not identify ID attribute value. One of the following is required: "
737+
f"``item_id`` or ``data.{id_attribute or self.repository.id_attribute}``"
738+
)
739+
raise RepositoryError(msg)
740+
if item_id is not None:
741+
data = self.repository.set_id_attribute_value(item_id=item_id, item=data, id_attribute=id_attribute) # pyright: ignore[reportUnknownMemberType]
731742
return cast(
732743
"ModelT",
733744
self.repository.update(

tests/integration/test_repository.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pytest_lazy_fixtures import lf
1919
from sqlalchemy import Engine, Table, and_, insert, select, text
2020
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
21-
from sqlalchemy.orm import Session, sessionmaker
21+
from sqlalchemy.orm import Session, selectinload, sessionmaker
2222
from time_machine import travel
2323

2424
from advanced_alchemy import base
@@ -2032,10 +2032,12 @@ async def test_lazy_load(
20322032
assert len(tags_to_add) > 0 # pyright: ignore
20332033
assert tags_to_add[0].id is not None # pyright: ignore
20342034
update_data["tags"] = tags_to_add # type: ignore[assignment]
2035-
updated_obj = await maybe_async(item_repo.update(item_model(**update_data), auto_refresh=False))
2035+
await maybe_async(item_repo.update(item_model(**update_data), load=[selectinload(item_repo.model_type.tags)]))
2036+
# Refresh the object to ensure tags are loaded before assertions
2037+
refreshed_obj = await maybe_async(item_repo.get(first_item_id, load=[selectinload(item_repo.model_type.tags)]))
20362038
await maybe_async(item_repo.session.commit())
2037-
assert len(updated_obj.tags) > 0
2038-
assert updated_obj.tags[0].name == "A new tag"
2039+
assert len(refreshed_obj.tags) > 0
2040+
assert refreshed_obj.tags[0].name == "A new tag"
20392041

20402042

20412043
async def test_repo_health_check(author_repo: AnyAuthorRepository) -> None:

0 commit comments

Comments
 (0)