Skip to content

Commit 003987a

Browse files
authored
Fix calling class methods from constructors (pyccel#2087)
Visit class methods in the order that they are called. Fixes pyccel#2085
1 parent 97bdb73 commit 003987a

File tree

9 files changed

+189
-63
lines changed

9 files changed

+189
-63
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ All notable changes to this project will be documented in this file.
9090
- #2041 : Include all type extension methods by default.
9191
- #2082 : Allow the use of a list comprehension to initialise an array.
9292
- #2094 : Fix slicing of array allocated in an if block.
93+
- #2085 : Fix calling class methods before they are defined.
9394

9495
### Changed
9596

pyccel/ast/core.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3449,6 +3449,47 @@ def add_new_interface(self, interface):
34493449
interface.set_current_user_node(self)
34503450
self._interfaces += (interface,)
34513451

3452+
def update_method(self, syntactic_method, semantic_method):
3453+
"""
3454+
Replace a syntactic_method with its semantic equivalent.
3455+
3456+
Replace a syntactic_method with its semantic equivalent.
3457+
3458+
Parameters
3459+
----------
3460+
syntactic_method : FunctionDef
3461+
The method that has already been added to the class.
3462+
semantic_method : FunctionDef
3463+
The method that will replace the syntactic_method.
3464+
"""
3465+
assert isinstance(semantic_method, FunctionDef)
3466+
assert syntactic_method in self._methods
3467+
assert semantic_method.is_semantic
3468+
syntactic_method.remove_user_node(self)
3469+
semantic_method.set_current_user_node(self)
3470+
self._methods = tuple(m for m in self._methods if m is not syntactic_method) + (semantic_method,)
3471+
3472+
def update_interface(self, syntactic_interface, semantic_interface):
3473+
"""
3474+
Replace a syntactic_interface with its semantic equivalent.
3475+
3476+
Replace a syntactic_interface with its semantic equivalent.
3477+
3478+
Parameters
3479+
----------
3480+
syntactic_interface : FunctionDef
3481+
The interface that has already been added to the class.
3482+
semantic_interface : FunctionDef
3483+
The interface that will replace the syntactic_interface.
3484+
"""
3485+
assert isinstance(semantic_interface, Interface)
3486+
assert syntactic_interface in self._methods
3487+
assert semantic_interface.is_semantic
3488+
syntactic_interface.remove_user_node(self)
3489+
semantic_interface.set_current_user_node(self)
3490+
self._methods = tuple(m for m in self._methods if m is not syntactic_interface)
3491+
self._interfaces = tuple(m for m in self._interfaces if m is not syntactic_interface) + (semantic_interface,)
3492+
34523493
def get_method(self, name, raise_error = True):
34533494
"""
34543495
Get the method `name` of the current class.
@@ -3478,6 +3519,11 @@ def get_method(self, name, raise_error = True):
34783519
ValueError
34793520
Raised if the method cannot be found.
34803521
"""
3522+
method = next((i for i in chain(self.methods, self.interfaces) \
3523+
if i.name == name and i.pyccel_staging == 'syntactic'), None)
3524+
if method:
3525+
return method
3526+
34813527
if self.scope is not None:
34823528
# Collect translated name from scope
34833529
try:

pyccel/codegen/printing/ccode.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from pyccel.ast.builtin_methods.dict_methods import DictItems
2323

24-
from pyccel.ast.core import Declare, For, CodeBlock
24+
from pyccel.ast.core import Declare, For, CodeBlock, ClassDef
2525
from pyccel.ast.core import FuncAddressDeclare, FunctionCall, FunctionCallArgument
2626
from pyccel.ast.core import Allocate, Deallocate
2727
from pyccel.ast.core import FunctionAddress
@@ -636,7 +636,8 @@ def _handle_inline_func_call(self, expr):
636636
if parent_assign:
637637
body.substitute(new_res_vars, orig_res_vars)
638638

639-
if func.global_vars or func.global_funcs:
639+
if func.global_vars or func.global_funcs and \
640+
not func.get_direct_user_nodes(lambda u: isinstance(u, ClassDef)):
640641
mod = func.get_direct_user_nodes(lambda x: isinstance(x, Module))[0]
641642
self.add_import(Import(mod.name, [AsName(v, v.name) \
642643
for v in (*func.global_vars, *func.global_funcs)]))

pyccel/codegen/printing/fcode.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
from pyccel.ast.core import FunctionDef, FunctionDefArgument, FunctionDefResult
3232
from pyccel.ast.core import SeparatorComment, Comment
33-
from pyccel.ast.core import ConstructorCall
33+
from pyccel.ast.core import ConstructorCall, ClassDef
3434
from pyccel.ast.core import FunctionCallArgument
3535
from pyccel.ast.core import FunctionAddress
3636
from pyccel.ast.core import Return, Module, For
@@ -464,7 +464,9 @@ def _handle_inline_func_call(self, expr, assign_lhs = None):
464464

465465
for i in func.imports:
466466
self.add_import(i)
467-
if func.global_vars or func.global_funcs:
467+
468+
if (func.global_vars or func.global_funcs) and \
469+
not func.get_direct_user_nodes(lambda u: isinstance(u, ClassDef)):
468470
mod = func.get_direct_user_nodes(lambda x: isinstance(x, Module))[0]
469471
current_mod = expr.get_user_nodes(Module, excluded_nodes=(FunctionCall,))[0]
470472
if current_mod is not mod:
@@ -2495,7 +2497,8 @@ def _print_ClassDef(self, expr):
24952497

24962498
aliases = []
24972499
names = []
2498-
methods = ''.join(f'procedure :: {method.name} => {method.cls_name}\n' for method in expr.methods)
2500+
methods = ''.join(f'procedure :: {method.name} => {method.cls_name}\n' for method in expr.methods \
2501+
if not method.is_inline)
24992502
for i in expr.interfaces:
25002503
names = ','.join(f.cls_name for f in i.functions if not f.is_inline)
25012504
if names:

pyccel/codegen/printing/pycode.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,8 @@ def _print_FunctionCall(self, expr):
542542
args = expr.args
543543
if func.arguments and func.arguments[0].bound_argument:
544544
func_name = f'{self._print(args[0])}.{func_name}'
545+
if 'property' in func.decorators:
546+
return func_name
545547
args = args[1:]
546548
args_str = ', '.join(self._print(i) for i in args)
547549
code = f'{func_name}({args_str})'

pyccel/parser/semantic.py

Lines changed: 64 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2966,7 +2966,6 @@ def _visit_DottedName(self, expr):
29662966

29672967
# look for a class method
29682968
if isinstance(rhs, FunctionCall):
2969-
method = cls_base.get_method(rhs_name)
29702969
macro = self.scope.find(rhs_name, 'macros')
29712970
if macro is not None:
29722971
master = macro.master
@@ -2976,7 +2975,15 @@ def _visit_DottedName(self, expr):
29762975
args = macro.apply(args)
29772976
return FunctionCall(master, args, self._current_function)
29782977

2978+
method = cls_base.get_method(rhs_name)
2979+
29792980
args = [FunctionCallArgument(visited_lhs), *self._handle_function_args(rhs.args)]
2981+
if not method.is_semantic:
2982+
if not method.is_inline:
2983+
method = self._annotate_the_called_function_def(method)
2984+
else:
2985+
method = self._annotate_the_called_function_def(method, function_call_args=args)
2986+
29802987
if cls_base.name == 'numpy.ndarray':
29812988
numpy_class = method.cls_name
29822989
self.insert_import('numpy', AsName(numpy_class, numpy_class.name))
@@ -2992,6 +2999,12 @@ def _visit_DottedName(self, expr):
29922999
# class property?
29933000
else:
29943001
method = cls_base.get_method(rhs_name)
3002+
if not method.is_semantic:
3003+
if not method.is_inline:
3004+
method = self._annotate_the_called_function_def(method)
3005+
else:
3006+
method = self._annotate_the_called_function_def(method,
3007+
function_call_args=(FunctionCallArgument(visited_lhs),))
29953008
assert 'property' in method.decorators
29963009
if cls_base.name == 'numpy.ndarray':
29973010
numpy_class = method.cls_name
@@ -4348,7 +4361,7 @@ def unpack(ann):
43484361
if cls_name:
43494362
# update the class methods
43504363
if not is_interface:
4351-
bound_class.add_new_method(func)
4364+
bound_class.update_method(expr, func)
43524365

43534366
new_semantic_funcs += [func]
43544367
if expr.python_ast:
@@ -4374,7 +4387,7 @@ def unpack(ann):
43744387
if expr.python_ast:
43754388
new_semantic_funcs.set_current_ast(expr.python_ast)
43764389
if cls_name:
4377-
bound_class.add_new_interface(new_semantic_funcs)
4390+
bound_class.update_interface(expr, new_semantic_funcs)
43784391
self.insert_function(new_semantic_funcs)
43794392

43804393
return EmptyNode()
@@ -4455,78 +4468,71 @@ def _visit_ClassDef(self, expr):
44554468
docstring = docstring, class_type = dtype)
44564469
self.scope.parent_scope.insert_class(cls)
44574470

