Skip to content

Commit 531ae30

Browse files
authored
Fix calls to functions with type vars and object arguments (#931)
Currently only the overloaded scalar cases seem to work for functions taking typevars. Handle them by having _typing_dispatch look at typevar bounds.
1 parent 22f84b3 commit 531ae30

File tree

2 files changed

+23
-33
lines changed

2 files changed

+23
-33
lines changed

gel/_internal/_typing_dispatch.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
_R_co = TypeVar("_R_co", covariant=True)
4343

4444

45-
def _isinstance(obj: Any, tp: Any) -> bool:
45+
def _isinstance(obj: Any, tp: Any, fn: Any) -> bool:
4646
# Handle Any type - matches everything
4747
if tp is Any:
4848
return True
@@ -52,7 +52,7 @@ def _isinstance(obj: Any, tp: Any) -> bool:
5252
return isinstance(obj, tp)
5353

5454
elif _typing_inspect.is_union_type(tp):
55-
return any(_isinstance(obj, el) for el in typing.get_args(tp))
55+
return any(_isinstance(obj, el, fn) for el in typing.get_args(tp))
5656

5757
elif _typing_inspect.is_literal(tp):
5858
# For Literal types, check if obj is one of the literal values
@@ -62,11 +62,17 @@ def _isinstance(obj: Any, tp: Any) -> bool:
6262
origin = typing.get_origin(tp)
6363
args = typing.get_args(tp)
6464
if origin is type:
65+
atype = args[0]
66+
if isinstance(atype, TypeVar):
67+
atype = atype.__bound__
68+
ns = _namespace.module_ns_of(fn)
69+
atype = _typing_eval.resolve_type(atype, globals=ns)
70+
6571
if isinstance(obj, type):
66-
return issubclass(obj, args[0])
72+
return issubclass(obj, atype)
6773
elif (mroent := getattr(obj, "__mro_entries__", None)) is not None:
6874
genalias_mro = mroent((obj,))
69-
return any(issubclass(c, args[0]) for c in genalias_mro)
75+
return any(issubclass(c, atype) for c in genalias_mro)
7076
else:
7177
return False
7278

@@ -93,7 +99,7 @@ def _isinstance(obj: Any, tp: Any) -> bool:
9399

94100
# For Mapping[K, V], check first key and first value
95101
k, v = next(iter(obj.items()))
96-
return _isinstance(k, args[0]) and _isinstance(v, args[1])
102+
return _isinstance(k, args[0], fn) and _isinstance(v, args[1], fn)
97103

98104
elif issubclass(origin, tuple):
99105
# Check the container type first
@@ -111,13 +117,13 @@ def _isinstance(obj: Any, tp: Any) -> bool:
111117
# heterogeneous tuple[*T]
112118
if num_args == 2 and args[1] is ...:
113119
# Homogeneous tuple like tuple[int, ...]
114-
return _isinstance(next(iter(obj)), args[0])
120+
return _isinstance(next(iter(obj)), args[0], fn)
115121
elif num_args != num_elems:
116122
# Shape of tuple value does not match type definition
117123
return False
118124
else:
119125
for el_type, el_val in zip(args, obj, strict=True):
120-
if not _isinstance(el_val, el_type):
126+
if not _isinstance(el_val, el_type, fn):
121127
return False
122128
return True
123129

@@ -130,7 +136,7 @@ def _isinstance(obj: Any, tp: Any) -> bool:
130136
if not args or len(obj) == 0:
131137
return True
132138

133-
return _isinstance(next(iter(obj)), args[0])
139+
return _isinstance(next(iter(obj)), args[0], fn)
134140

135141
else:
136142
# For other generic types, fall back to checking the origin
@@ -270,7 +276,7 @@ def _call(
270276
f"cannot dispatch to {self._qname}: an overload "
271277
f"is missing a type annotation on the {pn} parameter"
272278
)
273-
if not _isinstance(arg, pt):
279+
if not _isinstance(arg, pt, fn):
274280
break
275281
else:
276282
if bound_to is not None:

tests/test_qb.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -887,12 +887,6 @@ def test_qb_filter_12(self):
887887
res = sess_client.query(q)
888888
self.assertEqual(len(res), 2)
889889

890-
@tb.xfail('''
891-
std.union seems to fail in the filter
892-
893-
TypeError: issubclass() arg 2 must be a class, a tuple of classes,
894-
or a union
895-
''')
896890
def test_qb_filter_13(self):
897891
from models.orm import default, std
898892

@@ -1190,35 +1184,25 @@ def test_qb_update_02(self):
11901184
self.assertEqual(res.name, "blue")
11911185
self.assertEqual({u.name for u in res.users}, {"Zoe", "Dana"})
11921186

1193-
@tb.xfail('''
1194-
Runtime failure because assert_single doesn't work.
1195-
Bug #897.
1196-
1197-
Also fails at typecheck time because assert_single can return None?
1198-
''')
11991187
def test_qb_update_03(self):
12001188
from models.orm import default, std
12011189

12021190
# Combine update and select of the updated object
12031191
res = self.client.get(
12041192
default.Post.filter(body="Hello")
12051193
.update(
1206-
author=std.assert_single(default.User.filter(name="Billie"))
1194+
author=std.assert_single(default.User.filter(name="Billie")) # type: ignore
12071195
)
12081196
.select("*", author=lambda p: p.author.select("**"))
12091197
)
12101198

12111199
self.assertEqual(res.body, "Hello")
1212-
self.assertEqual(res.author.name, "Zoe")
1213-
self.assertEqual({g.name for g in res.author.groups}, {"redgreen"})
1200+
self.assertEqual(res.author.name, "Billie")
1201+
self.assertEqual({g.name for g in res.author.groups}, {"red", "green"})
12141202

1215-
@tb.xfail('''
1216-
Runtime failure because assert_single doesn't work.
1217-
Bug #897.
1218-
1219-
Also fails at typecheck time because update's *types* dont't
1220-
support callbacks, though runtime does.
1221-
''')
1203+
# Fails at typecheck time because update's *types* dont't
1204+
# support callbacks, though runtime does.
1205+
@tb.skip_typecheck
12221206
def test_qb_update_04(self):
12231207
from models.orm import default, std
12241208

@@ -1239,9 +1223,9 @@ def test_qb_update_04(self):
12391223
# Add Alice to the group
12401224
self.client.query(
12411225
default.UserGroup.filter(name="blue").update(
1242-
users=lambda g: std.union(
1226+
users=lambda g: std.assert_distinct(std.union(
12431227
g.users, default.User.filter(name="Alice")
1244-
)
1228+
))
12451229
)
12461230
)
12471231

0 commit comments

Comments
 (0)