Skip to content

Commit aacf497

Browse files
committed
Add support for compound type
Rename named tuple Variable to Argument in astvisitors.py Rename ReturnTypeVisitor to TypeVisitor which now support compound type
1 parent dd74bbd commit aacf497

File tree

5 files changed

+112
-46
lines changed

5 files changed

+112
-46
lines changed

py2puml/parsing/astvisitors.py

Lines changed: 53 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
1+
import ast
22
from typing import Dict, List, Tuple, Type
33
from ast import (
44
NodeVisitor, arg, expr,
@@ -12,50 +12,56 @@
1212
from py2puml.parsing.compoundtypesplitter import CompoundTypeSplitter, SPLITTING_CHARACTERS
1313
from py2puml.parsing.moduleresolver import ModuleResolver
1414

15-
Variable = namedtuple('Variable', ['id', 'type_expr'])
15+
Argument = namedtuple('Argument', ['id', 'type_expr'])
1616

1717

18-
class SignatureVariablesCollector(NodeVisitor):
18+
class SignatureArgumentsCollector(NodeVisitor):
1919
"""
20-
Collects the variables and their type annotations from the signature of a method
20+
Collects the arguments name and type annotations from the signature of a method
2121
"""
2222
def __init__(self, skip_self=False, *args, **kwargs):
2323
super().__init__(*args, **kwargs)
2424
self.skip_self = skip_self
2525
self.class_self_id: str = None
26-
self.variables: List[Variable] = []
26+
self.arguments: List[Argument] = []
27+
self.datatypes = {}
2728

2829
def visit_arg(self, node: arg):
29-
variable = Variable(node.arg, node.annotation)
30-
31-
# first constructor variable is the name for the 'self' reference
30+
argument = Argument(node.arg, node.annotation)
31+
if node.annotation:
32+
type_visitor = TypeVisitor()
33+
datatype = type_visitor.visit(node.annotation)
34+
else:
35+
datatype = None
36+
self.datatypes[node.arg] = datatype
37+
# first constructor argument is the name for the 'self' reference
3238
if self.class_self_id is None and not self.skip_self:
33-
self.class_self_id = variable.id
39+
self.class_self_id = argument.id
3440
# other arguments are constructor parameters
35-
self.variables.append(variable)
41+
self.arguments.append(argument)
3642

3743

3844
class AssignedVariablesCollector(NodeVisitor):
3945
'''Parses the target of an assignment statement to detect whether the value is assigned to a variable or an instance attribute'''
4046
def __init__(self, class_self_id: str, annotation: expr):
4147
self.class_self_id: str = class_self_id
4248
self.annotation: expr = annotation
43-
self.variables: List[Variable] = []
44-
self.self_attributes: List[Variable] = []
49+
self.variables: List[Argument] = []
50+
self.self_attributes: List[Argument] = []
4551

4652
def visit_Name(self, node: Name):
4753
'''
4854
Detects declarations of new variables
4955
'''
5056
if node.id != self.class_self_id:
51-
self.variables.append(Variable(node.id, self.annotation))
57+
self.variables.append(Argument(node.id, self.annotation))
5258

5359
def visit_Attribute(self, node: Attribute):
5460
'''
5561
Detects declarations of new attributes on 'self'
5662
'''
5763
if isinstance(node.value, Name) and node.value.id == self.class_self_id:
58-
self.self_attributes.append(Variable(node.attr, self.annotation))
64+
self.self_attributes.append(Argument(node.attr, self.annotation))
5965

6066
def visit_Subscript(self, node: Subscript):
6167
'''
@@ -76,8 +82,8 @@ def visit_FunctionDef(self, node: FunctionDef):
7682
self.uml_methods.append(method_visitor.uml_method)
7783

7884

79-
class ReturnTypeVisitor(NodeVisitor):
80-
85+
class TypeVisitor(NodeVisitor):
86+
""" Returns a string representation of a data type. Supports nested compound data types """
8187
def __init__(self, *args, **kwargs):
8288
super().__init__(*args, **kwargs)
8389

@@ -87,8 +93,24 @@ def visit_Name(self, node):
8793
def visit_Constant(self, node):
8894
return node.value
8995

90-
def visit_Subscript(self, node):
91-
return node.value.id
96+
def visit_Subscript(self, node: Subscript):
97+
""" Visit node of type ast.Subscript and returns a string representation of the compound datatype. Iterate
98+
over elements contained in slice attribute by calling recursively visit() method of new instances of
99+
TypeVisitor. This allows to resolve nested compound datatype. """
100+
101+
datatypes = []
102+
103+
if hasattr(node.slice.value, 'elts'):
104+
for child_node in node.slice.value.elts:
105+
child_visitor = TypeVisitor()
106+
datatypes.append(child_visitor.visit(child_node))
107+
else:
108+
child_visitor = TypeVisitor()
109+
datatypes.append(child_visitor.visit(node.slice.value))
110+
111+
joined_datatypes = ', '.join(datatypes)
112+
113+
return f'{node.value.id}[{joined_datatypes}]'
92114

93115

94116
class MethodVisitor(NodeVisitor):
@@ -100,32 +122,33 @@ class MethodVisitor(NodeVisitor):
100122

101123
def __init__(self, *args, **kwargs):
102124
super().__init__(*args, **kwargs)
103-
self.variables_namespace: List[Variable] = []
125+
self.variables_namespace: List[Argument] = []
104126
self.uml_method: UmlMethod
105127

106128
def visit_FunctionDef(self, node: FunctionDef):
107129
decorators = [decorator.id for decorator in node.decorator_list]
108130
is_static = 'staticmethod' in decorators
109131
is_class = 'classmethod' in decorators
110-
variables_collector = SignatureVariablesCollector(skip_self=is_static)
111-
variables_collector.visit(node)
112-
self.variables_namespace = variables_collector.variables
132+
arguments_collector = SignatureArgumentsCollector(skip_self=is_static)
133+
arguments_collector.visit(node)
134+
self.variables_namespace = arguments_collector.arguments
113135

114136
self.uml_method = UmlMethod(name=node.name, is_static=is_static, is_class=is_class)
115137

116-
for argument in variables_collector.variables:
117-
if argument.id == variables_collector.class_self_id:
138+
for argument in arguments_collector.arguments:
139+
if argument.id == arguments_collector.class_self_id:
118140
self.uml_method.arguments[argument.id] = None
119141
if argument.type_expr:
120142
if hasattr(argument.type_expr, 'id'):
121143
self.uml_method.arguments[argument.id] = argument.type_expr.id
122144
else:
123-
self.uml_method.arguments[argument.id] = f'Subscript {argument.type_expr.value.id}' #FIXME
145+
146+
self.uml_method.arguments[argument.id] = arguments_collector.datatypes[argument.id]
124147
else:
125148
self.uml_method.arguments[argument.id] = None
126149

127150
if node.returns is not None:
128-
return_visitor = ReturnTypeVisitor()
151+
return_visitor = TypeVisitor()
129152
self.uml_method.return_type = return_visitor.visit(node.returns)
130153

131154

@@ -140,7 +163,7 @@ def __init__(self, constructor_source: str, class_name: str, root_fqn: str, modu
140163
self.root_fqn = root_fqn
141164
self.module_resolver = module_resolver
142165
self.class_self_id: str
143-
self.variables_namespace: List[Variable] = []
166+
self.variables_namespace: List[Argument] = []
144167
self.uml_attributes: List[UmlAttribute] = []
145168
self.uml_relations_by_target_fqn: Dict[str, UmlRelation] = {}
146169

@@ -153,7 +176,7 @@ def extend_relations(self, target_fqns: List[str]):
153176
)
154177
})
155178

156-
def get_from_namespace(self, variable_id: str) -> Variable:
179+
def get_from_namespace(self, variable_id: str) -> Argument:
157180
return next((
158181
variable
159182
# variables namespace is iterated antichronologically
@@ -168,10 +191,10 @@ def generic_visit(self, node):
168191
def visit_FunctionDef(self, node: FunctionDef):
169192
# retrieves constructor arguments ('self' reference and typed arguments)
170193
if node.name == '__init__':
171-
variables_collector = SignatureVariablesCollector()
194+
variables_collector = SignatureArgumentsCollector()
172195
variables_collector.visit(node)
173196
self.class_self_id: str = variables_collector.class_self_id
174-
self.variables_namespace = variables_collector.variables
197+
self.variables_namespace = variables_collector.arguments
175198

176199
self.generic_visit(node)
177200

tests/asserts/variable.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11

22
from ast import get_source_segment
33

4-
from py2puml.parsing.astvisitors import Variable
4+
from py2puml.parsing.astvisitors import Argument
55

6-
def assert_Variable(variable: Variable, id: str, type_str: str, source_code: str):
6+
def assert_Variable(variable: Argument, id: str, type_str: str, source_code: str):
77
assert variable.id == id
88
if type_str is None:
99
assert variable.type_expr == None, 'no type annotation'

tests/modules/withmethods/withmethods.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from tests import modules
77
from tests.modules.withenum import TimeUnit
88

9+
910
class Coordinates:
1011
def __init__(self, x: float, y: float) -> None:
1112
self.x = x
@@ -23,7 +24,7 @@ def from_values(x: int, y: str) -> Point:
2324
def get_coordinates(self) -> Tuple[float, str]:
2425
return self.x, self.y
2526

26-
def __init__(self, x: int, y: str) -> None:
27+
def __init__(self, x: int, y: Tuple[bool]) -> None:
2728
self.coordinates: Coordinates = Coordinates(x, float(y))
2829
self.day_unit: withenum.TimeUnit = withenum.TimeUnit.DAYS
2930
self.hour_unit: modules.withenum.TimeUnit = modules.withenum.TimeUnit.HOURS

tests/puml_files/with_methods.puml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ class tests.modules.withmethods.withmethods.Point {
1010
hour_unit: TimeUnit
1111
time_resolution: Tuple[str, TimeUnit]
1212
x: int
13-
y: str
13+
y: Tuple[bool]
1414
{static} Point from_values(int x, str y)
1515
Tuple[float, str] get_coordinates(self)
16-
__init__(self, int x, str y)
16+
__init__(self, int x, Tuple[bool] y)
1717
int do_something(self, posarg_nohint, str posarg_hint, posarg_default)
1818
}
1919
class tests.modules.withmethods.withinheritedmethods.ThreeDimensionalPoint {

tests/py2puml/parsing/test_astvisitors.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
1+
import unittest
22
from typing import Dict, Tuple, List
33

44
from ast import parse, AST, get_source_segment
@@ -7,7 +7,7 @@
77

88
from pytest import mark
99

10-
from py2puml.parsing.astvisitors import AssignedVariablesCollector, SignatureVariablesCollector, Variable, shorten_compound_type_annotation
10+
from py2puml.parsing.astvisitors import AssignedVariablesCollector, TypeVisitor, SignatureArgumentsCollector, Argument, shorten_compound_type_annotation
1111
from py2puml.parsing.moduleresolver import ModuleResolver
1212

1313
from tests.asserts.variable import assert_Variable
@@ -32,18 +32,18 @@ def test_SignatureVariablesCollector_collect_arguments():
3232
constructor_source: str = dedent(getsource(ParseMyConstructorArguments.__init__.__code__))
3333
constructor_ast: AST = parse(constructor_source)
3434

35-
collector = SignatureVariablesCollector()
35+
collector = SignatureArgumentsCollector()
3636
collector.visit(constructor_ast)
3737

3838
assert collector.class_self_id == 'me'
39-
assert len(collector.variables) == 7, 'all the arguments must be detected'
40-
assert_Variable(collector.variables[0], 'me', None, constructor_source)
41-
assert_Variable(collector.variables[1], 'an_int', 'int', constructor_source)
42-
assert_Variable(collector.variables[2], 'an_untyped', None, constructor_source)
43-
assert_Variable(collector.variables[3], 'a_compound_type', 'Tuple[float, Dict[str, List[bool]]]', constructor_source)
44-
assert_Variable(collector.variables[4], 'a_default_string', 'str', constructor_source)
45-
assert_Variable(collector.variables[5], 'args', None, constructor_source)
46-
assert_Variable(collector.variables[6], 'kwargs', None, constructor_source)
39+
assert len(collector.arguments) == 7, 'all the arguments must be detected'
40+
assert_Variable(collector.arguments[0], 'me', None, constructor_source)
41+
assert_Variable(collector.arguments[1], 'an_int', 'int', constructor_source)
42+
assert_Variable(collector.arguments[2], 'an_untyped', None, constructor_source)
43+
assert_Variable(collector.arguments[3], 'a_compound_type', 'Tuple[float, Dict[str, List[bool]]]', constructor_source)
44+
assert_Variable(collector.arguments[4], 'a_default_string', 'str', constructor_source)
45+
assert_Variable(collector.arguments[5], 'args', None, constructor_source)
46+
assert_Variable(collector.arguments[6], 'kwargs', None, constructor_source)
4747

4848
@mark.parametrize(
4949
'class_self_id,assignment_code,annotation_as_str,self_attributes,variables', [
@@ -190,3 +190,45 @@ def test_shorten_compound_type_annotation(full_annotation: str, short_annotation
190190
shortened_annotation, full_namespaced_definitions = shorten_compound_type_annotation(full_annotation, module_resolver)
191191
assert shortened_annotation == short_annotation
192192
assert full_namespaced_definitions == namespaced_definitions
193+
194+
195+
class TestTypeVisitor(unittest.TestCase):
196+
197+
def test_return_type_int(self):
198+
source_code = 'def dummy_function() -> int:\n pass'
199+
ast = parse(source_code)
200+
node = ast.body[0].returns
201+
visitor = TypeVisitor()
202+
actual_rtype = visitor.visit(node)
203+
expected_rtype = 'int'
204+
self.assertEqual(expected_rtype, actual_rtype)
205+
206+
def test_return_type_compound(self):
207+
""" Test non-nested compound datatype"""
208+
source_code = 'def dummy_function() -> Tuple[float, str]:\n pass'
209+
ast = parse(source_code)
210+
node = ast.body[0].returns
211+
visitor = TypeVisitor()
212+
actual_rtype = visitor.visit(node)
213+
expected_rtype = 'Tuple[float, str]'
214+
self.assertEqual(expected_rtype, actual_rtype)
215+
216+
def test_return_type_compound_nested(self):
217+
""" Test nested compound datatype"""
218+
source_code = 'def dummy_function() -> Tuple[float, Dict[str, List[bool]]]:\n pass'
219+
ast = parse(source_code)
220+
node = ast.body[0].returns
221+
visitor = TypeVisitor()
222+
actual_rtype = visitor.visit(node)
223+
expected_rtype = 'Tuple[float, Dict[str, List[bool]]]'
224+
self.assertEqual(expected_rtype, actual_rtype)
225+
226+
def test_return_type_user_defined(self):
227+
""" Test user-defined class datatype"""
228+
source_code = 'def dummy_function() -> Point:\n pass'
229+
ast = parse(source_code)
230+
node = ast.body[0].returns
231+
visitor = TypeVisitor()
232+
actual_rtype = visitor.visit(node)
233+
expected_rtype = 'Point'
234+
self.assertEqual(expected_rtype, actual_rtype)

0 commit comments

Comments
 (0)