Skip to content

Commit d36c654

Browse files
jeertmanspatrick-kidger
authored andcommitted
fix(lib): allow DocTest to discover methods from subclasses of eqx.Module
Closes #1016. In the end, there were two issues: (1) `DocTestFinder._from_module` was returning `False` because the `__module__` attribute was not correct, and (2) the documentation was not copied to the wrapper class so doctests could not be discovered. My solution is inspired from what Python is doing with their own decorators, see https://github.com/python/cpython/blob/4617d68d73409e83d6ab31106d10421d44048787/Lib/functools.py#L1033-L1049
1 parent 76d697c commit d36c654

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed

equinox/_module.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,15 @@ def __init__(self, method):
681681
if getattr(self.method, "__isabstractmethod__", False):
682682
self.__isabstractmethod__ = self.method.__isabstractmethod__
683683

684+
# Fixes https://github.com/patrick-kidger/equinox/issues/1016
685+
@property
686+
def __module__(self):
687+
return self.method.__module__
688+
689+
@property
690+
def __doc__(self):
691+
return self.method.__doc__
692+
684693
def __get__(self, instance, owner):
685694
del owner
686695
if instance is None:

tests/test_module.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import abc
22
import dataclasses
3+
import doctest
34
import functools as ft
45
import inspect
56
import pickle
@@ -1257,3 +1258,61 @@ def test_tree_roundtrip_serialise_roundtrip():
12571258
x = pickle.loads(pickle.dumps(x))
12581259
tree2 = jtu.tree_structure(x)
12591260
assert tree == tree2
1261+
1262+
1263+
# https://github.com/patrick-kidger/equinox/issues/1016
1264+
def test_doctest():
1265+
class Example(eqx.Module):
1266+
def foo(self) -> int:
1267+
"""
1268+
>>> example = Example()
1269+
>>> example.foo()
1270+
1
1271+
"""
1272+
return 1
1273+
1274+
@property
1275+
def bar(self) -> int:
1276+
"""
1277+
>>> example = Example()
1278+
>>> example.bar
1279+
1
1280+
"""
1281+
return 1
1282+
1283+
@classmethod
1284+
def baz(self) -> int:
1285+
"""
1286+
>>> Example.baz()
1287+
1
1288+
"""
1289+
return 1
1290+
1291+
@eqx.filter_jit
1292+
def biz(self) -> int:
1293+
"""
1294+
>>> example = Example()
1295+
>>> example.biz()
1296+
1
1297+
"""
1298+
return 1
1299+
1300+
@staticmethod
1301+
def buz() -> int:
1302+
"""
1303+
>>> example = Example()
1304+
>>> example.buz()
1305+
1
1306+
"""
1307+
return 1
1308+
1309+
tests = doctest.DocTestFinder().find(Example)
1310+
tests = [test for test in tests if test.examples]
1311+
1312+
assert len(tests) == 5
1313+
1314+
assert any("foo" in str(test) for test in tests)
1315+
assert any("bar" in str(test) for test in tests)
1316+
assert any("baz" in str(test) for test in tests)
1317+
assert any("biz" in str(test) for test in tests)
1318+
assert any("buz" in str(test) for test in tests)

0 commit comments

Comments
 (0)