Skip to content

Commit 2381034

Browse files
Optimise min/max to avoid unnecessary temporary variables (pyccel#2028)
# PR Summary Fixes: pyccel#2025 The PR optimizes the generated C code by reducing unnecessary temporary variables in min/max functions when direct variable inputs are used. For example: ```python if __name__ == '__main__': a = 3 b = 4 c = min(a,b) d = max(a,b) e = min(a+b, a-b) ``` previously leads to the following generated output: ```C int main() { int64_t a; int64_t b; int64_t c; int64_t d; int64_t e; int64_t Dummy_0000; int64_t Dummy_0001; int64_t Dummy_0002; int64_t Dummy_0003; int64_t Dummy_0004; int64_t Dummy_0005; a = INT64_C(3); b = INT64_C(4); Dummy_0000 = a; Dummy_0001 = b; c = (Dummy_0000 < Dummy_0001 ? Dummy_0000 : Dummy_0001); Dummy_0002 = a; Dummy_0003 = b; d = (Dummy_0002 > Dummy_0003 ? Dummy_0002 : Dummy_0003); Dummy_0004 = a + b; Dummy_0005 = a - b; e = (Dummy_0004 < Dummy_0005 ? Dummy_0004 : Dummy_0005); return 0; } ``` But now it will lead to the following generated output: ```C int main() { int64_t a; int64_t b; int64_t c; int64_t d; int64_t e; int64_t Dummy_0000; int64_t Dummy_0001; a = INT64_C(3); b = INT64_C(4); c = (a < b ? a : b); d = (a > b ? a : b); Dummy_0000 = a + b; Dummy_0001 = a - b; e = (Dummy_0000 < Dummy_0001 ? Dummy_0000 : Dummy_0001); return 0; } ``` --------- Signed-off-by: Emmanuel Ferdman <[email protected]> Co-authored-by: Emily Bourne <[email protected]>
1 parent 5bd938d commit 2381034

File tree

4 files changed

+74
-12
lines changed

4 files changed

+74
-12
lines changed

AUTHORS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,4 @@ Contributors
3434
* Said Mazouz
3535
* Shoaib Moeen
3636
* Kush Choudhary
37+
* Emmmanuel Ferdman

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ All notable changes to this project will be documented in this file.
5252

5353
### Fixed
5454

55+
- #2025 : Optimise min/max to avoid unnecessary temporary variables.
5556
- #1720 : Fix Undefined Variable error when the function definition is after the variable declaration.
5657
- #1763 Use `np.result_type` to avoid mistakes in non-trivial NumPy type promotion rules.
5758
- Fix some cases where a Python built-in type is returned in place of a NumPy type.

pyccel/codegen/printing/ccode.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -714,12 +714,22 @@ def _print_PythonMin(self, expr):
714714
return "fmin({}, {})".format(self._print(arg[0]),
715715
self._print(arg[1]))
716716
elif arg.dtype.primitive_type is PrimitiveIntegerType() and len(arg) == 2:
717-
arg1 = self.scope.get_temporary_variable(PythonNativeInt())
718-
arg2 = self.scope.get_temporary_variable(PythonNativeInt())
719-
assign1 = Assign(arg1, arg[0])
720-
assign2 = Assign(arg2, arg[1])
721-
self._additional_code += self._print(assign1)
722-
self._additional_code += self._print(assign2)
717+
if isinstance(arg[0], Variable):
718+
arg1 = self._print(arg[0])
719+
else:
720+
arg1_temp = self.scope.get_temporary_variable(PythonNativeInt())
721+
assign1 = Assign(arg1_temp, arg[0])
722+
self._additional_code += self._print(assign1)
723+
arg1 = self._print(arg1_temp)
724+
725+
if isinstance(arg[1], Variable):
726+
arg2 = self._print(arg[1])
727+
else:
728+
arg2_temp = self.scope.get_temporary_variable(PythonNativeInt())
729+
assign2 = Assign(arg2_temp, arg[1])
730+
self._additional_code += self._print(assign2)
731+
arg2 = self._print(arg2_temp)
732+
723733
return f"({arg1} < {arg2} ? {arg1} : {arg2})"
724734
else:
725735
return errors.report("min in C is only supported for 2 scalar arguments", symbol=expr,
@@ -732,12 +742,22 @@ def _print_PythonMax(self, expr):
732742
return "fmax({}, {})".format(self._print(arg[0]),
733743
self._print(arg[1]))
734744
elif arg.dtype.primitive_type is PrimitiveIntegerType() and len(arg) == 2:
735-
arg1 = self.scope.get_temporary_variable(PythonNativeInt())
736-
arg2 = self.scope.get_temporary_variable(PythonNativeInt())
737-
assign1 = Assign(arg1, arg[0])
738-
assign2 = Assign(arg2, arg[1])
739-
self._additional_code += self._print(assign1)
740-
self._additional_code += self._print(assign2)
745+
if isinstance(arg[0], Variable):
746+
arg1 = self._print(arg[0])
747+
else:
748+
arg1_temp = self.scope.get_temporary_variable(PythonNativeInt())
749+
assign1 = Assign(arg1_temp, arg[0])
750+
self._additional_code += self._print(assign1)
751+
arg1 = self._print(arg1_temp)
752+
753+
if isinstance(arg[1], Variable):
754+
arg2 = self._print(arg[1])
755+
else:
756+
arg2_temp = self.scope.get_temporary_variable(PythonNativeInt())
757+
assign2 = Assign(arg2_temp, arg[1])
758+
self._additional_code += self._print(assign2)
759+
arg2 = self._print(arg2_temp)
760+
741761
return f"({arg1} > {arg2} ? {arg1} : {arg2})"
742762
else:
743763
return errors.report("max in C is only supported for 2 scalar arguments", symbol=expr,

tests/epyccel/test_builtins.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,26 @@ def f(x : 'T', y : 'T'):
188188
assert np.array_equal(epyc_f(*int_args), f(*int_args))
189189
assert np.allclose(epyc_f(*float_args), f(*float_args), rtol=RTOL, atol=ATOL)
190190

191+
def test_min_temp_var_first_arg(language):
192+
def f(x: 'int', y: 'int'):
193+
return min(x + 1, y)
194+
195+
epyc_f = epyccel(f, language=language)
196+
197+
x, y = randint(min_int, max_int), randint(min_int, max_int)
198+
199+
assert epyc_f(x, y) == f(x, y)
200+
201+
def test_min_temp_var_second_arg(language):
202+
def f(x: 'int', y: 'int'):
203+
return min(x, y + 2)
204+
205+
epyc_f = epyccel(f, language=language)
206+
207+
x, y = randint(min_int, max_int), randint(min_int, max_int)
208+
209+
assert epyc_f(x, y) == f(x, y)
210+
191211
def test_max_2_args_i(language):
192212
def f(x : 'int', y : 'int'):
193213
return max(x, y)
@@ -287,6 +307,26 @@ def f(x : 'T', y : 'T'):
287307
assert np.array_equal(epyc_f(*int_args), f(*int_args))
288308
assert np.allclose(epyc_f(*float_args), f(*float_args), rtol=RTOL, atol=ATOL)
289309

310+
def test_max_temp_var_first_arg(language):
311+
def f(x: 'int', y: 'int'):
312+
return max(x + 1, y)
313+
314+
epyc_f = epyccel(f, language=language)
315+
316+
x, y = randint(min_int, max_int), randint(min_int, max_int)
317+
318+
assert epyc_f(x, y) == f(x, y)
319+
320+
def test_max_temp_var_second_arg(language):
321+
def f(x: 'int', y: 'int'):
322+
return max(x, y + 2)
323+
324+
epyc_f = epyccel(f, language=language)
325+
326+
x, y = randint(min_int, max_int), randint(min_int, max_int)
327+
328+
assert epyc_f(x, y) == f(x, y)
329+
290330
@pytest.mark.parametrize( 'language', (
291331
pytest.param("fortran", marks = pytest.mark.fortran),
292332
pytest.param("c", marks = [

0 commit comments

Comments
 (0)