Skip to content

Commit ec45901

Browse files
Doc: Refactor copydoc
1 parent fdf706a commit ec45901

1 file changed

Lines changed: 16 additions & 13 deletions

File tree

src/libsemigroups_pybind11/detail/decorators.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def _get_overloaded_doc(func):
2222
return "1. " + doc
2323

2424

25-
def _correct_overloads(target, base_func, extra_funcs):
25+
def _correct_overloads(target, *funcs):
2626
"""
2727
Fix the docstring of copied overloaded functions
2828
@@ -41,27 +41,27 @@ def _correct_overloads(target, base_func, extra_funcs):
4141
This function does these things.
4242
"""
4343
target_name = target.__name__
44-
old_names = set(fn.__name__ for fn in extra_funcs) | {base_func.__name__}
45-
base_doc = _get_overloaded_doc(base_func)
46-
extra_doc = "".join([_get_overloaded_doc(func) for func in extra_funcs])
44+
if target.__doc__:
45+
funcs = list(funcs) + [target]
46+
old_names = set(fn.__name__ for fn in funcs)
47+
doc = "".join([_get_overloaded_doc(func) for func in funcs])
4748

4849
# Remove pybind11 inserted strings
49-
replacements = {
50-
f"{old_name}(*args, **kwargs)\n": "" for old_name in old_names
51-
}
50+
replacements = {f"{old_name}(*args, **kwargs)\n": "" for old_name in old_names}
5251
replacements["Overloaded function.\n"] = ""
5352
for old, new in replacements.items():
54-
extra_doc = extra_doc.replace(old, new)
53+
doc = doc.replace(old, new)
5554

5655
# Fix overload numbering
5756
overload_counter = 1
58-
doc_blocks = _re.split(
59-
rf"\d+\. (?:{'|'.join(old_names)})", base_doc + extra_doc
60-
)
57+
doc_blocks = _re.split(rf"\d+\. (?:{'|'.join(old_names)})", doc)
6158
new_doc = doc_blocks[0]
6259
for doc_block in doc_blocks[1:]:
6360
new_doc += f"{overload_counter}. {target_name}" + doc_block
6461
overload_counter += 1
62+
63+
# Add pybind11 strings at the start
64+
new_doc = f"{target_name}(*args, **kwargs)\nOverloaded function.\n\n" + new_doc.strip("\n")
6565
return new_doc
6666

6767

@@ -73,12 +73,15 @@ def copydoc(func, *extra_funcs):
7373
@copydoc(Transf1.__init__)
7474
def __init___(self) -> None:
7575
pass
76+
77+
If *target* has its own docstring, this will be added to the end of the new
78+
docstring.
7679
"""
7780
new_doc = func.__doc__
7881

7982
def wrapper(target):
80-
if extra_funcs:
81-
target.__doc__ = _correct_overloads(target, func, extra_funcs)
83+
if extra_funcs or target.__doc__:
84+
target.__doc__ = _correct_overloads(target, func, *extra_funcs)
8285
else:
8386
target.__doc__ = new_doc
8487
return target

0 commit comments

Comments
 (0)