4458-
methods = list(expr.methods)
4459-
init_func = None
4471+
methods = expr.methods
4472+
for method in methods:
4473+
cls.add_new_method(method)
44604474

4461-
if not any(method.name == '__init__' for method in methods):
4475+
syntactic_init_func = next((method for method in methods if method.name == '__init__'), None)
4476+
if syntactic_init_func is None:
44624477
argument = FunctionDefArgument(Variable(dtype, 'self', cls_base = cls), bound_argument = True)
44634478
self.scope.insert_symbol('__init__')
44644479
scope = self.create_new_function_scope('__init__')
44654480
init_func = FunctionDef('__init__', [argument], (), [], cls_name=cls.name, scope=scope)
44664481
self.exit_function_scope()
44674482
self.insert_function(init_func)
44684483
cls.add_new_method(init_func)
4469-
methods.append(init_func)
4470-
4471-
for (i, method) in enumerate(methods):
4472-
m_name = method.name
4473-
if m_name == '__init__':
4474-
if init_func is None:
4475-
self._visit(method)
4476-
init_func = self.scope.functions.pop(m_name)
4477-
4478-
if isinstance(init_func, Interface):
4479-
errors.report("Pyccel does not support interface constructor", symbol=method,
4480-
severity='fatal')
4481-
methods.pop(i)
4482-
4483-
# create a new attribute to check allocation
4484-
deallocater_lhs = Variable(dtype, 'self', cls_base = cls, is_argument=True)
4485-
deallocater = DottedVariable(lhs = deallocater_lhs, name = self.scope.get_new_name('is_freed'),
4486-
class_type = PythonNativeBool(), is_private=True)
4487-
cls.add_new_attribute(deallocater)
4488-
deallocater_assign = Assign(deallocater, LiteralFalse())
4489-
init_func.body.insert2body(deallocater_assign, back=False)
4490-
break
4491-
4492-
if not init_func:
4493-
errors.report(UNDEFINED_INIT_METHOD, symbol=name,
4494-
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
4495-
severity='error')
4484+
else:
4485+
self._visit(syntactic_init_func)
4486+
init_func = self.scope.functions.pop('__init__')
44964487

