77import re
88import textwrap
99from functools import partial
10- from typing import Callable , Dict , Optional , Tuple , TypeVar , Union
10+ from typing import Callable , Dict , List , Optional , Sequence , Tuple , TypeVar , Union
1111
1212# TODO: give this a proper public API?
1313_process_registry = collections .defaultdict (list )
@@ -61,7 +61,7 @@ def decorate(f: Callable) -> Callable:
6161 return decorate
6262
6363
64- def _get_doc (obj : Union [str , Callable ]) -> str :
64+ def get_docstring (obj : Union [str , Callable ]) -> str :
6565 """
6666 Get docstring of a method or function.
6767 """
@@ -76,7 +76,7 @@ def extract_params(doc: Union[str, Callable]) -> Dict[str, str]:
7676 """
7777 Extract parameters (``:param name:`` format) from a docstring.
7878 """
79- doc = _get_doc (doc )
79+ doc = get_docstring (doc )
8080 params_regex = re .compile (r"^:param\s+(?P<param>\w+)\s*:(?P<doc>.*(\n +.*)*)" , re .MULTILINE )
8181 return {m .group ("param" ): m .group ("doc" ).strip () for m in params_regex .finditer (doc )}
8282
@@ -85,12 +85,27 @@ def extract_return(doc: Union[str, Callable]) -> Union[str, None]:
8585 """
8686 Extract return value description (``:return:`` format) from a docstring.
8787 """
88- doc = _get_doc (doc )
88+ doc = get_docstring (doc )
8989 return_regex = re .compile (r"^:return\s*:(?P<doc>.*(\n +.*)*)" , re .MULTILINE )
9090 matches = [m .group ("doc" ).strip () for m in return_regex .finditer (doc )]
9191 assert 0 <= len (matches ) <= 1
9292 return matches [0 ] if matches else None
9393
94+
95+ def extract_main_description (doc : Union [str , Callable ]) -> List [str ]:
96+ """
97+ Extract main description from a docstring:
98+ paragraphs before the params/returns description.
99+ """
100+ paragraphs = []
101+ for part in re .split (r"\s*\n(?:\s*\n)+" , get_docstring (doc )):
102+ if re .match (r"\s*:" , part ):
103+ break
104+ paragraphs .append (part .strip ("\n " ))
105+ assert len (paragraphs ) > 0
106+ return paragraphs
107+
108+
94109def assert_same_param_docs (doc_a : Union [str , Callable ], doc_b : Union [str , Callable ], only_intersection : bool = False ):
95110 """
96111 Compare parameters (``:param name:`` format) from two docstrings.
@@ -112,3 +127,17 @@ def assert_same_return_docs(doc_a: Union[str, Callable], doc_b: Union[str, Calla
112127 Compare return value descriptions from two docstrings.
113128 """
114129 assert extract_return (doc_a ) == extract_return (doc_b )
130+
131+
132+ def assert_same_main_description (doc_a : Union [str , Callable ], doc_b : Union [str , Callable ], ignore : Sequence [str ] = ()):
133+ """
134+ Compare main description from two docstrings.
135+ """
136+ description_a = extract_main_description (doc_a )
137+ description_b = extract_main_description (doc_b )
138+
139+ for s in ignore :
140+ description_a = [p .replace (s , "<IGNORED>" ) for p in description_a ]
141+ description_b = [p .replace (s , "<IGNORED>" ) for p in description_b ]
142+
143+ assert description_a == description_b
0 commit comments