Skip to content

Commit 7fc957b

Browse files
laramiel2bndy5
andauthored
feat(apidoc.cpp): conditionally add space between certain AST tokens (#483)
Removing extra spaces around namespace separators is a simple fix to improve type matching when the result of get_extent_spelling() is passed to _substitute_internal_type_names() --------- Co-authored-by: Brendan <2bndy5@gmail.com>
1 parent b194f33 commit 7fc957b

File tree

2 files changed

+55
-5
lines changed

2 files changed

+55
-5
lines changed

sphinx_immaterial/apidoc/cpp/api_parser.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -939,21 +939,31 @@ def get_extent_spelling(translation_unit: TranslationUnit, extent: SourceRange)
939939
whitespace. This results in excessive whitespace, but that does not matter
940940
because this is intended to be parsed by the Sphinx cpp domain anyway.
941941
"""
942+
no_spaces = (
943+
(TokenKind.KEYWORD, TokenKind.PUNCTUATION),
944+
(TokenKind.IDENTIFIER, TokenKind.PUNCTUATION),
945+
(TokenKind.PUNCTUATION, TokenKind.KEYWORD),
946+
(TokenKind.PUNCTUATION, TokenKind.IDENTIFIER),
947+
)
942948

943949
def get_spellings():
944950
prev_token = None
945951
COMMENT = TokenKind.COMMENT
946952
for token in translation_unit.get_tokens(extent=extent):
947953
if prev_token is not None:
948954
yield prev_token.spelling
955+
if (prev_token.kind, token.kind) not in no_spaces:
956+
yield " "
949957
prev_token = None
950958
if token.kind == COMMENT:
959+
yield " "
951960
continue
952961
prev_token = token
953962
# We need to handle the last token specially, because clang sometimes parses
954963
# ">>" as a single token but the extent may cover only the first of the two
955964
# angle brackets.
956965
if prev_token is not None:
966+
yield " "
957967
spelling = prev_token.spelling
958968
token_end = cast(SourceLocation, prev_token.extent.end)
959969
offset_diff = token_end.offset - cast(SourceLocation, extent.end).offset
@@ -962,7 +972,7 @@ def get_spellings():
962972
else:
963973
yield spelling
964974

965-
return " ".join(get_spellings())
975+
return "".join(get_spellings())
966976

967977

968978
def get_related_comments(decl: Cursor):
@@ -1532,10 +1542,7 @@ def _transform_unexposed_decl(config: Config, decl: Cursor) -> Optional[VarEntit
15321542
# exposed as an unexposed decl.
15331543

15341544
source_code = get_extent_spelling(decl.translation_unit, decl.extent)
1535-
1536-
# Note: Since `source_code` is reconstructed from the tokens, we don't need to
1537-
# worry about inconsistency in spacing.
1538-
if not source_code.startswith("template <"):
1545+
if not re.search(r"^\s*template\s*<", source_code):
15391546
return None
15401547

15411548
# Assume that it is a variable template
@@ -2344,6 +2351,7 @@ def _normalize_entity_requires(entity: CppApiEntity):
23442351
if _is_function(entity):
23452352
func_entity = cast(FunctionEntity, entity)
23462353
declaration = func_entity["declaration"]
2354+
declaration = _substitute_internal_type_names(config, declaration)
23472355
if replacements:
23482356
declaration = _apply_identifier_replacements(declaration, replacements)
23492357
if (

tests/cpp_api_parser_test.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,3 +284,45 @@ def test_variable_template_specialization():
284284
"HasA",
285285
"HasA-int",
286286
]
287+
288+
289+
def test_type_replacements():
290+
config = api_parser.Config(
291+
input_path="a.cpp",
292+
input_content=rb"""
293+
struct source_location {};
294+
295+
namespace foo {
296+
struct SourceLocation {};
297+
}
298+
299+
/// Default method
300+
foo::SourceLocation Default();
301+
302+
/// Logging method
303+
void LogSourceLocation(foo::SourceLocation loc = Default());
304+
305+
/// Class
306+
class ClassWithSourceLocation {
307+
public:
308+
/// Constructor.
309+
ClassWithSourceLocation(foo::SourceLocation loc = Default())
310+
: loc_(loc) {}
311+
312+
foo::SourceLocation loc_;
313+
};
314+
""",
315+
type_replacements={
316+
"foo::SourceLocation": "source_location",
317+
},
318+
)
319+
320+
output = api_parser.generate_output(config)
321+
assert not output.get("errors")
322+
print(output)
323+
324+
assert len(output.get("entities", {}).values()) == 4
325+
for x in output["entities"].values():
326+
d = x.get("declaration")
327+
if d:
328+
assert cast(str, d).find("source_location") != -1

0 commit comments

Comments
 (0)