Skip to content

Commit fa5ea32

Browse files
authored
Allow lambdify to generate functions with out arguments (pyccel#1867)
Add a `use_out` parameter to `pyccel.lambdify` to avoid unnecessary memory allocation. Auto-generate a docstring for functions generated via calls to `pyccel.lambdify`.
1 parent e7c5b85 commit fa5ea32

File tree

4 files changed

+90
-10
lines changed

4 files changed

+90
-10
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ All notable changes to this project will be documented in this file.
2020
- #1656 : Ensure `gFTL` is installed with Pyccel.
2121
- #1830 : Add a `pyccel.lambdify.lambdify` function to accelerate SymPy expressions.
2222
- #1844 : Add line numbers and code to errors from built-in function calls.
23+
- #1867 : Add a `use_out` parameter to `pyccel.lambdify` to avoid unnecessary memory allocation.
24+
- #1867 : Auto-generate a docstring for functions generated via calls to `pyccel.lambdify`.
2325
- \[INTERNALS\] Added `container_rank` property to `ast.datatypes.PyccelType` objects.
2426
- \[DEVELOPER\] Added an improved traceback to the developer-mode errors for errors in function calls.
2527

docs/quickstart.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,22 @@ In practice `lambdify` uses SymPy's `NumPyPrinter` to generate code which is pas
482482
Once the file has been copied, `epyccel` calls the `pyccel` command to generate a Python C extension module that contains a single pyccelised function.
483483
Then finally, it imports this function and returns it to the caller.
484484

485+
In order to make functions even faster it may be desirable to avoid unnecessary allocations inside the function. This functionality is similar to using an `out` argument in NumPy. Pyccel makes this functionality possible through the use of the `use_out` argument.
486+
For example:
487+
```python
488+
import numpy as np
489+
import sympy as sp
490+
from pyccel import lambdify
491+
492+
x = sp.Symbol('x')
493+
expr = x**2 + x*5
494+
f = lambdify(expr, {x : 'float[:,:]'}, result_type = 'float[:,:]')
495+
x_2d = np.ones((4,2))
496+
y_2d = np.empty_like(x_2d)
497+
f(x_2d, y_2d)
498+
print(y_2d)
499+
```
500+
485501
## Other Features
486502

487503
Pyccel's generated code can use parallel multi-threading through [OpenMP](https://en.wikipedia.org/wiki/OpenMP); please read [our documentation](https://github.com/pyccel/pyccel/blob/devel/docs/openmp.md) for more details.

pyccel/commands/lambdify.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from sympy.printing.pycode import NumPyPrinter
1515

1616
def lambdify(expr : sp.Expr, args : 'dict[sp.Symbol, str]', *, result_type : str = None,
17-
templates : 'dict[str,list[str]]' = None, **kwargs):
17+
templates : 'dict[str, list[str]]' = None, use_out = False,
18+
**kwargs):
1819
"""
1920
Convert a SymPy expression into a Pyccel-accelerated function.
2021
@@ -37,13 +38,18 @@ def lambdify(expr : sp.Expr, args : 'dict[sp.Symbol, str]', *, result_type : str
3738
expressions do not always evaluate to the expected type. For
3839
example if the SymPy expression simplifies to 0 then the default
3940
type will be int even if the arguments are floats.
40-
templates : dict[str,list[str]], optional
41+
templates : dict[str, list[str]], optional
4142
A description of any templates that should be added to the
4243
function. The keys are the symbols which can be used as type
4344
specifiers, the values are a list of the type annotations which
4445
are valid types for the symbol. See
4546
<https://github.com/pyccel/pyccel/blob/devel/docs/templates.md>
4647
for more details.
48+
use_out : bool, default=False
49+
If true the function will modify an argument called 'out' instead
50+
of returning a newly allocated array. If this argument is set then
51+
result_type must be provided. This only works if the result is an
52+
array type.
4753
**kwargs : dict
4854
Additional arguments that are passed to epyccel.
4955
@@ -62,16 +68,37 @@ def lambdify(expr : sp.Expr, args : 'dict[sp.Symbol, str]', *, result_type : str
6268
"""
6369
if not (isinstance(args, dict) and all(isinstance(k, sp.Symbol) and isinstance(v, str) for k,v in args.items())):
6470
raise TypeError("Argument 'args': Expected a dictionary mapping SymPy symbols to string type annotations.")
71+
if result_type is not None and not isinstance(result_type, str):
72+
raise TypeError("Argument 'result_type': Expected a string type annotation.")
6573

