Skip to content

Commit c03cd70

Browse files
committed
Implement function for getting fields from inheritance chain
1 parent 8729faa commit c03cd70

File tree

2 files changed

+132
-17
lines changed

2 files changed

+132
-17
lines changed

src/ir/ast.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# pylint: disable=dangerous-default-value
2-
from typing import List
2+
from typing import List, Set, Union
33
from copy import deepcopy
44

55
import src.ir.type_utils as tu
@@ -630,7 +630,7 @@ def get_overridable_fields(self):
630630
if f.can_override
631631
]
632632

633-
def get_callable_functions(self, class_decls) -> List[FunctionDeclaration]:
633+
def get_callable_functions(self, class_decls) -> Set[FunctionDeclaration]:
634634
"""All functions that can be called in instantiations of this class
635635
"""
636636
# Get functions that are implemented in the current class
@@ -668,6 +668,48 @@ def get_callable_functions(self, class_decls) -> List[FunctionDeclaration]:
668668

669669
return functions
670670

671+
def get_all_fields(self, class_decls) -> Set[FieldDeclaration]:
672+
"""
673+
All fields (including the inheritted ones) that can be accessed by
674+
instantiations of this class.
675+
"""
676+
fields = set(self.fields)
677+
field_names = {f.name for f in fields}
678+
679+
if not self.superclasses:
680+
return fields
681+
682+
# Retrieve fields from the inheritance chain.
683+
super_cls = self.superclasses[0]
684+
class_decl = tu.get_superclass_decl(super_cls, class_decls)
685+
686+
if not class_decl:
687+
return fields
688+
689+
type_var_map = tu.get_superclass_type_var_map(super_cls, class_decl)
690+
691+
parent_fields = class_decl.get_all_fields(class_decls)
692+
693+
# substitute type variables in parent's functions
694+
for f in parent_fields:
695+
if f.name in field_names:
696+
# We override this field in the current class
697+
continue
698+
new_f = deepcopy(f)
699+
new_f.field_type = types.substitute_type(f.get_type(),
700+
type_var_map)
701+
fields.add(new_f)
702+
703+
return fields
704+
705+
def get_all_attributes(self, class_decls) -> Set[Union[FunctionDeclaration, FieldDeclaration]]:
706+
"""
707+
Get all attributes (fields + functions) from the inheritance chain
708+
"""
709+
attributes = self.get_callable_functions(class_decls)
710+
attributes.update(self.get_all_fields(class_decls))
711+
return attributes
712+
671713
def get_abstract_functions(self, class_decls) -> List[FunctionDeclaration]:
672714
# Get the abstract functions that are declared in the current class.
673715
functions = {

tests/test_ir.py

Lines changed: 88 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from src.ir.kotlin_types import *
88

99

10-
def assert_abstract_funcs(actual, expected):
10+
def assert_declarations(actual, expected):
1111
assert len(actual) == len(expected)
1212
table = {f.name: f for f in expected}
1313
for f in actual:
@@ -183,8 +183,8 @@ def test_get_abstract_functions():
183183
functions=[func2])
184184

185185
assert cls1.get_abstract_functions([cls1, cls2]) == {func1}
186-
assert_abstract_funcs(cls2.get_abstract_functions([cls1, cls2]),
187-
{func1, func2})
186+
assert_declarations(cls2.get_abstract_functions([cls1, cls2]),
187+
{func1, func2})
188188

189189

190190
def test_get_abstract_functions_chain():
@@ -199,18 +199,17 @@ def test_get_abstract_functions_chain():
199199
functions=[func2])
200200

201201
assert cls1.get_abstract_functions([cls1, cls2, cls3]) == {func1}
202-
assert_abstract_funcs(cls2.get_abstract_functions([cls1, cls2, cls3]),
203-
{func1})
204-
assert_abstract_funcs(cls3.get_abstract_functions([cls1, cls2, cls3]),
205-
{func1, func2})
202+
assert_declarations(cls2.get_abstract_functions([cls1, cls2, cls3]),
203+
{func1})
204+
assert_declarations(cls3.get_abstract_functions([cls1, cls2, cls3]),
205+
{func1, func2})
206206

207207
override_func = deepcopy(func1)
208208
func1.body = IntegerConstant(1, Integer)
209209
cls2.functions = [func1]
210-
assert_abstract_funcs(cls2.get_abstract_functions([cls1, cls2, cls3]),
211-
[])
212-
assert_abstract_funcs(cls3.get_abstract_functions([cls1, cls2, cls3]),
213-
{func2})
210+
assert_declarations(cls2.get_abstract_functions([cls1, cls2, cls3]), [])
211+
assert_declarations(cls3.get_abstract_functions([cls1, cls2, cls3]),
212+
{func2})
214213