4497-
for i in methods:
4498-
self._visit(i)
4488+
if isinstance(init_func, Interface):
4489+
errors.report("Pyccel does not support interface constructor", symbol=init_func,
4490+
severity='fatal')
44994491

4500-
if not any(method.name == '__del__' for method in methods):
4492+
# create a new attribute to check allocation
4493+
deallocater_lhs = Variable(dtype, 'self', cls_base = cls, is_argument=True)
4494+
deallocater = DottedVariable(lhs = deallocater_lhs, name = self.scope.get_new_name('is_freed'),
4495+
class_type = PythonNativeBool(), is_private=True)
4496+
cls.add_new_attribute(deallocater)
4497+
deallocater_assign = Assign(deallocater, LiteralFalse())
4498+
init_func.body.insert2body(deallocater_assign, back=False)
4499+
4500+
syntactic_method = next((m for m in cls.methods if not m.is_semantic), None)
4501+
while syntactic_method:
4502+
self._visit(syntactic_method)
4503+
syntactic_method = next((m for m in cls.methods if not m.is_semantic), None)
4504+
4505+
syntactic_del_func = next((method for method in methods if method.name == '__del__'), None)
4506+
if syntactic_del_func is None:
45014507
argument = FunctionDefArgument(Variable(dtype, 'self', cls_base = cls), bound_argument = True)
45024508
self.scope.insert_symbol('__del__')
45034509
scope = self.create_new_function_scope('__del__')
45044510
del_method = FunctionDef('__del__', [argument], (), [Pass()], scope=scope)
45054511
self.exit_function_scope()
45064512
self.insert_function(del_method)
45074513
cls.add_new_method(del_method)
4508-
4509-
for method in cls.methods:
4510-
if method.name == '__del__':
4511-
self._current_function = method.name
4512-
attribute = []
4513-
for attr in cls.attributes:
4514-
if not attr.on_stack:
4515-
attribute.append(attr)
4516-
elif isinstance(attr.class_type, CustomDataType) and not attr.is_alias:
4517-
attribute.append(attr)
4518-
if attribute:
4519-
# Create a new list that store local attributes
4520-
self._allocs.append(set())
4521-
self._pointer_targets.append({})
4522-
self._allocs[-1].update(attribute)
4523-
method.body.insert2body(*self._garbage_collector(method.body))
4524-
self._pointer_targets.pop()
4525-
condition = If(IfSection(PyccelNot(deallocater),
4526-
[method.body]+[Assign(deallocater, LiteralTrue())]))
4527-
method.body = [condition]
4528-
self._current_function = None
4529-
break
4514+
else:
4515+
del_method = cls.get_method('__del__')
4516+
4517+
# Add destructors to __del__ method
4518+
self._current_function = del_method.name
4519+
attribute = []
4520+
for attr in cls.attributes:
4521+
if not attr.on_stack:
4522+
attribute.append(attr)
4523+
elif isinstance(attr.class_type, CustomDataType) and not attr.is_alias:
4524+
attribute.append(attr)
4525+
if attribute:
4526+
# Create a new list that store local attributes
4527+
self._allocs.append(set())
4528+
self._pointer_targets.append({})
4529+
self._allocs[-1].update(attribute)
4530+
del_method.body.insert2body(*self._garbage_collector(del_method.body))
4531+
self._pointer_targets.pop()
4532+
condition = If(IfSection(PyccelNot(deallocater),
4533+
[del_method.body]+[Assign(deallocater, LiteralTrue())]))
4534+
del_method.body = [condition]
4535+
self._current_function = None
45304536

