Skip to content

Commit 9bd6b36

Browse files
authored
feat: match against complex types (#501)
Corrected an issue that prevented certain data types from working with the matching function.
1 parent 489b052 commit 9bd6b36

File tree

4 files changed

+346
-17
lines changed

4 files changed

+346
-17
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,4 @@ cython_debug/
176176
CLAUDE.md
177177
.gitignore
178178
.todos
179+
.tmp

advanced_alchemy/repository/_async.py

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import datetime
2+
import decimal
13
import random
24
import string
35
from collections.abc import Iterable, Sequence
@@ -55,6 +57,18 @@
5557

5658
DEFAULT_INSERTMANYVALUES_MAX_PARAMETERS: Final = 950
5759
POSTGRES_VERSION_SUPPORTING_MERGE: Final = 15
60+
DEFAULT_SAFE_TYPES: Final[set[type[Any]]] = {
61+
int,
62+
float,
63+
str,
64+
bool,
65+
bytes,
66+
decimal.Decimal,
67+
datetime.date,
68+
datetime.datetime,
69+
datetime.time,
70+
datetime.timedelta,
71+
}
5872

5973

6074
@runtime_checkable
@@ -497,6 +511,65 @@ def _get_uniquify(self, uniquify: Optional[bool] = None) -> bool:
497511
"""
498512
return bool(uniquify) if uniquify is not None else self._uniquify
499513

514+
def _type_must_use_in_instead_of_any(self, matched_values: "list[Any]", field_type: "Any" = None) -> bool:
515+
"""Determine if field.in_() should be used instead of any_() for compatibility.
516+
517+
Uses SQLAlchemy's type introspection to detect types that may have DBAPI
518+
serialization issues with the ANY() operator. Checks if actual values match
519+
the column's expected python_type - mismatches indicate complex types that
520+
need the safer IN() operator. Falls back to Python type checking when
521+
SQLAlchemy type information is unavailable.
522+
523+
Args:
524+
matched_values: Values to be used in the filter
525+
field_type: Optional SQLAlchemy TypeEngine from the column
526+
527+
Returns:
528+
bool: True if field.in_() should be used instead of any_()
529+
"""
530+
if not matched_values:
531+
return False
532+
533+
if field_type is not None:
534+
try:
535+
expected_python_type = getattr(field_type, "python_type", None)
536+
if expected_python_type is not None:
537+
for value in matched_values:
538+
if value is not None and not isinstance(value, expected_python_type):
539+
return True
540+
except (AttributeError, NotImplementedError):
541+
return True
542+
543+
return any(value is not None and type(value) not in DEFAULT_SAFE_TYPES for value in matched_values)
544+
545+
def _get_unique_values(self, values: "list[Any]") -> "list[Any]":
546+
"""Get unique values from a list, handling unhashable types safely.
547+
548+
Args:
549+
values: List of values to deduplicate
550+
551+
Returns:
552+
list[Any]: List of unique values preserving order
553+
"""
554+
if not values:
555+
return []
556+
557+
try:
558+
# Fast path for hashable types
559+
seen: set[Any] = set()
560+
unique_values: list[Any] = []
561+
for value in values:
562+
if value not in seen:
563+
unique_values.append(value)
564+
seen.add(value)
565+
except TypeError:
566+
# Fallback for unhashable types (e.g., dicts from JSONB)
567+
unique_values = []
568+
for value in values:
569+
if value not in unique_values:
570+
unique_values.append(value)
571+
return unique_values
572+
500573
@staticmethod
501574
def _get_error_messages(
502575
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
@@ -974,9 +1047,11 @@ def _get_delete_many_statement(
9741047
statement = statement.execution_options(**execution_options)
9751048
if supports_returning and statement_type != "select":
9761049
statement = cast("ReturningDelete[tuple[ModelT]]", statement.returning(model_type)) # type: ignore[union-attr,assignment] # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType,reportAttributeAccessIssue,reportUnknownVariableType]
977-
if self._prefer_any:
978-
return statement.where(any_(id_chunk) == id_attribute) # type: ignore[arg-type]
979-
return statement.where(id_attribute.in_(id_chunk)) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
1050+
# Use field.in_() if types are incompatible with ANY() or if dialect doesn't prefer ANY()
1051+
use_in = not self._prefer_any or self._type_must_use_in_instead_of_any(id_chunk, id_attribute.type)
1052+
if use_in:
1053+
return statement.where(id_attribute.in_(id_chunk)) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
1054+
return statement.where(any_(id_chunk) == id_attribute) # type: ignore[arg-type]
9801055

9811056
async def get(
9821057
self,
@@ -1870,7 +1945,9 @@ async def upsert_many(
18701945
matched_values = [
18711946
field_data for datum in data if (field_data := getattr(datum, field_name)) is not None
18721947
]
1873-
match_filter.append(any_(matched_values) == field if self._prefer_any else field.in_(matched_values)) # type: ignore[arg-type]
1948+
# Use field.in_() if types are incompatible with ANY() or if dialect doesn't prefer ANY()
1949+
use_in = not self._prefer_any or self._type_must_use_in_instead_of_any(matched_values, field.type)
1950+
match_filter.append(field.in_(matched_values) if use_in else any_(matched_values) == field) # type: ignore[arg-type]
18741951

18751952
with wrap_sqlalchemy_exception(
18761953
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
@@ -1883,10 +1960,12 @@ async def upsert_many(
18831960
)
18841961
for field_name in match_fields:
18851962
field = get_instrumented_attr(self.model_type, field_name)
1886-
matched_values = list(
1887-
{getattr(datum, field_name) for datum in existing_objs if datum}, # ensure the list is unique
1888-
)
1889-
match_filter.append(any_(matched_values) == field if self._prefer_any else field.in_(matched_values)) # type: ignore[arg-type]
1963+
# Safe deduplication that handles unhashable types (e.g., JSONB dicts)
1964+
all_values = [getattr(datum, field_name) for datum in existing_objs if datum]
1965+
matched_values = self._get_unique_values(all_values)
1966+
# Use field.in_() if types are incompatible with ANY() or if dialect doesn't prefer ANY()
1967+
use_in = not self._prefer_any or self._type_must_use_in_instead_of_any(matched_values, field.type)
1968+
match_filter.append(field.in_(matched_values) if use_in else any_(matched_values) == field) # type: ignore[arg-type]
18901969
existing_ids = self._get_object_ids(existing_objs=existing_objs)
18911970
data = self._merge_on_match_fields(data, existing_objs, match_fields)
18921971
for datum in data:

advanced_alchemy/repository/_sync.py

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Do not edit this file directly. It has been autogenerated from
22
# advanced_alchemy/repository/_async.py
3+
import datetime
4+
import decimal
35
import random
46
import string
57
from collections.abc import Iterable, Sequence
@@ -56,6 +58,18 @@
5658

5759
DEFAULT_INSERTMANYVALUES_MAX_PARAMETERS: Final = 950
5860
POSTGRES_VERSION_SUPPORTING_MERGE: Final = 15
61+
DEFAULT_SAFE_TYPES: Final[set[type[Any]]] = {
62+
int,
63+
float,
64+
str,
65+
bool,
66+
bytes,
67+
decimal.Decimal,
68+
datetime.date,
69+
datetime.datetime,
70+
datetime.time,
71+
datetime.timedelta,
72+
}
5973

6074

6175
@runtime_checkable
@@ -498,6 +512,65 @@ def _get_uniquify(self, uniquify: Optional[bool] = None) -> bool:
498512
"""
499513
return bool(uniquify) if uniquify is not None else self._uniquify
500514

515+
def _type_must_use_in_instead_of_any(self, matched_values: "list[Any]", field_type: "Any" = None) -> bool:
516+
"""Determine if field.in_() should be used instead of any_() for compatibility.
517+
518+
Uses SQLAlchemy's type introspection to detect types that may have DBAPI
519+
serialization issues with the ANY() operator. Checks if actual values match
520+
the column's expected python_type - mismatches indicate complex types that
521+
need the safer IN() operator. Falls back to Python type checking when
522+
SQLAlchemy type information is unavailable.
523+
524+
Args:
525+
matched_values: Values to be used in the filter
526+
field_type: Optional SQLAlchemy TypeEngine from the column
527+
528+
Returns:
529+
bool: True if field.in_() should be used instead of any_()
530+
"""
531+
if not matched_values:
532+
return False
533+
534+
if field_type is not None:
535+
try:
536+
expected_python_type = getattr(field_type, "python_type", None)
537+
if expected_python_type is not None:
538+
for value in matched_values:
539+
if value is not None and not isinstance(value, expected_python_type):
540+
return True
541+
except (AttributeError, NotImplementedError):
542+
return True
543+
544+
return any(value is not None and type(value) not in DEFAULT_SAFE_TYPES for value in matched_values)
545+
546+
def _get_unique_values(self, values: "list[Any]") -> "list[Any]":
547+
"""Get unique values from a list, handling unhashable types safely.
548+
549+
Args:
550+
values: List of values to deduplicate
551+
552+
Returns:
553+
list[Any]: List of unique values preserving order
554+
"""
555+
if not values:
556+
return []
557+
558+
try:
559+
# Fast path for hashable types
560+
seen: set[Any] = set()
561+
unique_values: list[Any] = []
562+
for value in values:
563+
if value not in seen:
564+
unique_values.append(value)
565+
seen.add(value)
566+
except TypeError:
567+
# Fallback for unhashable types (e.g., dicts from JSONB)
568+
unique_values = []
569+
for value in values:
570+
if value not in unique_values:
571+
unique_values.append(value)
572+
return unique_values
573+
501574
@staticmethod
502575
def _get_error_messages(
503576
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
@@ -975,9 +1048,11 @@ def _get_delete_many_statement(
9751048
statement = statement.execution_options(**execution_options)
9761049
if supports_returning and statement_type != "select":
9771050
statement = cast("ReturningDelete[tuple[ModelT]]", statement.returning(model_type)) # type: ignore[union-attr,assignment] # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType,reportAttributeAccessIssue,reportUnknownVariableType]
978-
if self._prefer_any:
979-
return statement.where(any_(id_chunk) == id_attribute) # type: ignore[arg-type]
980-
return statement.where(id_attribute.in_(id_chunk)) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
1051+
# Use field.in_() if types are incompatible with ANY() or if dialect doesn't prefer ANY()
1052+
use_in = not self._prefer_any or self._type_must_use_in_instead_of_any(id_chunk, id_attribute.type)
1053+
if use_in:
1054+
return statement.where(id_attribute.in_(id_chunk)) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
1055+
return statement.where(any_(id_chunk) == id_attribute) # type: ignore[arg-type]
9811056

9821057
def get(
9831058
self,
@@ -1869,7 +1944,9 @@ def upsert_many(
18691944
matched_values = [
18701945
field_data for datum in data if (field_data := getattr(datum, field_name)) is not None
18711946
]
1872-
match_filter.append(any_(matched_values) == field if self._prefer_any else field.in_(matched_values)) # type: ignore[arg-type]
1947+
# Use field.in_() if types are incompatible with ANY() or if dialect doesn't prefer ANY()
1948+
use_in = not self._prefer_any or self._type_must_use_in_instead_of_any(matched_values, field.type)
1949+
match_filter.append(field.in_(matched_values) if use_in else any_(matched_values) == field) # type: ignore[arg-type]
18731950

18741951
with wrap_sqlalchemy_exception(
18751952
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
@@ -1882,10 +1959,12 @@ def upsert_many(
18821959
)
18831960
for field_name in match_fields:
18841961
field = get_instrumented_attr(self.model_type, field_name)
1885-
matched_values = list(
1886-
{getattr(datum, field_name) for datum in existing_objs if datum}, # ensure the list is unique
1887-
)
1888-
match_filter.append(any_(matched_values) == field if self._prefer_any else field.in_(matched_values)) # type: ignore[arg-type]
1962+
# Safe deduplication that handles unhashable types (e.g., JSONB dicts)
1963+
all_values = [getattr(datum, field_name) for datum in existing_objs if datum]
1964+
matched_values = self._get_unique_values(all_values)
1965+
# Use field.in_() if types are incompatible with ANY() or if dialect doesn't prefer ANY()
1966+
use_in = not self._prefer_any or self._type_must_use_in_instead_of_any(matched_values, field.type)
1967+
match_filter.append(field.in_(matched_values) if use_in else any_(matched_values) == field) # type: ignore[arg-type]
18891968
existing_ids = self._get_object_ids(existing_objs=existing_objs)
18901969
data = self._merge_on_match_fields(data, existing_objs, match_fields)
18911970
for datum in data:

0 commit comments

Comments
 (0)