6674
expr = NumPyPrinter().doprint(expr)
67-
args = ', '.join(f'{a} : "{annot}"' for a, annot in args.items())
75+
args_code = ', '.join(f'{a} : "{annot}"' for a, annot in args.items())
6876
func_name = 'func_'+random_string(8)
69-
if result_type:
70-
if not isinstance(result_type, str):
71-
raise TypeError("Argument 'result_type': Expected a string type annotation.")
72-
signature = f'def {func_name}({args}) -> "{result_type}":'
77+
78+
docstring = " \n".join((' """',
79+
" Expression evaluation created with `pyccel.lambdify`.",
80+
"",
81+
" Function evaluating the expression:",
82+
f" {expr}",
83+
"",
84+
" Parameters",
85+
" ----------\n"))
86+
docstring += '\n'.join(f" {a} : {type_annot}" for a, type_annot in args.items())
87+
88+
if use_out:
89+
if not result_type:
90+
raise TypeError("The result_type must be provided if use_out is true.")
91+
else:
92+
signature = f'def {func_name}({args_code}, out : "{result_type}"):'
93+
docstring += f"\n out : {result_type}"
94+
elif result_type:
95+
signature = f'def {func_name}({args_code}) -> "{result_type}":'
96+
docstring += "\n".join(("\n",
97+
" Returns",
98+
" -------",
99+
f" {result_type}"))
73100
else:
74-
signature = f'def {func_name}({args}):'
101+
signature = f'def {func_name}({args_code}):'
75102
if templates:
76103
if not (isinstance(templates, dict) and all(isinstance(k, str) and hasattr(v, '__iter__') for k,v in templates.items()) \
77104
and all(all(isinstance(type_annot, str) for type_annot in v) for v in templates.values())):
@@ -81,9 +108,15 @@ def lambdify(expr : sp.Expr, args : 'dict[sp.Symbol, str]', *, result_type : str
81108
for key, annotations in templates.items())
82109
else:
83110
decorators = ''
84-
code = f' return {expr}'
111+
if use_out:
112+
code = f' out[:] = {expr}'
113+
else:
114+
code = f' return {expr}'
85115
numpy_import = 'import numpy\n'
86-
func = '\n'.join((numpy_import, decorators, signature, code))
116+
117+
docstring += '\n """'
118+
119+
func = '\n'.join((numpy_import, decorators, signature, docstring, code))
87120
package = epyccel(func, **kwargs)
88121
return getattr(package, func_name)
89122

tests/symbolic/test_symbolic.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,35 @@ def test_lambdify(language):
7272
assert np.allclose(sp_x(r[0,:], p[0,:]), pyc_x(r[0,:], p[0,:]), rtol=RTOL, atol=ATOL)
7373
assert np.allclose(sp_y(r[0,:], p[0,:]), pyc_y(r[0,:], p[0,:]), rtol=RTOL, atol=ATOL)
7474

75+
def test_lambdify_out_arg(language):
76+
r1 = np.linspace(0.0, 1.0, 100)
77+
p1 = np.linspace(0.0, 2*np.pi, 100)
78+
r,p = np.meshgrid(r1, p1)
79+
x,y = sp.symbols('x1,x2')
80+
for m in (mappings.PolarMapping, mappings.TargetMapping, mappings.CzarnyMapping):
81+
expr_x = sp.sympify(m.expressions['x']).subs(m.constants)
82+
expr_y = sp.sympify(m.expressions['y']).subs(m.constants)
83+
sp_x = sp.lambdify([x, y], expr_x)
84+
sp_y = sp.lambdify([x, y], expr_y)
85+
pyc_x = pyc_lambdify(expr_x, {x : 'float[:,:]', y : 'float[:,:]'}, result_type = 'float[:,:]',
86+
use_out = True, language = language)
87+
pyc_y = pyc_lambdify(expr_y, {x : 'float[:,:]', y : 'float[:,:]'}, result_type = 'float[:,:]',
88+
use_out = True, language = language)
89+
90+
print(pyc_x.__doc__)
91+
92+
sp_out_x = np.empty_like(r)
93+
sp_out_y = np.empty_like(r)
94+
pyc_out_x = np.empty_like(r)
95+
pyc_out_y = np.empty_like(r)
96+
sp_out_x = sp_x(r, p)
97+
sp_out_y = sp_y(r, p)
98+
pyc_x(r, p, pyc_out_x)
99+
pyc_y(r, p, out = pyc_out_y)
100+
101+
assert np.allclose(sp_out_x, pyc_out_x, rtol=RTOL, atol=ATOL)
102+
assert np.allclose(sp_out_y, pyc_out_y, rtol=RTOL, atol=ATOL)
103+
75104
######################
76105
if __name__ == '__main__':
77106
print('*********************************')

0 commit comments

Comments
 (0)