45314537
self.exit_class_scope()
45324538

tests/epyccel/classes/classes_8.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# pylint: disable=missing-function-docstring, missing-module-docstring, missing-class-docstring
2+
3+
class A:
4+
def __init__(self, x : float):
5+
self._x = x
6+
self._y : float
7+
self._construct_y()
8+
self._construct_y_from_z(self.y)
9+
self._construct_y_from_z(3)
10+
11+
def _construct_y(self):
12+
self._y = self.x + 3
13+
14+
@property
15+
def x(self):
16+
return self._x
17+
18+
@property
19+
def y(self):
20+
return self._y
21+
22+
def _construct_y_from_z(self, z : 'int | float'):
23+
self._y = self._x + z

tests/epyccel/classes/classes_9.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# pylint: disable=missing-function-docstring, missing-module-docstring, missing-class-docstring
2+
from pyccel.decorators import inline
3+
4+
class A:
5+
def __init__(self, x : float):
6+
self._x = x
7+
self._y = self._calculate_y(2)
8+
9+
@property
10+
def x(self):
11+
return self._x
12+
13+
@inline
14+
def _calculate_y(self, n : int):
15+
return self.x + n
16+
17+
def get_A_contents(self):
18+
return self.x, self.y
19+
20+
@inline
21+
@property
22+
def y(self):
23+
return self._y

tests/epyccel/test_epyccel_classes.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,27 @@ def test_classes_6(language):
210210
assert p_py.get_attributes(3) == p_l.get_attributes(3)
211211
assert p_py.get_attributes(4.5) == p_l.get_attributes(4.5)
212212

213+
def test_classes_8(language):
214+
import classes.classes_8 as mod
215+
modnew = epyccel(mod, language = language)
216+
217+
a_py = mod.A(3.0)
218+
a_l = modnew.A(3.0)
219+
220+
assert a_py.x == a_l.x
221+
assert a_py.y == a_l.y
222+
223+
def test_classes_9(language):
224+
import classes.classes_9 as mod
225+
modnew = epyccel(mod, language = language)
226+
227+
a_py = mod.A(3.0)
228+
a_l = modnew.A(3.0)
229+
230+
assert a_py.get_A_contents() == a_l.get_A_contents()
231+
assert a_py.x == a_l.x
232+
assert a_py.y == a_l.y
233+
213234
def test_generic_methods(language):
214235
import classes.generic_methods as mod
215236
modnew = epyccel(mod, language = language)

0 commit comments

Comments
 (0)