Skip to content

Commit cf01f24

Browse files
authored
fix: additional update and update_many corrections. (#537)
Updates `update` and `update_many` to properly handle relationships and returning support.
1 parent 25a37b1 commit cf01f24

File tree

11 files changed

+865
-306
lines changed

11 files changed

+865
-306
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ repos:
2222
- id: unasyncd
2323
additional_dependencies: ["ruff"]
2424
- repo: https://github.com/charliermarsh/ruff-pre-commit
25-
rev: "v0.12.10"
25+
rev: "v0.12.11"
2626
hooks:
2727
# Run the linter.
2828
- id: ruff

advanced_alchemy/repository/_async.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
get_instrumented_attr,
5454
)
5555
from advanced_alchemy.repository.typing import MISSING, ModelT, OrderingPair, T
56+
from advanced_alchemy.service.typing import schema_dump
5657
from advanced_alchemy.utils.dataclass import Empty, EmptyType
5758
from advanced_alchemy.utils.text import slugify
5859

@@ -1577,7 +1578,12 @@ async def update_many(
15771578
error_messages=error_messages,
15781579
default_messages=self.error_messages,
15791580
)
1580-
data_to_update: list[dict[str, Any]] = [v.to_dict() if isinstance(v, self.model_type) else v for v in data] # type: ignore[misc]
1581+
data_to_update: list[dict[str, Any]] = []
1582+
for v in data:
1583+
if isinstance(v, self.model_type) or (hasattr(v, "to_dict") and callable(v.to_dict)):
1584+
data_to_update.append(v.to_dict())
1585+
else:
1586+
data_to_update.append(cast("dict[str, Any]", schema_dump(v)))
15811587
with wrap_sqlalchemy_exception(
15821588
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
15831589
):
@@ -1606,7 +1612,17 @@ async def update_many(
16061612
return instances
16071613
await self.session.execute(statement, data_to_update, execution_options=execution_options)
16081614
await self._flush_or_commit(auto_commit=auto_commit)
1609-
return data
1615+
1616+
# For non-RETURNING backends, fetch updated instances from database
1617+
updated_ids: list[Any] = [item[self.id_attribute] for item in data_to_update]
1618+
updated_instances = await self.list(
1619+
getattr(self.model_type, self.id_attribute).in_(updated_ids),
1620+
load=loader_options,
1621+
execution_options=execution_options,
1622+
)
1623+
for instance in updated_instances:
1624+
self._expunge(instance, auto_expunge=auto_expunge)
1625+
return updated_instances
16101626

16111627
def _get_update_many_statement(
16121628
self,

advanced_alchemy/repository/_sync.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
get_instrumented_attr,
5555
)
5656
from advanced_alchemy.repository.typing import MISSING, ModelT, OrderingPair, T
57+
from advanced_alchemy.service.typing import schema_dump
5758
from advanced_alchemy.utils.dataclass import Empty, EmptyType
5859
from advanced_alchemy.utils.text import slugify
5960

@@ -1578,7 +1579,12 @@ def update_many(
15781579
error_messages=error_messages,
15791580
default_messages=self.error_messages,
15801581
)
1581-
data_to_update: list[dict[str, Any]] = [v.to_dict() if isinstance(v, self.model_type) else v for v in data] # type: ignore[misc]
1582+
data_to_update: list[dict[str, Any]] = []
1583+
for v in data:
1584+
if isinstance(v, self.model_type) or (hasattr(v, "to_dict") and callable(v.to_dict)):
1585+
data_to_update.append(v.to_dict())
1586+
else:
1587+
data_to_update.append(cast("dict[str, Any]", schema_dump(v)))
15821588
with wrap_sqlalchemy_exception(
15831589
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
15841590
):
@@ -1607,7 +1613,17 @@ def update_many(
16071613
return instances
16081614
self.session.execute(statement, data_to_update, execution_options=execution_options)
16091615
self._flush_or_commit(auto_commit=auto_commit)
1610-
return data
1616+
1617+
# For non-RETURNING backends, fetch updated instances from database
1618+
updated_ids: list[Any] = [item[self.id_attribute] for item in data_to_update]
1619+
updated_instances = self.list(
1620+
getattr(self.model_type, self.id_attribute).in_(updated_ids),
1621+
load=loader_options,
1622+
execution_options=execution_options,
1623+
)
1624+
for instance in updated_instances:
1625+
self._expunge(instance, auto_expunge=auto_expunge)
1626+
return updated_instances
16111627

16121628
def _get_update_many_statement(
16131629
self,

advanced_alchemy/service/_async.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,18 @@
2727
from advanced_alchemy.repository.typing import MISSING, ModelT, OrderingPair, SQLAlchemyAsyncRepositoryT
2828
from advanced_alchemy.service._util import ResultConverter
2929
from advanced_alchemy.service.typing import (
30+
UNSET,
3031
BulkModelDictT,
3132
ModelDictListT,
3233
ModelDictT,
3334
asdict,
35+
attrs_nothing,
3436
is_attrs_instance,
3537
is_dict,
3638
is_dto_data,
3739
is_msgspec_struct,
3840
is_pydantic_model,
41+
schema_dump,
3942
)
4043
from advanced_alchemy.utils.dataclass import Empty, EmptyType
4144

@@ -448,20 +451,26 @@ async def to_model(
448451
)
449452

450453
if is_msgspec_struct(data):
451-
from msgspec import UNSET
452-
453454
return model_from_dict(
454455
model=self.model_type,
455-
**{f: val for f in data.__struct_fields__ if (val := getattr(data, f, None)) != UNSET},
456+
**{
457+
f: getattr(data, f)
458+
for f in data.__struct_fields__
459+
if hasattr(data, f) and getattr(data, f) is not UNSET
460+
},
456461
)
457462

458463
if is_dto_data(data):
459464
return cast("ModelT", data.create_instance())
460465

461466
if is_attrs_instance(data):
467+
# Filter out attrs.NOTHING values for partial updates
468+
def filter_unset(attr: Any, value: Any) -> bool: # noqa: ARG001
469+
return value is not attrs_nothing
470+
462471
return model_from_dict(
463472
model=self.model_type,
464-
**asdict(data),
473+
**asdict(data, filter=filter_unset),
465474
)
466475

467476
# Fallback for objects with __dict__ (e.g., regular classes)
@@ -729,11 +738,16 @@ async def update(
729738
Returns:
730739
Updated representation.
731740
"""
732-
if is_dict(data) and item_id is not None:
741+
if (
742+
is_dict(data) or is_pydantic_model(data) or is_msgspec_struct(data) or is_attrs_instance(data)
743+
) and item_id is not None:
733744
existing_instance = await self.repository.get(
734745
item_id, id_attribute=id_attribute, load=load, execution_options=execution_options
735746
)
736-
update_data = await self.to_model_on_update(data)
747+
update_data = (
748+
await self.to_model_on_update(data) if is_dict(data) else schema_dump(data, exclude_unset=True)
749+
)
750+
737751
if is_dict(update_data):
738752
for key, value in update_data.items():
739753
if getattr(existing_instance, key, MISSING) is not MISSING:

advanced_alchemy/service/_sync.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,18 @@
2626
from advanced_alchemy.repository.typing import MISSING, ModelT, OrderingPair, SQLAlchemySyncRepositoryT
2727
from advanced_alchemy.service._util import ResultConverter
2828
from advanced_alchemy.service.typing import (
29+
UNSET,
2930
BulkModelDictT,
3031
ModelDictListT,
3132
ModelDictT,
3233
asdict,
34+
attrs_nothing,
3335
is_attrs_instance,
3436
is_dict,
3537
is_dto_data,
3638
is_msgspec_struct,
3739
is_pydantic_model,
40+
schema_dump,
3841
)
3942
from advanced_alchemy.utils.dataclass import Empty, EmptyType
4043

@@ -447,20 +450,26 @@ def to_model(
447450
)
448451

449452
if is_msgspec_struct(data):
450-
from msgspec import UNSET
451-
452453
return model_from_dict(
453454
model=self.model_type,
454-
**{f: val for f in data.__struct_fields__ if (val := getattr(data, f, None)) != UNSET},
455+
**{
456+
f: getattr(data, f)
457+
for f in data.__struct_fields__
458+
if hasattr(data, f) and getattr(data, f) is not UNSET
459+
},
455460
)
456461

457462
if is_dto_data(data):
458463
return cast("ModelT", data.create_instance())
459464

460465
if is_attrs_instance(data):
466+
# Filter out attrs.NOTHING values for partial updates
467+
def filter_unset(attr: Any, value: Any) -> bool: # noqa: ARG001
468+
return value is not attrs_nothing
469+
461470
return model_from_dict(
462471
model=self.model_type,
463-
**asdict(data),
472+
**asdict(data, filter=filter_unset),
464473
)
465474

466475
# Fallback for objects with __dict__ (e.g., regular classes)
@@ -728,11 +737,14 @@ def update(
728737
Returns:
729738
Updated representation.
730739
"""
731-
if is_dict(data) and item_id is not None:
740+
if (
741+
is_dict(data) or is_pydantic_model(data) or is_msgspec_struct(data) or is_attrs_instance(data)
742+
) and item_id is not None:
732743
existing_instance = self.repository.get(
733744
item_id, id_attribute=id_attribute, load=load, execution_options=execution_options
734745
)
735-
update_data = self.to_model_on_update(data)
746+
update_data = self.to_model_on_update(data) if is_dict(data) else schema_dump(data, exclude_unset=True)
747+
736748
if is_dict(update_data):
737749
for key, value in update_data.items():
738750
if getattr(existing_instance, key, MISSING) is not MISSING:

advanced_alchemy/service/_typing.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,19 @@ def attrs_has_stub(*args: Any, **kwargs: Any) -> bool: # noqa: ARG001
255255
return False
256256

257257

258+
class AttrsNothingStub:
259+
"""Placeholder for attrs.NOTHING sentinel value"""
260+
261+
def __repr__(self) -> str:
262+
return "NOTHING"
263+
264+
265+
ATTRS_NOTHING_STUB = AttrsNothingStub()
266+
267+
258268
# Try to import real implementations at runtime
259269
try:
270+
from attrs import NOTHING as _real_attrs_nothing # noqa: N811
260271
from attrs import AttrsInstance as _RealAttrsInstance # pyright: ignore
261272
from attrs import asdict as _real_attrs_asdict
262273
from attrs import define as _real_attrs_define
@@ -270,6 +281,7 @@ def attrs_has_stub(*args: Any, **kwargs: Any) -> bool: # noqa: ARG001
270281
attrs_field = _real_attrs_field
271282
attrs_fields = _real_attrs_fields
272283
attrs_has = _real_attrs_has
284+
attrs_nothing = _real_attrs_nothing
273285
ATTRS_INSTALLED = True # pyright: ignore[reportConstantRedefinition]
274286
except ImportError:
275287
AttrsInstance = AttrsLike # type: ignore[misc]
@@ -278,6 +290,7 @@ def attrs_has_stub(*args: Any, **kwargs: Any) -> bool: # noqa: ARG001
278290
attrs_field = attrs_field_stub
279291
attrs_fields = attrs_fields_stub
280292
attrs_has = attrs_has_stub # type: ignore[assignment]
293+
attrs_nothing = ATTRS_NOTHING_STUB # type: ignore[assignment]
281294
ATTRS_INSTALLED = False # pyright: ignore[reportConstantRedefinition]
282295

283296
try:
@@ -309,6 +322,7 @@ class EmptyEnum(enum.Enum):
309322

310323
__all__ = (
311324
"ATTRS_INSTALLED",
325+
"ATTRS_NOTHING_STUB",
312326
"CATTRS_INSTALLED",
313327
"LITESTAR_INSTALLED",
314328
"MSGSPEC_INSTALLED",
@@ -317,6 +331,7 @@ class EmptyEnum(enum.Enum):
317331
"UNSET_STUB",
318332
"AttrsInstance",
319333
"AttrsLike",
334+
"AttrsNothingStub",
320335
"BaseModel",
321336
"BaseModelLike",
322337
"DTOData",
@@ -346,6 +361,7 @@ class EmptyEnum(enum.Enum):
346361
"attrs_fields_stub",
347362
"attrs_has",
348363
"attrs_has_stub",
364+
"attrs_nothing",
349365
"cattrs_structure",
350366
"cattrs_unstructure",
351367
"convert",

advanced_alchemy/service/typing.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
T,
3939
TypeAdapter,
4040
UnsetType,
41+
attrs_nothing,
4142
convert,
4243
)
4344
from advanced_alchemy.service._typing import attrs_asdict as asdict
@@ -480,9 +481,20 @@ def schema_dump(
480481
return data.model_dump(exclude_unset=exclude_unset)
481482
if is_msgspec_struct(data):
482483
if exclude_unset:
483-
return {f: val for f in data.__struct_fields__ if (val := getattr(data, f, None)) != UNSET}
484+
return {
485+
f: getattr(data, f)
486+
for f in data.__struct_fields__
487+
if hasattr(data, f) and getattr(data, f) is not UNSET
488+
}
484489
return {f: getattr(data, f, None) for f in data.__struct_fields__}
485490
if is_attrs_instance(data):
491+
if exclude_unset:
492+
# Filter out attrs.NOTHING values for partial updates
493+
def filter_unset_attrs(attr: Any, value: Any) -> bool: # noqa: ARG001
494+
return value is not attrs_nothing
495+
496+
return asdict(data, filter=filter_unset_attrs)
497+
486498
# Use cattrs for enhanced performance and type-aware serialization when available
487499
if CATTRS_INSTALLED:
488500
return unstructure(data) # type: ignore[no-any-return]
@@ -522,6 +534,7 @@ def schema_dump(
522534
"TypeAdapter",
523535
"UnsetType",
524536
"asdict",
537+
"attrs_nothing",
525538
"convert",
526539
"fields",
527540
"get_attrs_fields",

0 commit comments

Comments
 (0)