Skip to content

Commit 486e6de

Browse files
authored
Change the order of the arguments of FunctionDef (pyccel#2101)
In order to fix pyccel#337 the handling of function results must be changed. Specifically, the handling of the case where nothing is returned is non-trivial. It is therefore better to handle this with a default value, however adding a default value requires the order of the arguments of the `FunctionDef` class to be changed. This PR makes that change in preparation
1 parent 003987a commit 486e6de

File tree

9 files changed

+66
-68
lines changed

9 files changed

+66
-68
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ All notable changes to this project will be documented in this file.
125125
- \[INTERNALS\] Stop using ndarrays as an intermediate step to return arrays from Fortran code.
126126
- \[INTERNALS\] Unify the strategy for handling additional imports in the printing stage for different languages.
127127
- \[INTERNALS\] Make `Iterable` into a super-class instead of a storage class.
128+
- \[INTERNALS\] Change the order of the constructor arguments of `FunctionDef`.
128129

129130
### Deprecated
130131

pyccel/ast/core.py

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2047,12 +2047,12 @@ class FunctionDef(ScopedAstNode):
20472047
arguments : iterable of FunctionDefArgument
20482048
The arguments to the function.
20492049
2050-
results : iterable
2051-
The direct outputs of the function.
2052-
20532050
body : iterable
20542051
The body of the function.
20552052
2053+
results : iterable
2054+
The direct outputs of the function.
2055+
20562056
global_vars : list of Symbols
20572057
Variables which will not be passed into the function.
20582058
@@ -2161,8 +2161,9 @@ def __init__(
21612161
self,
21622162
name,
21632163
arguments,
2164-
results,
21652164
body,
2165+
results = (),
2166+
*,
21662167
global_vars=(),
21672168
cls_name=None,
21682169
is_static=False,
@@ -2210,15 +2211,10 @@ def __init__(
22102211

22112212
if iterable(body):
22122213
body = CodeBlock(body)
2213-
elif not isinstance(body,CodeBlock):
2214-
raise TypeError('body must be an iterable or a CodeBlock')
2214+
assert isinstance(body,CodeBlock)
22152215

22162216
# results
2217-
2218-
if not iterable(results):
2219-
raise TypeError('results must be an iterable')
2220-
if not all(isinstance(r, FunctionDefResult) for r in results):
2221-
raise TypeError('results must be all be FunctionDefResults')
2217+
assert iterable(results) and all(isinstance(r, FunctionDefResult) for r in results)
22222218

22232219
if cls_name:
22242220

@@ -2527,28 +2523,28 @@ def __getnewargs__(self):
25272523
args = (
25282524
self._name,
25292525
self._arguments,
2530-
self._results,
25312526
self._body)
25322527

25332528
kwargs = {
2534-
'global_vars':self._global_vars,
2535-
'cls_name':self._cls_name,
2536-
'is_static':self._is_static,
2537-
'imports':self._imports,
2538-
'decorators':self._decorators,
2539-
'headers':self._headers,
2540-
'is_recursive':self._is_recursive,
2541-
'is_pure':self._is_pure,
2542-
'is_elemental':self._is_elemental,
2543-
'is_private':self._is_private,
2544-
'is_header':self._is_header,
2545-
'functions':self._functions,
2546-
'is_external':self._is_external,
2547-
'is_imported':self._is_imported,
2548-
'is_semantic':self._is_semantic,
2549-
'interfaces':self._interfaces,
2550-
'docstring':self._docstring,
2551-
'scope':self._scope}
2529+
'results':self._results,
2530+
'global_vars':self._global_vars,
2531+
'cls_name':self._cls_name,
2532+
'is_static':self._is_static,
2533+
'imports':self._imports,
2534+
'decorators':self._decorators,
2535+
'headers':self._headers,
2536+
'is_recursive':self._is_recursive,
2537+
'is_pure':self._is_pure,
2538+
'is_elemental':self._is_elemental,
2539+
'is_private':self._is_private,
2540+
'is_header':self._is_header,
2541+
'functions':self._functions,
2542+
'is_external':self._is_external,
2543+
'is_imported':self._is_imported,
2544+
'is_semantic':self._is_semantic,
2545+
'interfaces':self._interfaces,
2546+
'docstring':self._docstring,
2547+
'scope':self._scope}
25522548
return args, kwargs
25532549

25542550
def __reduce_ex__(self, i):
@@ -3076,7 +3072,7 @@ def __init__(
30763072
memory_handling='stack',
30773073
**kwargs
30783074
):
3079-
super().__init__(name, arguments, results, body=[], scope=None, **kwargs)
3075+
super().__init__(name, arguments, body=[], results=results, scope=None, **kwargs)
30803076
if not isinstance(is_argument, bool):
30813077
raise TypeError('Expecting a boolean for is_argument')
30823078

pyccel/ast/cwrapper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -510,8 +510,8 @@ def __init__(self, name, *args, external_funcs = (), declarations = (), init_fun
510510
self._external_funcs = external_funcs
511511
self._declarations = declarations
512512
if import_func is None:
513-
self._import_func = FunctionDef(f'{name}_import', (),
514-
(FunctionDefResult(Variable(CNativeInt(), '_', is_temp=True)),), ())
513+
self._import_func = FunctionDef(f'{name}_import', (), (),
514+
(FunctionDefResult(Variable(CNativeInt(), '_', is_temp=True)),))
515515
else:
516516
self._import_func = import_func
517517
super().__init__(name, *args, init_func = init_func, **kwargs)
@@ -915,7 +915,7 @@ class PyModInitFunc(FunctionDef):
915915

916916
def __init__(self, name, body, static_vars, scope):
917917
self._static_vars = static_vars
918-
super().__init__(name, (), (), body, scope=scope)
918+
super().__init__(name, (), body, scope=scope)
919919

920920
@property
921921
def declarations(self):

pyccel/ast/utilities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pyccel.decorators as pyccel_decorators
1313
from pyccel.errors.errors import Errors, PyccelError
1414

15-
from .core import (AsName, Import, FunctionDef, FunctionCall,
15+
from .core import (AsName, Import, FunctionCall,
1616
Allocate, Duplicate, Assign, For, CodeBlock,
1717
Concatenate, Module, PyccelFunctionDef)
1818

pyccel/codegen/printing/fcode.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,8 @@ def _define_gFTL_element(self, element_type, imports_and_macros, element_name):
562562
imports_and_macros.append(complex_tool_import)
563563
compare_func = FunctionDef('complex_comparison',
564564
[FunctionDefArgument(tmpVar_x), FunctionDefArgument(tmpVar_y)],
565-
[FunctionDefResult(Variable(PythonNativeBool(), 'c'))], [])
565+
[],
566+
[FunctionDefResult(Variable(PythonNativeBool(), 'c'))])
566567
lt_def = compare_func(tmpVar_x, tmpVar_y)
567568
else:
568569
lt_def = PyccelAssociativeParenthesis(PyccelLt(tmpVar_x, tmpVar_y))

pyccel/codegen/wrapper/c_to_python_wrapper.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,8 @@ def f(a, b):
455455
"which indicates which function should be called.")
456456

457457
# Build the function
458-
func = FunctionDef(name, [FunctionDefArgument(a) for a in args], [FunctionDefResult(type_indicator)],
459-
body, docstring=docstring, scope=func_scope)
458+
func = FunctionDef(name, [FunctionDefArgument(a) for a in args], body,
459+
[FunctionDefResult(type_indicator)], docstring=docstring, scope=func_scope)
460460

461461
return func, argument_type_flags
462462

@@ -787,7 +787,7 @@ def _build_module_import_function(self, expr):
787787
result = func_scope.get_temporary_variable(CNativeInt())
788788
self.exit_scope()
789789
self._error_exit_code = Nil()
790-
import_func = FunctionDef(func_name, (), (FunctionDefResult(result),), body, is_static=True, scope = func_scope)
790+
import_func = FunctionDef(func_name, (), body, (FunctionDefResult(result),), is_static=True, scope = func_scope)
791791

792792
return API_var, import_func
793793

@@ -883,8 +883,8 @@ def _get_class_allocator(self, class_dtype, func = None):
883883

884884
self.exit_scope()
885885

886-
return PyFunctionDef(func_name, func_args, func_results,
887-
body, scope=func_scope, original_function = None)
886+
return PyFunctionDef(func_name, func_args, body, func_results,
887+
scope=func_scope, original_function = None)
888888

889889
def _get_class_initialiser(self, init_function, cls_dtype):
890890
"""
@@ -980,7 +980,7 @@ def _get_class_initialiser(self, init_function, cls_dtype):
980980
if not a.bound_argument:
981981
self._python_object_map.pop(a)
982982

983-
function = PyFunctionDef(func_name, func_args, func_results, body, scope=func_scope,
983+
function = PyFunctionDef(func_name, func_args, body, func_results, scope=func_scope,
984984
docstring = init_function.docstring, original_function = original_func)
985985

986986
self.scope.functions[func_name] = function
@@ -1046,7 +1046,7 @@ def _get_class_destructor(self, del_function, cls_dtype, wrapper_scope):
10461046

10471047
self.exit_scope()
10481048

1049-
function = PyFunctionDef(func_name, [FunctionDefArgument(func_arg)], [], body, scope=func_scope,
1049+
function = PyFunctionDef(func_name, [FunctionDefArgument(func_arg)], body, scope=func_scope,
10501050
original_function = original_func)
10511051

10521052
self.scope.functions[func_name] = function
@@ -1225,24 +1225,24 @@ def _wrap_BindCModule(self, expr):
12251225
# Add external functions for functions wrapping array variables
12261226
for v in expr.variable_wrappers:
12271227
f = v.wrapper_function
1228-
external_funcs.append(FunctionDef(f.name, f.arguments, f.results, [], is_header = True, scope = Scope()))
1228+
external_funcs.append(FunctionDef(f.name, f.arguments, [], f.results, is_header = True, scope = Scope()))
12291229

12301230
# Add external functions for normal functions
12311231
for f in expr.funcs:
1232-
external_funcs.append(FunctionDef(f.name.lower(), f.arguments, f.results, [], is_header = True, scope = Scope()))
1232+
external_funcs.append(FunctionDef(f.name.lower(), f.arguments, [], f.results, is_header = True, scope = Scope()))
12331233

12341234
for c in expr.classes:
12351235
m = c.new_func
1236-
external_funcs.append(FunctionDef(m.name, m.arguments, m.results, [], is_header = True, scope = Scope()))
1236+
external_funcs.append(FunctionDef(m.name, m.arguments, [], m.results, is_header = True, scope = Scope()))
12371237
for m in c.methods:
1238-
external_funcs.append(FunctionDef(m.name, m.arguments, m.results, [], is_header = True, scope = Scope()))
1238+
external_funcs.append(FunctionDef(m.name, m.arguments, [], m.results, is_header = True, scope = Scope()))
12391239
for i in c.interfaces:
12401240
for f in i.functions:
1241-
external_funcs.append(FunctionDef(f.name, f.arguments, f.results, [], is_header = True, scope = Scope()))
1241+
external_funcs.append(FunctionDef(f.name, f.arguments, [], f.results, is_header = True, scope = Scope()))
12421242
for a in c.attributes:
12431243
for f in (a.getter, a.setter):
12441244
if f:
1245-
external_funcs.append(FunctionDef(f.name, f.arguments, f.results, [], is_header = True, scope = Scope()))
1245+
external_funcs.append(FunctionDef(f.name, f.arguments, [], f.results, is_header = True, scope = Scope()))
12461246
pymod.external_funcs = external_funcs
12471247

12481248
return pymod
@@ -1326,8 +1326,8 @@ def _wrap_Interface(self, expr):
13261326

13271327
interface_func = FunctionDef(func_name,
13281328
[FunctionDefArgument(a) for a in func_args],
1329-
[FunctionDefResult(self.get_new_PyObject("result", is_temp=True))],
13301329
body,
1330+
[FunctionDefResult(self.get_new_PyObject("result", is_temp=True))],
13311331
scope=func_scope)
13321332
for a in python_args:
13331333
self._python_object_map.pop(a)
@@ -1467,7 +1467,7 @@ def _wrap_FunctionDef(self, expr):
14671467
if not a.bound_argument:
14681468
self._python_object_map.pop(a)
14691469

1470-
function = PyFunctionDef(func_name, func_args, func_results, body, scope=func_scope,
1470+
function = PyFunctionDef(func_name, func_args, body, func_results, scope=func_scope,
14711471
docstring = expr.docstring, original_function = original_func)
14721472

14731473
self.scope.functions[func_name] = function
@@ -1712,7 +1712,7 @@ def _wrap_DottedVariable(self, expr):
17121712
self.exit_scope()
17131713

17141714
args = [FunctionDefArgument(a) for a in getter_args]
1715-
getter = PyFunctionDef(getter_name, args, (FunctionDefResult(getter_result),), getter_body,
1715+
getter = PyFunctionDef(getter_name, args, getter_body, (FunctionDefResult(getter_result),),
17161716
original_function = expr, scope = getter_scope)
17171717

17181718
# ----------------------------------------------------------------------------------
@@ -1760,7 +1760,7 @@ def _wrap_DottedVariable(self, expr):
17601760
self.exit_scope()
17611761

17621762
args = [FunctionDefArgument(a) for a in setter_args]
1763-
setter = PyFunctionDef(setter_name, args, setter_result, setter_body,
1763+
setter = PyFunctionDef(setter_name, args, setter_body, setter_result,
17641764
original_function = expr, scope = setter_scope)
17651765
self._error_exit_code = Nil()
17661766
self._python_object_map.pop(new_set_val_arg)
@@ -1840,7 +1840,7 @@ def _wrap_BindCClassProperty(self, expr):
18401840
self.exit_scope()
18411841

18421842
args = [FunctionDefArgument(a) for a in getter_args]
1843-
getter = PyFunctionDef(getter_name, args, (FunctionDefResult(getter_result),), getter_body,
1843+
getter = PyFunctionDef(getter_name, args, getter_body, (FunctionDefResult(getter_result),),
18441844
original_function = expr.getter, scope = getter_scope)
18451845

18461846
# ----------------------------------------------------------------------------------
@@ -1886,7 +1886,7 @@ def _wrap_BindCClassProperty(self, expr):
18861886
self.exit_scope()
18871887

18881888
args = [FunctionDefArgument(a) for a in setter_args]
1889-
setter = PyFunctionDef(setter_name, args, setter_result, setter_body,
1889+
setter = PyFunctionDef(setter_name, args, setter_body, setter_result,
18901890
original_function = expr, scope = setter_scope)
18911891
else:
18921892
setter = None

pyccel/codegen/wrapper/fortran_to_c_wrapper.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def _wrap_FunctionDef(self, expr):
277277

278278
self.exit_scope()
279279

280-
func = BindCFunctionDef(name, func_arguments, func_results, body, scope=func_scope, original_function = expr,
280+
func = BindCFunctionDef(name, func_arguments, body, func_results, scope=func_scope, original_function = expr,
281281
docstring = expr.docstring, result_pointer_map = expr.result_pointer_map)
282282

283283
self.scope.functions[name] = func
@@ -541,7 +541,7 @@ def _wrap_DottedVariable(self, expr):
541541
self._additional_exprs.clear()
542542
self.exit_scope()
543543

544-
getter = BindCFunctionDef(getter_name, (getter_arg,), (getter_result,), getter_body,
544+
getter = BindCFunctionDef(getter_name, (getter_arg,), getter_body, (getter_result,),
545545
original_function = expr, scope = getter_scope)
546546

547547
# ----------------------------------------------------------------------------------
@@ -579,7 +579,7 @@ def _wrap_DottedVariable(self, expr):
579579
setter_body.append(Assign(attrib, set_val))
580580
self.exit_scope()
581581

582-
setter = BindCFunctionDef(setter_name, setter_args, (), setter_body,
582+
setter = BindCFunctionDef(setter_name, setter_args, setter_body,
583583
original_function = expr, scope = setter_scope)
584584
return BindCClassProperty(lhs.cls_base.scope.get_python_name(expr.name),
585585
getter, setter, lhs.dtype)
@@ -623,7 +623,7 @@ def _wrap_ClassDef(self, expr):
623623
c_loc = CLocFunc(local_var, bind_var)
624624
body = [alloc, c_loc]
625625

626-
new_method = BindCFunctionDef(func_name, [], [result], body, original_function = None, scope = func_scope)
626+
new_method = BindCFunctionDef(func_name, [], body, [result], original_function = None, scope = func_scope)
627627

628628
methods = [self._wrap(m) for m in expr.methods]
629629
methods = [m for m in methods if not isinstance(m, EmptyNode)]

pyccel/parser/semantic.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2392,7 +2392,7 @@ def _visit_Module(self, expr):
23922392
init_func_body = If(IfSection(PyccelNot(init_var),
23932393
init_func_body+[Assign(init_var, LiteralTrue())]))
23942394

2395-
init_func = FunctionDef(init_func_name, [], [], [init_func_body],
2395+
init_func = FunctionDef(init_func_name, [], [init_func_body],
23962396
global_vars = variables, scope=init_scope)
23972397
self.insert_function(init_func)
23982398

@@ -2416,7 +2416,7 @@ def _visit_Module(self, expr):
24162416
import_free_calls+deallocs+[Assign(init_var, LiteralFalse())]))
24172417
# Ensure that the function is correctly defined within the namespaces
24182418
scope = self.create_new_function_scope(free_func_name)
2419-
free_func = FunctionDef(free_func_name, [], [], [free_func_body],
2419+
free_func = FunctionDef(free_func_name, [], [free_func_body],
24202420
global_vars = variables, scope = scope)
24212421
self.exit_function_scope()
24222422
self.insert_function(free_func)
@@ -2456,7 +2456,7 @@ def _visit_Module(self, expr):
24562456

24572457
args = [FunctionDefArgument(a) for a in args]
24582458
results = [FunctionDefResult(r) for r in results]
2459-
func_defs.append(FunctionDef(v.name, args, results, [], is_external = is_external, is_header = True))
2459+
func_defs.append(FunctionDef(v.name, args, [], results, is_external = is_external, is_header = True))
24602460

24612461
if len(func_defs) == 1:
24622462
F = func_defs[0]
@@ -4224,7 +4224,7 @@ def unpack(ann):
42244224
# insert the FunctionDef into the scope
42254225
# to handle the case of a recursive function
42264226
# TODO improve in the case of an interface
4227-
recursive_func_obj = FunctionDef(name, arguments, results, [])
4227+
recursive_func_obj = FunctionDef(name, arguments, [], results)
42284228
self.insert_function(recursive_func_obj)
42294229

42304230
# Create a new list that store local variables for each FunctionDef to handle nested functions
@@ -4352,8 +4352,8 @@ def unpack(ann):
43524352
cls = FunctionDef
43534353
func = cls(name,
43544354
arguments,
4355-
results,
43564355
body,
4356+
results,
43574357
**func_kwargs)
43584358
if not is_recursive:
43594359
recursive_func_obj.invalidate_node()
@@ -4477,7 +4477,7 @@ def _visit_ClassDef(self, expr):
44774477
argument = FunctionDefArgument(Variable(dtype, 'self', cls_base = cls), bound_argument = True)
44784478
self.scope.insert_symbol('__init__')
44794479
scope = self.create_new_function_scope('__init__')
4480-
init_func = FunctionDef('__init__', [argument], (), [], cls_name=cls.name, scope=scope)
4480+
init_func = FunctionDef('__init__', [argument], (), cls_name=cls.name, scope=scope)
44814481
self.exit_function_scope()
44824482
self.insert_function(init_func)
44834483
cls.add_new_method(init_func)
@@ -4507,7 +4507,7 @@ def _visit_ClassDef(self, expr):
45074507
argument = FunctionDefArgument(Variable(dtype, 'self', cls_base = cls), bound_argument = True)
45084508
self.scope.insert_symbol('__del__')
45094509
scope = self.create_new_function_scope('__del__')
4510-
del_method = FunctionDef('__del__', [argument], (), [Pass()], scope=scope)
4510+
del_method = FunctionDef('__del__', [argument], [Pass()], scope=scope)
45114511
self.exit_function_scope()
45124512
self.insert_function(del_method)
45134513
cls.add_new_method(del_method)
@@ -4791,7 +4791,7 @@ def _visit_MacroFunction(self, expr):
47914791

47924792
arguments = [FunctionDefArgument(self._visit(a)[0]) for a in syntactic_args]
47934793
results = [FunctionDefResult(self._visit(r)[0]) for r in syntactic_results]
4794-
interfaces.append(FunctionDef(f_name, arguments, results, []))
4794+
interfaces.append(FunctionDef(f_name, arguments, [], results))
47954795

47964796
# TODO -> Said: must handle interface
47974797

0 commit comments

Comments
 (0)