Skip to content

Commit 725e26a

Browse files
authored
feat: handle and compare numpy arrays (#550)
Direct equality comparisons (`!=`) with numpy arrays in repository update methods raised `ValueError: The truth value of an array with more than one element is ambiguous` Adds a safe comparison utility that handle numpy arrays gracefully
1 parent 11f5b68 commit 725e26a

File tree

9 files changed

+757
-179
lines changed

9 files changed

+757
-179
lines changed

advanced_alchemy/repository/_async.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
FilterableRepositoryProtocol,
5050
LoadSpec,
5151
column_has_defaults,
52+
compare_values,
5253
get_abstract_loader_options,
5354
get_instrumented_attr,
5455
)
@@ -1293,7 +1294,7 @@ async def get_or_upsert(
12931294
if upsert:
12941295
for field_name, new_field_value in kwargs.items():
12951296
field = getattr(existing, field_name, MISSING)
1296-
if field is not MISSING and field != new_field_value:
1297+
if field is not MISSING and not compare_values(field, new_field_value): # pragma: no cover
12971298
setattr(existing, field_name, new_field_value)
12981299
existing = await self._attach_to_session(existing, strategy="merge")
12991300
await self._flush_or_commit(auto_commit=auto_commit)
@@ -1367,7 +1368,7 @@ async def get_and_update(
13671368
updated = False
13681369
for field_name, new_field_value in kwargs.items():
13691370
field = getattr(existing, field_name, MISSING)
1370-
if field is not MISSING and field != new_field_value:
1371+
if field is not MISSING and not compare_values(field, new_field_value): # pragma: no cover
13711372
updated = True
13721373
setattr(existing, field_name, new_field_value)
13731374
existing = await self._attach_to_session(existing, strategy="merge")
@@ -1502,7 +1503,9 @@ async def update(
15021503
if new_field_value is None and column_has_defaults(column):
15031504
continue
15041505
existing_field_value = getattr(existing_instance, field_name, MISSING)
1505-
if existing_field_value is not MISSING and existing_field_value != new_field_value:
1506+
if existing_field_value is not MISSING and not compare_values(
1507+
existing_field_value, new_field_value
1508+
):
15061509
setattr(existing_instance, field_name, new_field_value)
15071510

15081511
# Handle relationships by merging objects into session first
@@ -1940,7 +1943,7 @@ async def upsert(
19401943
):
19411944
for field_name, new_field_value in data.to_dict(exclude={self.id_attribute}).items():
19421945
field = getattr(existing, field_name, MISSING)
1943-
if field is not MISSING and field != new_field_value:
1946+
if field is not MISSING and not compare_values(field, new_field_value): # pragma: no cover
19441947
setattr(existing, field_name, new_field_value)
19451948
instance = await self._attach_to_session(existing, strategy="merge")
19461949
await self._flush_or_commit(auto_commit=auto_commit)

advanced_alchemy/repository/_sync.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
FilterableRepositoryProtocol,
5151
LoadSpec,
5252
column_has_defaults,
53+
compare_values,
5354
get_abstract_loader_options,
5455
get_instrumented_attr,
5556
)
@@ -1294,7 +1295,7 @@ def get_or_upsert(
12941295
if upsert:
12951296
for field_name, new_field_value in kwargs.items():
12961297
field = getattr(existing, field_name, MISSING)
1297-
if field is not MISSING and field != new_field_value:
1298+
if field is not MISSING and not compare_values(field, new_field_value): # pragma: no cover
12981299
setattr(existing, field_name, new_field_value)
12991300
existing = self._attach_to_session(existing, strategy="merge")
13001301
self._flush_or_commit(auto_commit=auto_commit)
@@ -1368,7 +1369,7 @@ def get_and_update(
13681369
updated = False
13691370
for field_name, new_field_value in kwargs.items():
13701371
field = getattr(existing, field_name, MISSING)
1371-
if field is not MISSING and field != new_field_value:
1372+
if field is not MISSING and not compare_values(field, new_field_value): # pragma: no cover
13721373
updated = True
13731374
setattr(existing, field_name, new_field_value)
13741375
existing = self._attach_to_session(existing, strategy="merge")
@@ -1503,7 +1504,9 @@ def update(
15031504
if new_field_value is None and column_has_defaults(column):
15041505
continue
15051506
existing_field_value = getattr(existing_instance, field_name, MISSING)
1506-
if existing_field_value is not MISSING and existing_field_value != new_field_value:
1507+
if existing_field_value is not MISSING and not compare_values(
1508+
existing_field_value, new_field_value
1509+
):
15071510
setattr(existing_instance, field_name, new_field_value)
15081511

15091512
# Handle relationships by merging objects into session first
@@ -1939,7 +1942,7 @@ def upsert(
19391942
):
19401943
for field_name, new_field_value in data.to_dict(exclude={self.id_attribute}).items():
19411944
field = getattr(existing, field_name, MISSING)
1942-
if field is not MISSING and field != new_field_value:
1945+
if field is not MISSING and not compare_values(field, new_field_value): # pragma: no cover
19431946
setattr(existing, field_name, new_field_value)
19441947
instance = self._attach_to_session(existing, strategy="merge")
19451948
self._flush_or_commit(auto_commit=auto_commit)
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
"""Repository typing utilities for optional dependency support.
2+
3+
Provides stubs and detection functions for numpy arrays to support
4+
pgvector and other array-based types when numpy is not installed.
5+
"""
6+
7+
from typing import Any
8+
9+
# Always define stub functions for type checking and fallback behavior
10+
11+
12+
def is_numpy_array_stub(value: Any) -> bool: # pragma: no cover
13+
"""Check if value has numpy array-like characteristics (fallback implementation).
14+
15+
When numpy is not installed, this checks for basic array-like attributes
16+
that indicate the value might be an array that needs special comparison handling.
17+
18+
Args:
19+
value: Value to check.
20+
21+
Returns:
22+
bool: True if value appears to be array-like.
23+
"""
24+
return hasattr(value, "__array__") and hasattr(value, "dtype") # pragma: no cover
25+
26+
27+
def arrays_equal_stub(a: Any, b: Any) -> bool:
28+
"""Fallback array equality comparison when numpy is not installed.
29+
30+
When numpy is not available, we can't properly compare arrays,
31+
so we default to considering them different to trigger updates.
32+
This ensures safety but may cause unnecessary updates.
33+
34+
Args:
35+
a: First value to compare.
36+
b: Second value to compare.
37+
38+
Returns:
39+
bool: Always False when numpy is not available.
40+
"""
41+
_, _ = a, b # Unused parameters # pragma: no cover
42+
return False # pragma: no cover
43+
44+
45+
# Try to import real numpy implementation at runtime
46+
try:
47+
import numpy as np # type: ignore[import-not-found,unused-ignore] # pyright: ignore[reportMissingImports]
48+
49+
def is_numpy_array_real(value: Any) -> bool:
50+
"""Check if value is a numpy array.
51+
52+
Args:
53+
value: Value to check.
54+
55+
Returns:
56+
bool: True if value is a numpy ndarray.
57+
"""
58+
return isinstance(value, np.ndarray) # pyright: ignore
59+
60+
def arrays_equal_real(a: Any, b: Any) -> bool:
61+
"""Compare numpy arrays for equality.
62+
63+
Uses numpy.array_equal for proper array comparison.
64+
65+
Args:
66+
a: First array to compare.
67+
b: Second array to compare.
68+
69+
Returns:
70+
bool: True if arrays are equal.
71+
"""
72+
return bool(np.array_equal(a, b)) # pyright: ignore
73+
74+
is_numpy_array = is_numpy_array_real
75+
arrays_equal = arrays_equal_real
76+
NUMPY_INSTALLED = True # pyright: ignore[reportConstantRedefinition]
77+
78+
except ImportError: # pragma: no cover
79+
is_numpy_array = is_numpy_array_stub
80+
arrays_equal = arrays_equal_stub
81+
NUMPY_INSTALLED = False # pyright: ignore[reportConstantRedefinition]
82+
83+
84+
__all__ = (
85+
"NUMPY_INSTALLED",
86+
"arrays_equal",
87+
"arrays_equal_stub",
88+
"is_numpy_array",
89+
"is_numpy_array_stub",
90+
)

advanced_alchemy/repository/_util.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
StatementFilter,
3535
StatementTypeT,
3636
)
37+
from advanced_alchemy.repository._typing import arrays_equal, is_numpy_array
3738
from advanced_alchemy.repository.typing import ModelT, OrderingPair
3839

3940
WhereClauseT = ColumnExpressionArgument[bool]
@@ -368,3 +369,39 @@ def column_has_defaults(column: Any) -> bool:
368369
or getattr(column, "onupdate", None) is not None
369370
or getattr(column, "server_onupdate", None) is not None
370371
)
372+
373+
374+
def compare_values(existing_value: Any, new_value: Any) -> bool:
375+
"""Safely compare two values, handling numpy arrays and other special types.
376+
377+
This function handles the comparison of values that may include numpy arrays
378+
(such as pgvector's Vector type) which cannot be directly compared using
379+
standard equality operators due to their element-wise comparison behavior.
380+
381+
Args:
382+
existing_value: The current value to compare.
383+
new_value: The new value to compare against.
384+
385+
Returns:
386+
bool: True if values are equal, False otherwise.
387+
"""
388+
# Handle None comparisons
389+
if existing_value is None and new_value is None:
390+
return True
391+
if existing_value is None or new_value is None:
392+
return False
393+
394+
# Handle numpy arrays or array-like objects
395+
if is_numpy_array(existing_value) or is_numpy_array(new_value):
396+
# Both values must be arrays for them to be considered equal
397+
if not (is_numpy_array(existing_value) and is_numpy_array(new_value)):
398+
return False
399+
return arrays_equal(existing_value, new_value)
400+
401+
# Standard equality comparison for all other types
402+
try:
403+
return bool(existing_value == new_value)
404+
except (ValueError, TypeError):
405+
# If comparison fails for any reason, consider them different
406+
# This is a safe fallback that will trigger updates when unsure
407+
return False

advanced_alchemy/repository/memory/_async.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
StatementFilter,
3636
)
3737
from advanced_alchemy.repository._async import SQLAlchemyAsyncRepositoryProtocol, SQLAlchemyAsyncSlugRepositoryProtocol
38-
from advanced_alchemy.repository._util import DEFAULT_ERROR_MESSAGE_TEMPLATES, LoadSpec
38+
from advanced_alchemy.repository._util import DEFAULT_ERROR_MESSAGE_TEMPLATES, LoadSpec, compare_values
3939
from advanced_alchemy.repository.memory.base import (
4040
AnyObject,
4141
InMemoryStore,
@@ -490,7 +490,7 @@ async def get_or_upsert(
490490
if upsert:
491491
for field_name, new_field_value in kwargs_.items():
492492
field = getattr(existing, field_name, MISSING)
493-
if field is not MISSING and field != new_field_value:
493+
if field is not MISSING and not compare_values(field, new_field_value): # pragma: no cover
494494
setattr(existing, field_name, new_field_value)
495495
existing = await self.update(existing)
496496
return existing, False
@@ -524,7 +524,7 @@ async def get_and_update(
524524
updated = False
525525
for field_name, new_field_value in kwargs_.items():
526526
field = getattr(existing, field_name, MISSING)
527-
if field is not MISSING and field != new_field_value:
527+
if field is not MISSING and not compare_values(field, new_field_value): # pragma: no cover
528528
updated = True
529529
setattr(existing, field_name, new_field_value)
530530
existing = await self.update(existing)

advanced_alchemy/repository/memory/_sync.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
StatementFilter,
3737
)
3838
from advanced_alchemy.repository._sync import SQLAlchemySyncRepositoryProtocol, SQLAlchemySyncSlugRepositoryProtocol
39-
from advanced_alchemy.repository._util import DEFAULT_ERROR_MESSAGE_TEMPLATES, LoadSpec
39+
from advanced_alchemy.repository._util import DEFAULT_ERROR_MESSAGE_TEMPLATES, LoadSpec, compare_values
4040
from advanced_alchemy.repository.memory.base import (
4141
AnyObject,
4242
InMemoryStore,
@@ -491,7 +491,7 @@ def get_or_upsert(
491491
if upsert:
492492
for field_name, new_field_value in kwargs_.items():
493493
field = getattr(existing, field_name, MISSING)
494-
if field is not MISSING and field != new_field_value:
494+
if field is not MISSING and not compare_values(field, new_field_value): # pragma: no cover
495495
setattr(existing, field_name, new_field_value)
496496
existing = self.update(existing)
497497
return existing, False
@@ -525,7 +525,7 @@ def get_and_update(
525525
updated = False
526526
for field_name, new_field_value in kwargs_.items():
527527
field = getattr(existing, field_name, MISSING)
528-
if field is not MISSING and field != new_field_value:
528+
if field is not MISSING and not compare_values(field, new_field_value): # pragma: no cover
529529
updated = True
530530
setattr(existing, field_name, new_field_value)
531531
existing = self.update(existing)

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ test = [
148148
"cattrs",
149149
"dishka ; python_version >= \"3.10\"",
150150
"pydantic-extra-types",
151+
"numpy",
152+
"pgvector",
151153
"rich-click",
152154
"coverage>=7.6.1",
153155
"fsspec[s3]",

0 commit comments

Comments
 (0)