Skip to content

Commit 1ee5ac3

Browse files
committed
Fix test
1 parent 793ab88 commit 1ee5ac3

File tree

2 files changed

+38
-21
lines changed

2 files changed

+38
-21
lines changed

src/xdoctest/static_analysis.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -171,16 +171,16 @@ def __init__(self, source: typing.Any = None):
171171
source (None | str):
172172
"""
173173
super(TopLevelVisitor, self).__init__()
174-
self.calldefs = OrderedDict()
175-
self.source = source
176-
self.sourcelines = None
174+
self.calldefs: OrderedDict[str, CallDefNode] = OrderedDict()
175+
self.source: typing.Optional[str] = source
176+
self.sourcelines: typing.Optional[list[str]] = None
177177

178-
self._current_classname = None
178+
self._current_classname: typing.Optional[str] = None
179179
# Keep track of when we leave a top level definition
180-
self._finish_queue = deque()
180+
self._finish_queue: deque[CallDefNode] = deque()
181181

182182
# new
183-
self.assignments = []
183+
self.assignments: list[typing.Any] = []
184184

185185
def syntax_tree(self) -> ast.AST:
186186
"""
@@ -189,9 +189,10 @@ def syntax_tree(self) -> ast.AST:
189189
Returns:
190190
ast.Module:
191191
"""
192+
assert self.source is not None
192193
self.sourcelines = self.source.splitlines()
193-
source_utf8 = self.source.encode('utf8')
194-
pt = ast.parse(source_utf8)
194+
# ast.parse expects a string; ensure we pass the original source
195+
pt = ast.parse(self.source)
195196
return pt
196197

197198
def process_finished(self, node: typing.Any):
@@ -202,6 +203,7 @@ def process_finished(self, node: typing.Any):
202203
node (ast.AST):
203204
"""
204205
if self._finish_queue:
206+
lineno_end: int | None
205207
if isinstance(node, int):
206208
lineno_end = node
207209
else:
@@ -327,8 +329,9 @@ def visit_If(self, node: typing.Any):
327329
if all(
328330
[
329331
isinstance(node.test.ops[0], ast.Eq),
330-
node.test.left.id == '__name__',
331-
node.test.comparators[0].value == '__main__',
332+
getattr(node.test.left, 'id', None) == '__name__',
333+
getattr(node.test.comparators[0], 'value', None)
334+
== '__main__',
332335
]
333336
):
334337
# Ignore main block
@@ -337,8 +340,9 @@ def visit_If(self, node: typing.Any):
337340
if all(
338341
[
339342
isinstance(node.test.ops[0], ast.Eq),
340-
node.test.left.id == '__name__',
341-
node.test.comparators[0].s == '__main__',
343+
getattr(node.test.left, 'id', None) == '__name__',
344+
getattr(node.test.comparators[0], 's', None)
345+
== '__main__',
342346
]
343347
):
344348
# Ignore main block
@@ -799,6 +803,14 @@ def _parse_static_node_value(node):
799803
"""
800804
Extract a constant value from a node if possible
801805
"""
806+
807+
# Prefer using ast.literal_eval when possible as it handles constants
808+
# and container literals robustly across Python versions.
809+
try:
810+
return ast.literal_eval(node)
811+
except Exception:
812+
pass
813+
802814
import numbers
803815

804816
if isinstance(node, ast.Constant) and isinstance(
@@ -807,15 +819,20 @@ def _parse_static_node_value(node):
807819
value = node.value
808820
elif isinstance(node, ast.Constant) and isinstance(node.value, str):
809821
value = node.value
810-
elif isinstance(node, ast.List):
811-
value = list(map(_parse_static_node_value, node.elts))
812-
elif isinstance(node, ast.Tuple):
813-
value = tuple(map(_parse_static_node_value, node.elts))
814-
elif isinstance(node, (ast.Dict)):
815-
keys = map(_parse_static_node_value, node.keys)
816-
values = map(_parse_static_node_value, node.values)
822+
# Accept sequence-like nodes (List/Tuple in different Python versions)
823+
elif hasattr(node, 'elts'):
824+
# Sequence-like node (list/tuple) — accept any iterable of elts
825+
elts = [(_parse_static_node_value(e)) for e in getattr(node, 'elts')]
826+
# Preserve tuple vs list if possible by checking node class name
827+
if getattr(node, '__class__', None).__name__ == 'Tuple':
828+
value = tuple(elts)
829+
else:
830+
value = list(elts)
831+
# Handle mapping-like nodes
832+
elif hasattr(node, 'keys') and hasattr(node, 'values'):
833+
keys = list(map(_parse_static_node_value, node.keys))
834+
values = list(map(_parse_static_node_value, node.values))
817835
value = OrderedDict(zip(keys, values))
818-
# value = dict(zip(keys, values))
819836
# Avoid direct reference to ast.NameConstant which is deprecated in
820837
# Python 3.14; access it via getattr so linters won't emit a deprecation
821838
# warning while preserving compatibility with older Pythons.

src/xdoctest/utils/util_import.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ def _extension_module_tags() -> list[str]:
573573
return tags
574574

575575

576-
def _static_parse(varname: typing.Any, fpath: typing.Any) -> str | None:
576+
def _static_parse(varname: typing.Any, fpath: typing.Any) -> typing.Any:
577577
"""
578578
Statically parse the a constant variable from a python file
579579

0 commit comments

Comments
 (0)