Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
python -m pip install -r ./tests/requirements.txt

- name: Checks with pre-commit
uses: pre-commit/action@v2.0.3
uses: pre-commit/action@v3.0.1

- name: Test with pytest
run: |
Expand Down
9 changes: 9 additions & 0 deletions equinox/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,15 @@ def __init__(self, method):
if getattr(self.method, "__isabstractmethod__", False):
self.__isabstractmethod__ = self.method.__isabstractmethod__

# Fixes https://github.com/patrick-kidger/equinox/issues/1016
@property
def __module__(self):
return self.method.__module__

@property
def __doc__(self):
return self.method.__doc__

def __get__(self, instance, owner):
del owner
if instance is None:
Expand Down
59 changes: 59 additions & 0 deletions tests/test_module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import dataclasses
import doctest
import functools as ft
import inspect
import pickle
Expand Down Expand Up @@ -1257,3 +1258,61 @@ def test_tree_roundtrip_serialise_roundtrip():
x = pickle.loads(pickle.dumps(x))
tree2 = jtu.tree_structure(x)
assert tree == tree2


# https://github.com/patrick-kidger/equinox/issues/1016
def test_doctest():
class Example(eqx.Module):
def foo(self) -> int:
"""
>>> example = Example()
>>> example.foo()
1
"""
return 1

@property
def bar(self) -> int:
"""
>>> example = Example()
>>> example.bar
1
"""
return 1

@classmethod
def baz(self) -> int:
"""
>>> Example.baz()
1
"""
return 1

@eqx.filter_jit
def biz(self) -> int:
"""
>>> example = Example()
>>> example.biz()
1
"""
return 1

@staticmethod
def buz() -> int:
"""
>>> example = Example()
>>> example.buz()
1
"""
return 1

tests = doctest.DocTestFinder().find(Example)
tests = [test for test in tests if test.examples]

assert len(tests) == 5

assert any("foo" in str(test) for test in tests)
assert any("bar" in str(test) for test in tests)
assert any("baz" in str(test) for test in tests)
assert any("biz" in str(test) for test in tests)
assert any("buz" in str(test) for test in tests)