215214

216215
def test_get_abstract_functions_parameterized():
@@ -232,8 +231,8 @@ def test_get_abstract_functions_parameterized():
232231
exp_func1.params[0].param_type = String
233232
exp_func1.ret_type = String
234233
exp_func1.inferred_type = String
235-
assert_abstract_funcs(cls2.get_abstract_functions([cls1, cls2]),
236-
[exp_func1])
234+
assert_declarations(cls2.get_abstract_functions([cls1, cls2]),
235+
[exp_func1])
237236

238237

239238
def test_get_abstract_functions_parameterized_chain():
@@ -262,5 +261,79 @@ def test_get_abstract_functions_parameterized_chain():
262261
exp_func1.params[0].param_type = actual_t
263262
exp_func1.ret_type = actual_t
264263
exp_func1.inferred_type = actual_t
265-
assert_abstract_funcs(cls2.get_abstract_functions([cls1, cls2, cls3]),
266-
[exp_func1])
264+
assert_declarations(cls2.get_abstract_functions([cls1, cls2, cls3]),
265+
[exp_func1])
266+
267+
268+
def test_get_fields():
269+
field1 = FieldDeclaration("foo", String)
270+
field2 = FieldDeclaration("bar", Any)
271+
cls1 = ClassDeclaration("A", [], 0, fields=[field1])
272+
cls2 = ClassDeclaration("B",
273+
[SuperClassInstantiation(cls1.get_type(), [])],
274+
fields=[field2])
275+
276+
assert cls1.get_all_fields([cls1, cls2]) == {field1}
277+
assert_declarations(cls2.get_all_fields([cls1, cls2]),
278+
{field1, field2})
279+
280+
cls1 = ClassDeclaration("A", [], fields=[field1])
281+
cls2 = ClassDeclaration("B",
282+
[SuperClassInstantiation(cls1.get_type(), [])],
283+
fields=[])
284+
cls3 = ClassDeclaration("C",
285+
[SuperClassInstantiation(cls2.get_type(), [])],
286+
fields=[field2])
287+
288+
assert cls1.get_all_fields([cls1, cls2, cls3]) == {field1}
289+
assert_declarations(cls2.get_all_fields([cls1, cls2, cls3]),
290+
{field1})
291+
assert_declarations(cls3.get_all_fields([cls1, cls2, cls3]),
292+
{field1, field2})
293+
294+
295+
def test_get_fields_parameterized():
296+
type_param1 = TypeParameter("T")
297+
field1 = FieldDeclaration("foo", type_param1)
298+
field2 = FieldDeclaration("bar", type_param1)
299+
cls1 = ClassDeclaration("A", [],
300+
type_parameters=[type_param1],
301+
fields=[field1, field2])
302+
cls2 = ClassDeclaration(
303+
"B", [SuperClassInstantiation(cls1.get_type().new([String]), [])],
304+
fields=[]
305+
)
306+
307+
assert cls1.get_all_fields([cls1, cls2]) == {field1, field2}
308+
exp_field1 = deepcopy(field1)
309+
exp_field2 = deepcopy(field2)
310+
exp_field1.field_type = String
311+
exp_field2.field_type = String
312+
assert_declarations(cls2.get_all_fields([cls1, cls2]),
313+
[field1, field2])
314+
315+
316+
cls1 = ClassDeclaration("A", [],
317+
type_parameters=[type_param1],
318+
fields=[field1])
319+
320+
t_con = TypeConstructor("Foo", [TypeParameter("T")])
321+
type_param2 = TypeParameter("T")
322+
field2 = FieldDeclaration("bar", type_param2)
323+
t = t_con.new([type_param2])
324+
cls2 = ClassDeclaration(
325+
"B", [SuperClassInstantiation(cls1.get_type().new([t]), [])],
326+
type_parameters=[type_param2],
327+
fields=[field2]
328+
)
329+
field3 = FieldDeclaration("bar", String)
330+
cls3 = ClassDeclaration(
331+
"C", [SuperClassInstantiation(cls2.get_type().new([String]), [])],
332+
fields=[field3]
333+
)
334+
actual_t = t_con.new([String])
335+
exp_field = deepcopy(field1)
336+
exp_field.field_type = actual_t
337+
assert_declarations(cls2.get_all_fields([cls1, cls2, cls3]),
338+
[exp_field, field3])
339+

0 commit comments

Comments
 (0)