@@ -22,7 +22,7 @@ def _get_overloaded_doc(func):
2222 return "1. " + doc
2323
2424
25- def _correct_overloads (target , base_func , extra_funcs ):
25+ def _correct_overloads (target , * funcs ):
2626 """
2727 Fix the docstring of copied overloaded functions
2828
@@ -41,27 +41,27 @@ def _correct_overloads(target, base_func, extra_funcs):
4141 This function does these things.
4242 """
4343 target_name = target .__name__
44- old_names = set (fn .__name__ for fn in extra_funcs ) | {base_func .__name__ }
45- base_doc = _get_overloaded_doc (base_func )
46- extra_doc = "" .join ([_get_overloaded_doc (func ) for func in extra_funcs ])
44+ if target .__doc__ :
45+ funcs = list (funcs ) + [target ]
46+ old_names = set (fn .__name__ for fn in funcs )
47+ doc = "" .join ([_get_overloaded_doc (func ) for func in funcs ])
4748
4849 # Remove pybind11 inserted strings
49- replacements = {
50- f"{ old_name } (*args, **kwargs)\n " : "" for old_name in old_names
51- }
50+ replacements = {f"{ old_name } (*args, **kwargs)\n " : "" for old_name in old_names }
5251 replacements ["Overloaded function.\n " ] = ""
5352 for old , new in replacements .items ():
54- extra_doc = extra_doc .replace (old , new )
53+ doc = doc .replace (old , new )
5554
5655 # Fix overload numbering
5756 overload_counter = 1
58- doc_blocks = _re .split (
59- rf"\d+\. (?:{ '|' .join (old_names )} )" , base_doc + extra_doc
60- )
57+ doc_blocks = _re .split (rf"\d+\. (?:{ '|' .join (old_names )} )" , doc )
6158 new_doc = doc_blocks [0 ]
6259 for doc_block in doc_blocks [1 :]:
6360 new_doc += f"{ overload_counter } . { target_name } " + doc_block
6461 overload_counter += 1
62+
63+ # Add pybind11 strings at the start
64+ new_doc = f"{ target_name } (*args, **kwargs)\n Overloaded function.\n \n " + new_doc .strip ("\n " )
6565 return new_doc
6666
6767
@@ -73,12 +73,15 @@ def copydoc(func, *extra_funcs):
7373 @copydoc(Transf1.__init__)
7474 def __init___(self) -> None:
7575 pass
76+
77+ If *target* has its own docstring, this will be added to the end of the new
78+ docstring.
7679 """
7780 new_doc = func .__doc__
7881
7982 def wrapper (target ):
80- if extra_funcs :
81- target .__doc__ = _correct_overloads (target , func , extra_funcs )
83+ if extra_funcs or target . __doc__ :
84+ target .__doc__ = _correct_overloads (target , func , * extra_funcs )
8285 else :
8386 target .__doc__ = new_doc
8487 return target
0 commit comments