diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 0934f36b..806ae013 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -27,7 +27,7 @@ jobs: python -m pip install -r ./tests/requirements.txt - name: Checks with pre-commit - uses: pre-commit/action@v2.0.3 + uses: pre-commit/action@v3.0.1 - name: Test with pytest run: | diff --git a/equinox/_module.py b/equinox/_module.py index 69a1bc79..a67c8e3b 100644 --- a/equinox/_module.py +++ b/equinox/_module.py @@ -681,6 +681,15 @@ def __init__(self, method): if getattr(self.method, "__isabstractmethod__", False): self.__isabstractmethod__ = self.method.__isabstractmethod__ + # Fixes https://github.com/patrick-kidger/equinox/issues/1016 + @property + def __module__(self): + return self.method.__module__ + + @property + def __doc__(self): + return self.method.__doc__ + def __get__(self, instance, owner): del owner if instance is None: diff --git a/tests/test_module.py b/tests/test_module.py index b9a7471f..680341da 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -1,5 +1,6 @@ import abc import dataclasses +import doctest import functools as ft import inspect import pickle @@ -1257,3 +1258,61 @@ def test_tree_roundtrip_serialise_roundtrip(): x = pickle.loads(pickle.dumps(x)) tree2 = jtu.tree_structure(x) assert tree == tree2 + + +# https://github.com/patrick-kidger/equinox/issues/1016 +def test_doctest(): + class Example(eqx.Module): + def foo(self) -> int: + """ + >>> example = Example() + >>> example.foo() + 1 + """ + return 1 + + @property + def bar(self) -> int: + """ + >>> example = Example() + >>> example.bar + 1 + """ + return 1 + + @classmethod + def baz(self) -> int: + """ + >>> Example.baz() + 1 + """ + return 1 + + @eqx.filter_jit + def biz(self) -> int: + """ + >>> example = Example() + >>> example.biz() + 1 + """ + return 1 + + @staticmethod + def buz() -> int: + """ + >>> example = Example() + >>> example.buz() + 1 + """ + return 1 + + tests = doctest.DocTestFinder().find(Example) + tests = [test for test in tests if test.examples] + + assert len(tests) == 5 + + assert any("foo" in str(test) for test in tests) + assert any("bar" in str(test) for test in tests) + assert any("baz" in str(test) for test in tests) + assert any("biz" in str(test) for test in tests) + assert any("buz" in str(test) for test in tests)