Skip to content

Commit 40dfb29

Browse files
steppilucascolley
authored andcommitted
MAINT: simplify check that target owns method
1 parent 9cf7895 commit 40dfb29

File tree

1 file changed

+7
-11
lines changed

1 file changed

+7
-11
lines changed

src/array_api_extra/testing.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def test_myfunc(xp):
219219
DeprecationWarning,
220220
stacklevel=2,
221221
)
222-
tags = {
222+
tags: dict[str, bool | int | type] = {
223223
"allow_dask_compute": allow_dask_compute,
224224
"jax_jit": jax_jit,
225225
}
@@ -238,10 +238,6 @@ def test_myfunc(xp):
238238
raw_attr = getattr_static(cls, method_name)
239239
method = getattr(cls, method_name)
240240
cloned_method = _clone_function(method)
241-
# Update the ``__qualname__`` because this will be used later to check
242-
# whether something is a method defined in the class of interest, or just
243-
# a reference to a function that's stored in a class.
244-
cloned_method.__qualname__ = f"{cls.__name__}.{method_name}"
245241

246242
method_to_set: Any
247243
if isinstance(raw_attr, staticmethod):
@@ -253,6 +249,8 @@ def test_myfunc(xp):
253249

254250
setattr(cls, method_name, method_to_set)
255251
f = getattr(cls, method_name)
252+
# Annotate that cls owns this method so we can check that later.
253+
tags["owner"] = cls
256254
else:
257255
f = func
258256

@@ -382,19 +380,17 @@ def iter_tagged() -> Iterator[
382380
with contextlib.suppress(KeyError, TypeError):
383381
tags = _ufuncs_tags[func]
384382
if tags is not None:
385-
if isinstance(target, type):
383+
if isinstance(target, type) and tags.get("owner") is not target:
386384
# There's a common pattern to wrap functions in namespace
387385
# classes to bypass lazy_xp_function like this:
388386
#
389387
# class naked:
390388
# myfunc = mymodule.myfunc
391389
#
392390
# To ensure this still works when checking for tags in
393-
# attributes of classes, use ``__qualname__`` to check whether
394-
# or not ``func`` was originally defined within ``target``.
395-
qn = getattr(func, "__qualname__", "")
396-
if not qn.startswith(f"{target.__name__}."):
397-
continue
391+
# attributes of classes, ensure that target is the actual
392+
# owning class where func was defined.
393+
continue
398394
# put attr, and func in the outputs so we can later tell
399395
# if this was a staticmethod or classmethod.
400396
yield target, name, attr, func, tags

0 commit comments

Comments
 (0)