Skip to content

Commit 3e2f5f1

Browse files
Extend the C min and max functions to more than two variables (pyccel#2041)
This PR aims to extend the C support for min and max functions to more than two variables. It leverages the `stc/common.h` functionalities to define two variadic functions for min and max in `stdlib/STC_EXTENSIONS/Common_extensions.h`. This PR also fixed the import scheme for extension headers: Previously, stc headers and extension headers were included separately this is problematic because at the end of `stc` headers, we find a call to `#include "priv/template2.h"` which undefines what was defined by the header unless `imore` is defined. This prohibits the use of these in the extension header. The proposed solution combines the `stc` header with its respective extension and defines `imore`. --------- Co-authored-by: Emily Bourne <[email protected]>
1 parent 49abd63 commit 3e2f5f1

File tree

8 files changed

+130
-58
lines changed

8 files changed

+130
-58
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ All notable changes to this project will be documented in this file.
5050
- #738 : Add support for homogeneous tuples with scalar elements as arguments.
5151
- Add a warning about containers in lists.
5252
- #2016 : Add support for translating arithmetic magic methods (methods cannot yet be used from Python).
53+
- #1980 : Extend The C support for min and max to more than two variables
5354
- \[INTERNALS\] Add abstract class `SetMethod` to handle calls to various set methods.
5455
- \[INTERNALS\] Added `container_rank` property to `ast.datatypes.PyccelType` objects.
5556
- \[INTERNALS\] Add a `__call__` method to `FunctionDef` to create `FunctionCall` instances.

pyccel/ast/core.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3711,6 +3711,30 @@ def define_target(self, new_target):
37113711
else:
37123712
self._target[new_target] = None
37133713

3714+
def remove_target(self, target_to_remove):
3715+
"""
3716+
Remove a target from the imports.
3717+
3718+
Remove a target from the imports.
3719+
I.e., if `imp` is an Import defined as:
3720+
>>> from numpy import ones, cos
3721+
3722+
and we call `imp.remove_target('cos')`
3723+
then it becomes:
3724+
>>> from numpy import ones
3725+
3726+
Parameters
3727+
----------
3728+
target_to_remove : str | AsName | iterable[str | AsName]
3729+
The import target(s) to remove.
3730+
"""
3731+
assert pyccel_stage != "syntactic"
3732+
if iterable(target_to_remove):
3733+
for t in target_to_remove:
3734+
self._target.pop(t, None)
3735+
else:
3736+
self._target.pop(target_to_remove, None)
3737+
37143738
def find_module_target(self, new_target):
37153739
"""
37163740
Find the specified target amongst the targets of the Import.

pyccel/codegen/printing/ccode.py

Lines changed: 70 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414

1515
from pyccel.ast.bind_c import BindCPointer
1616

17-
from pyccel.ast.builtins import PythonRange, PythonComplex
17+
from pyccel.ast.builtins import PythonRange, PythonComplex, PythonMin
1818
from pyccel.ast.builtins import PythonPrint, PythonType, VariableIterator
19+
1920
from pyccel.ast.builtins import PythonList, PythonTuple, PythonSet, PythonDict, PythonLen
2021

2122
from pyccel.ast.builtin_methods.dict_methods import DictItems
@@ -248,8 +249,14 @@
248249
'stdbool',
249250
'assert']}
250251

251-
import_header_guard_prefix = {'Set_extensions' : '_TOOLS_SET',
252-
'List_extensions' : '_TOOLS_LIST'}
252+
import_header_guard_prefix = {'Set_extensions' : '_TOOLS_SET',
253+
'List_extensions' : '_TOOLS_LIST',
254+
'Common_extensions' : '_TOOLS_COMMON'}
255+
256+
257+
stc_header_mapping = {'List_extensions': 'stc/vec',
258+
'Set_extensions': 'stc/hset',
259+
'Common_extensions': 'stc/common'}
253260

254261
class CCodePrinter(CodePrinter):
255262
"""
@@ -675,6 +682,32 @@ def init_stc_container(self, expr, assignment_var):
675682
init = f'{container_name} = c_init({dtype}, {keyraw});\n'
676683
return init
677684

685+
def invalidate_stc_headers(self, imports):
686+
"""
687+
Invalidate STC headers when STC extension headers are present.
688+
689+
This function iterates over the list of imports and removes any targets
690+
from STC headers if the target is present in their corresponding
691+
STC extension headers.
692+
The STC extension headers take care of including the standard
693+
headers.
694+
695+
Parameters
696+
----------
697+
imports : list of Import
698+
The list of Import objects representing the header files to include.
699+
700+
Returns
701+
-------
702+
None
703+
The function modifies the `imports` list in-place.
704+
"""
705+
for imp in imports:
706+
if imp.source in stc_header_mapping:
707+
for imp2 in imports:
708+
if imp2.source == stc_header_mapping[imp.source]:
709+
imp2.remove_target(imp.target)
710+
678711
def rename_imported_methods(self, expr):
679712
"""
680713
Rename class methods from user-defined imports.
@@ -711,12 +744,13 @@ def _print_PythonAbs(self, expr):
711744
func = "labs"
712745
return "{}({})".format(func, self._print(expr.arg))
713746

714-
def _print_PythonMin(self, expr):
747+
def _print_PythonMinMax(self, expr):
715748
arg = expr.args[0]
716749
if arg.dtype.primitive_type is PrimitiveFloatingPointType() and len(arg) == 2:
717750
self.add_import(c_imports['math'])
718-
return "fmin({}, {})".format(self._print(arg[0]),
719-
self._print(arg[1]))
751+
arg1 = self._print(arg[0])
752+
arg2 = self._print(arg[1])
753+
return f"f{expr.name}({arg1}, {arg2})"
720754
elif arg.dtype.primitive_type is PrimitiveIntegerType() and len(arg) == 2:
721755
if isinstance(arg[0], Variable):
722756
arg1 = self._print(arg[0])
@@ -734,38 +768,22 @@ def _print_PythonMin(self, expr):
734768
self._additional_code += self._print(assign2)
735769
arg2 = self._print(arg2_temp)
736770

737-
return f"({arg1} < {arg2} ? {arg1} : {arg2})"
771+
op = '<' if isinstance(expr, PythonMin) else '>'
772+
return f"({arg1} {op} {arg2} ? {arg1} : {arg2})"
773+
elif len(arg) > 2 and isinstance(arg.dtype.primitive_type, (PrimitiveFloatingPointType, PrimitiveIntegerType)):
774+
key = self.get_declare_type(arg[0])
775+
self.add_import(Import('stc/common', AsName(VariableTypeAnnotation(arg.dtype), key)))
776+
self.add_import(Import('Common_extensions', AsName(VariableTypeAnnotation(arg.dtype), key)))
777+
return f'{key}_{expr.name}({len(arg)}, {", ".join(self._print(a) for a in arg)})'
738778
else:
739-
return errors.report("min in C is only supported for 2 scalar arguments", symbol=expr,
779+
return errors.report(f"{expr.name} in C does not support arguments of type {arg.dtype}", symbol=expr,
740780
severity='fatal')
741781

742-
def _print_PythonMax(self, expr):
743-
arg = expr.args[0]
744-
if arg.dtype.primitive_type is PrimitiveFloatingPointType() and len(arg) == 2:
745-
self.add_import(c_imports['math'])
746-
return "fmax({}, {})".format(self._print(arg[0]),
747-
self._print(arg[1]))
748-
elif arg.dtype.primitive_type is PrimitiveIntegerType() and len(arg) == 2:
749-
if isinstance(arg[0], Variable):
750-
arg1 = self._print(arg[0])
751-
else:
752-
arg1_temp = self.scope.get_temporary_variable(PythonNativeInt())
753-
assign1 = Assign(arg1_temp, arg[0])
754-
self._additional_code += self._print(assign1)
755-
arg1 = self._print(arg1_temp)
756-
757-
if isinstance(arg[1], Variable):
758-
arg2 = self._print(arg[1])
759-
else:
760-
arg2_temp = self.scope.get_temporary_variable(PythonNativeInt())
761-
assign2 = Assign(arg2_temp, arg[1])
762-
self._additional_code += self._print(assign2)
763-
arg2 = self._print(arg2_temp)
782+
def _print_PythonMin(self, expr):
783+
return self._print_PythonMinMax(expr)
764784

765-
return f"({arg1} > {arg2} ? {arg1} : {arg2})"
766-
else:
767-
return errors.report("max in C is only supported for 2 scalar arguments", symbol=expr,
768-
severity='fatal')
785+
def _print_PythonMax(self, expr):
786+
return self._print_PythonMinMax(expr)
769787

770788
def _print_SysExit(self, expr):
771789
code = ""
@@ -857,6 +875,7 @@ def _print_ModuleHeader(self, expr):
857875

858876
# Print imports last to be sure that all additional_imports have been collected
859877
imports = [*expr.module.imports, *self._additional_imports.values()]
878+
self.invalidate_stc_headers(imports)
860879
imports = ''.join(self._print(i) for i in imports)
861880

862881
self._in_header = False
@@ -1033,7 +1052,20 @@ def _print_Import(self, expr):
10331052
source = source.name[-1].python_value
10341053
else:
10351054
source = self._print(source)
1036-
if source.startswith('stc/') or source in import_header_guard_prefix:
1055+
if source == 'Common_extensions':
1056+
code = ''
1057+
for t in expr.target:
1058+
element_decl = f'#define i_key {t.local_alias}\n'
1059+
header_guard_prefix = import_header_guard_prefix.get(source, '')
1060+
header_guard = f'{header_guard_prefix}_{t.local_alias.upper()}'
1061+
code += ''.join((f'#ifndef {header_guard}\n',
1062+
f'#define {header_guard}\n',
1063+
element_decl,
1064+
f'#include <{stc_header_mapping[source]}.h>\n',
1065+
f'#include <{source}.h>\n',
1066+
f'#endif // {header_guard}\n\n'))
1067+
return code
1068+
elif source.startswith('stc/') or source in import_header_guard_prefix:
10371069
code = ''
10381070
for t in expr.target:
10391071
class_type = t.object.class_type
@@ -1055,6 +1087,8 @@ def _print_Import(self, expr):
10551087
f'#define {header_guard}\n',
10561088
f'#define i_type {container_type}\n',
10571089
element_decl,
1090+
'#define i_more\n' if source in import_header_guard_prefix else '',
1091+
f'#include <{stc_header_mapping[source]}.h>\n' if source in import_header_guard_prefix else '',
10581092
f'#include <{source}.h>\n',
10591093
f'#endif // {header_guard}\n\n'))
10601094
return code
@@ -2630,6 +2664,7 @@ def _print_Program(self, expr):
26302664
decs = ''.join(self._print(Declare(v)) for v in variables)
26312665

26322666
imports = [*expr.imports, *self._additional_imports.values()]
2667+
self.invalidate_stc_headers(imports)
26332668
imports = ''.join(self._print(i) for i in imports)
26342669

26352670
self.exit_scope()

pyccel/codegen/utilities.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
"numpy_c" : ("numpy", CompileObj("numpy_c.c",folder="numpy")),
5050
"Set_extensions" : ("STC_Extensions", CompileObj("Set_Extensions.h", folder="STC_Extensions", has_target_file = False)),
5151
"List_extensions" : ("STC_Extensions", CompileObj("List_Extensions.h", folder="STC_Extensions", has_target_file = False)),
52+
"Common_extensions" : ("STC_Extensions", CompileObj("Common_Extensions.h", folder="STC_Extensions", has_target_file = False)),
5253
"gFTL_functions/Set_extensions" : ("gFTL_functions", CompileObj("Set_Extensions.inc", folder="gFTL_functions", has_target_file = False)),
5354
}
5455
internal_libs["cwrapper_ndarrays"] = ("cwrapper_ndarrays", CompileObj("cwrapper_ndarrays.c",folder="cwrapper_ndarrays",
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#include <stdarg.h>
2+
3+
static inline i_key c_JOIN(i_key,_min)(size_t count, ...) {
4+
va_list args;
5+
va_start(args, count);
6+
i_key min_value = va_arg(args, i_key);
7+
for (size_t i = 1; i < count; ++i) {
8+
i_key value = va_arg(args, i_key);
9+
if (value < min_value) {
10+
min_value = value;
11+
}
12+
}
13+
va_end(args);
14+
return min_value;
15+
}
16+
17+
static inline i_key c_JOIN(i_key, _max)(size_t count, ...) {
18+
va_list args;
19+
va_start(args, count);
20+
i_key max_value = va_arg(args, i_key);
21+
for (size_t i = 1; i < count; ++i) {
22+
i_key value = va_arg(args, i_key);
23+
if (value > max_value) {
24+
max_value = value;
25+
}
26+
}
27+
va_end(args);
28+
return max_value;
29+
}
30+
31+
#undef i_key

pyccel/stdlib/STC_Extensions/List_extensions.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
#define LIST_EXTENSIONS_H
33

44

5-
#define _c_MEMB(name) c_JOIN(i_type, name)
6-
75
// This function represents a call to the .pop() method.
86
// i_type: Class type (e.g., hset_int64_t).
97
// i_key: Data type of the elements in the set (e.g., int64_t).
@@ -24,4 +22,5 @@ static inline i_key _c_MEMB(_pull_elem)(i_type* self, intptr_t pop_idx) {
2422

2523
#undef i_type
2624
#undef i_key
25+
#include <stc/priv/template2.h>
2726
#endif

pyccel/stdlib/STC_Extensions/Set_extensions.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
#define SET_EXTENSIONS_H
33
#include <stdarg.h>
44

5-
#define _c_MEMB(name) c_JOIN(i_type, name)
6-
75
// This function represents a call to the .pop() method.
86
// i_type: Class type (e.g., hset_int64_t).
97
// i_key: Data type of the elements in the set (e.g., int64_t).
@@ -44,4 +42,5 @@ static inline i_type _c_MEMB(_union)(i_type* self, int n, ...) {
4442

4543
#undef i_type
4644
#undef i_key
45+
#include <stc/priv/template2.h>
4746
#endif

tests/epyccel/test_builtins.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,6 @@ def f(x : 'float', y : 'float'):
109109

110110
assert np.isclose(epyc_f(*float_args), f(*float_args), rtol=RTOL, atol=ATOL)
111111

112-
@pytest.mark.parametrize( 'language', (
113-
pytest.param("fortran", marks = pytest.mark.fortran),
114-
pytest.param("c", marks = [
115-
pytest.mark.skip(reason="min not implemented in C for more than 2 args"),
116-
pytest.mark.c]
117-
),
118-
pytest.param("python", marks = pytest.mark.python)
119-
)
120-
)
121112
def test_min_3_args(language):
122113
@template('T', [int, float])
123114
def f(x : 'T', y : 'T', z : 'T'):
@@ -228,19 +219,10 @@ def f(x : 'float', y : 'float'):
228219

229220
assert np.isclose(epyc_f(*float_args), f(*float_args), rtol=RTOL, atol=ATOL)
230221

231-
@pytest.mark.parametrize( 'language', (
232-
pytest.param("fortran", marks = pytest.mark.fortran),
233-
pytest.param("c", marks = [
234-
pytest.mark.skip(reason="max not implemented in C for more than 2 args"),
235-
pytest.mark.c]
236-
),
237-
pytest.param("python", marks = pytest.mark.python)
238-
)
239-
)
240222
def test_max_3_args(language):
241223
@template('T', [int, float])
242224
def f(x : 'T', y : 'T', z : 'T'):
243-
return min(x, y, z)
225+
return max(x, y, z)
244226

245227
epyc_f = epyccel(f, language=language)
246228

0 commit comments

Comments
 (0)