Skip to content

Commit 78d56de

Browse files
authored
Add a pyccel.lambdify function with similar functionality to sympy.lambdify (pyccel#1829)
Given a SymPy expression `f` and type annotations, `lambdify` returns a "pyccelised" function `f_fast` that can be used in the same Python session. Fixes pyccel#1827.
1 parent e36726f commit 78d56de

File tree

10 files changed

+213
-11
lines changed

10 files changed

+213
-11
lines changed

.dict_custom.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ NumPy
55
NumPy's
66
CuPy
77
CuPy's
8+
SymPy
9+
SymPy's
810
BLAS
911
LAPACK
1012
MPI

.github/actions/pytest_run/action.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ runs:
2828
shell: ${{ inputs.shell_cmd }}
2929
working-directory: ./tests
3030
- name: Test C translation
31-
run: python -m pytest -n auto -rX ${FLAGS} -m "not (parallel or xdist_incompatible) and c ${{ inputs.pytest_mark }}" --ignore=symbolic --ignore=ndarrays 2>&1 | tee s1_outfile.out
31+
run: python -m pytest -n auto -rX ${FLAGS} -m "not (parallel or xdist_incompatible) and c ${{ inputs.pytest_mark }}" --ignore=ndarrays 2>&1 | tee s1_outfile.out
3232
shell: ${{ inputs.shell_cmd }}
3333
working-directory: ./tests
3434
id: pytest_1
@@ -45,19 +45,19 @@ runs:
4545
id: pytest_2
4646
- name: Test multi-file C translations
4747
run: |
48-
python -m pytest -rX ${FLAGS} -m "xdist_incompatible and not parallel and c ${{ inputs.pytest_mark }}" --ignore=symbolic --ignore=ndarrays 2>&1 | tee s3_outfile.out
48+
python -m pytest -rX ${FLAGS} -m "xdist_incompatible and not parallel and c ${{ inputs.pytest_mark }}" --ignore=ndarrays 2>&1 | tee s3_outfile.out
4949
pyccel-clean
5050
shell: ${{ inputs.shell_cmd }}
5151
working-directory: ./tests
5252
id: pytest_3
5353
- name: Test Fortran translations
54-
run: python -m pytest -n auto -rX ${FLAGS} -m "not (parallel or xdist_incompatible) and not (c or python) ${{ inputs.pytest_mark }}" --ignore=symbolic --ignore=ndarrays 2>&1 | tee s4_outfile.out
54+
run: python -m pytest -n auto -rX ${FLAGS} -m "not (parallel or xdist_incompatible) and not (c or python) ${{ inputs.pytest_mark }}" --ignore=ndarrays 2>&1 | tee s4_outfile.out
5555
shell: ${{ inputs.shell_cmd }}
5656
working-directory: ./tests
5757
id: pytest_4
5858
- name: Test multi-file Fortran translations
5959
run: |
60-
python -m pytest -rX ${FLAGS} -m "xdist_incompatible and not parallel and not (c or python) ${{ inputs.pytest_mark }}" --ignore=symbolic --ignore=ndarrays 2>&1 | tee s5_outfile.out
60+
python -m pytest -rX ${FLAGS} -m "xdist_incompatible and not parallel and not (c or python) ${{ inputs.pytest_mark }}" --ignore=ndarrays 2>&1 | tee s5_outfile.out
6161
pyccel-clean
6262
shell: ${{ inputs.shell_cmd }}
6363
working-directory: ./tests

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ All notable changes to this project will be documented in this file.
1414
- #1750 : Add Python support for set method `remove()`.
1515
- #1787 : Ensure `STC` is installed with Pyccel.
1616
- #1743 : Add Python support for set method `discard()`.
17+
- #1830 : Add a `pyccel.lambdify.lambdify` function to accelerate SymPy expressions.
1718
- \[INTERNALS\] Added `container_rank` property to `ast.datatypes.PyccelType` objects.
1819
- \[DEVELOPER\] Added an improved traceback to the developer-mode errors for errors in function calls.
1920

docs/quickstart.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,35 @@ Out[9]: 210.99245283018868
453453
```
454454
After subtracting the amount of time required to create an array copy from the given times, we can conclude that the pyccelised function is approximately 210 times faster than the original Python function.
455455

456+
### Interactive Usage with `lambdify`
457+
458+
While Pyccel is usually used to accelerate Python code, it is also possible to accelerate other expressions. The Pyccel library provides the `lambdify` Python function. This function is similar to SymPy's [`lambdify`](https://docs.sympy.org/latest/modules/utilities/lambdify.html) function, given a SymPy expression `f` and type annotations, `lambdify` returns a "pyccelised" function `f_fast` that can be used in the same Python session.
459+
For example:
460+
```python
461+
import numpy as np
462+
import sympy as sp
463+
from pyccel import lambdify
464+
465+
x = sp.Symbol('x')
466+
expr = x**2 + x*5
467+
f = lambdify(expr, {x : 'float'})
468+
print(f(3.0))
469+
470+
expr2 = x-x
471+
f2 = lambdify(expr, {x : 'float'}, result_type = 'float')
472+
print(f2(3.0))
473+
474+
expr = x**2 + x*5 + 4.5
475+
f3 = lambdify(expr, {x : 'T'}, templates = {'T': ['float[:]', 'float[:,:]']})
476+
x_1d = np.ones(4)
477+
x_2d = np.ones((4,2))
478+
print(f3(x_1d))
479+
print(f3(x_2d))
480+
```
481+
In practice `lambdify` uses SymPy's `NumPyPrinter` to generate code which is passed to the `epyccel` function. The `epyccel` function copies the code into a temporary Python file in the `__epyccel__` directory.
482+
Once the file has been copied, `epyccel` calls the `pyccel` command to generate a Python C extension module that contains a single pyccelised function.
483+
Then finally, it imports this function and returns it to the caller.
484+
456485
## Other Features
457486

458487
Pyccel's generated code can use parallel multi-threading through [OpenMP](https://en.wikipedia.org/wiki/OpenMP); please read [our documentation](https://github.com/pyccel/pyccel/blob/master/tutorial/openmp.md) for more details.

pyccel/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .version import __version__
2+
from .commands.lambdify import lambdify

pyccel/commands/lambdify.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""
2+
File describing commands associated with the lambdify function which converts a
3+
SymPy expression into a Pyccel-accelerated function.
4+
"""
5+
import sympy as sp
6+
from packaging import version
7+
8+
from pyccel.epyccel import epyccel
9+
from pyccel.utilities.strings import random_string
10+
11+
if version.parse(sp.__version__) >= version.parse('1.8'):
12+
from sympy.printing.numpy import NumPyPrinter
13+
else:
14+
from sympy.printing.pycode import NumPyPrinter
15+
16+
def lambdify(expr : sp.Expr, args : 'dict[sp.Symbol, str]', *, result_type : str = None,
17+
templates : 'dict[str,list[str]]' = None, **kwargs):
18+
"""
19+
Convert a SymPy expression into a Pyccel-accelerated function.
20+
21+
Convert a SymPy expression into a function that allows for fast
22+
numeric evaluation. This is done using SymPy's NumPyPrinter to
23+
generate code that can be accelerated by Pyccel.
24+
25+
Parameters
26+
----------
27+
expr : sp.Expr
28+
The SymPy expression that should be returned from the function.
29+
args : dict[sp.Symbol, str]
30+
A dictionary of the arguments of the function being created.
31+
The keys are variables representing the arguments that will be
32+
passed to the function. The values are the the type annotations
33+
for those functions.
34+
result_type : str, optional
35+
The type annotation for the result of the function. This argument
36+
is optional but it is recommended to provide it as SymPy
37+
expressions do not always evaluate to the expected type. For
38+
example if the SymPy expression simplifies to 0 then the default
39+
type will be int even if the arguments are floats.
40+
templates : dict[str,list[str]], optional
41+
A description of any templates that should be added to the
42+
function. The keys are the symbols which can be used as type
43+
specifiers, the values are a list of the type annotations which
44+
are valid types for the symbol. See
45+
<https://github.com/pyccel/pyccel/blob/devel/docs/templates.md>
46+
for more details.
47+
**kwargs : dict
48+
Additional arguments that are passed to epyccel.
49+
50+
Returns
51+
-------
52+
func
53+
A Pyccel-accelerated function which allows the evaluation of
54+
the SymPy expression.
55+
56+
See Also
57+
--------
58+
sympy.lambdify
59+
<https://docs.sympy.org/latest/modules/utilities/lambdify.html>.
60+
epyccel
61+
The function that accelerates the generated code.
62+
"""
63+
if not (isinstance(args, dict) and all(isinstance(k, sp.Symbol) and isinstance(v, str) for k,v in args.items())):
64+
raise TypeError("Argument 'args': Expected a dictionary mapping SymPy symbols to string type annotations.")
65+
66+
expr = NumPyPrinter().doprint(expr)
67+
args = ', '.join(f'{a} : "{annot}"' for a, annot in args.items())
68+
func_name = 'func_'+random_string(8)
69+
if result_type:
70+
if not isinstance(result_type, str):
71+
raise TypeError("Argument 'result_type': Expected a string type annotation.")
72+
signature = f'def {func_name}({args}) -> "{result_type}":'
73+
else:
74+
signature = f'def {func_name}({args}):'
75+
if templates:
76+
if not (isinstance(templates, dict) and all(isinstance(k, str) and hasattr(v, '__iter__') for k,v in templates.items()) \
77+
and all(all(isinstance(type_annot, str) for type_annot in v) for v in templates.values())):
78+
raise TypeError("Argument 'templates': Expected a dictionary mapping strings describing type specifiers to lists of string type annotations.")
79+
80+
decorators = '\n'.join(f'@template("{key}", ['+', '.join(f'"{annot}"' for annot in annotations)+'])' \
81+
for key, annotations in templates.items())
82+
else:
83+
decorators = ''
84+
code = f' return {expr}'
85+
numpy_import = 'import numpy\n'
86+
func = '\n'.join((numpy_import, decorators, signature, code))
87+
package = epyccel(func, **kwargs)
88+
return getattr(package, func_name)
89+

pyccel/epyccel.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,11 @@ def epyccel_seq(function_or_module, *,
118118
119119
Parameters
120120
----------
121-
function_or_module : function | module
121+
function_or_module : function | module | str
122122
Python function or module to be accelerated.
123+
If a string is passed then it is assumed to be the code from a module which
124+
should be accelerated. The module must be capable of running as a standalone
125+
file so it must include any necessary import statements.
123126
language : {'fortran', 'c', 'python'}
124127
Language of generated code (default: 'fortran').
125128
compiler : str, optional
@@ -167,7 +170,7 @@ def epyccel_seq(function_or_module, *,
167170
# Store current directory
168171
base_dirpath = os.getcwd()
169172

170-
if isinstance(function_or_module, (FunctionType, type)):
173+
if isinstance(function_or_module, (FunctionType, type, str)):
171174
dirpath = os.getcwd()
172175

173176
elif isinstance(function_or_module, ModuleType):
@@ -197,6 +200,11 @@ def epyccel_seq(function_or_module, *,
197200

198201
module_name, module_lock = get_unique_name(pymod.__name__, epyccel_dirpath)
199202

203+
elif isinstance(function_or_module, str):
204+
code = function_or_module
205+
206+
module_name, module_lock = get_unique_name('mod', epyccel_dirpath)
207+
200208
else:
201209
raise TypeError('> Expecting a FunctionType, type or a ModuleType')
202210

@@ -277,8 +285,10 @@ def epyccel( python_function_or_module, **kwargs ):
277285
278286
Parameters
279287
----------
280-
python_function_or_module : function | module
288+
python_function_or_module : function | module | str
281289
Python function or module to be accelerated.
290+
If a string is passed then it is assumed to be the code from a module which
291+
should be accelerated..
282292
**kwargs :
283293
Additional keyword arguments for configuring the compilation and acceleration process.
284294
Available options are defined in epyccel_seq.
@@ -300,7 +310,7 @@ def epyccel( python_function_or_module, **kwargs ):
300310
>>> one_f = epyccel(one, language='fortran')
301311
>>> one_c = epyccel(one, language='c')
302312
"""
303-
assert isinstance( python_function_or_module, (FunctionType, type, ModuleType) )
313+
assert isinstance( python_function_or_module, (FunctionType, type, ModuleType, str) )
304314

305315
comm = kwargs.pop('comm', None)
306316
root = kwargs.pop('root', 0)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ dependencies = [
2121
"termcolor >= 1.0.0",
2222
"textx >= 2.2",
2323
"astunparse >= 1.6.0", # astunparse is only needed for Python3.8, we should use ast.unparse when we drop Python3.8.
24+
"packaging",
2425
]
2526

2627
[project.optional-dependencies]

tests/symbolic/mappings.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
""" A file containing 2D analytical mappings to test Pyccel lambdify function.
2+
"""
3+
4+
class PolarMapping:
5+
"""
6+
Represents a Polar 2D Mapping object (Annulus).
7+
"""
8+
expressions = {'x': 'c1 + (rmin*(1-x1)+rmax*x1)*cos(x2)',
9+
'y': 'c2 + (rmin*(1-x1)+rmax*x1)*sin(x2)'}
10+
constants = {'rmin': 0.0, 'rmax': 1.0, 'c1' : 0.0, 'c2' : 0.0}
11+
12+
#==============================================================================
13+
class TargetMapping:
14+
"""
15+
Represents a Target 2D Mapping object.
16+
"""
17+
expressions = {'x': 'c1 + (1-k)*x1*cos(x2) - D*x1**2',
18+
'y': 'c2 + (1+k)*x1*sin(x2)'}
19+
constants = {'c1': 0.0, 'c2': 0.0, 'k' : 0.5, 'D' : '1.0'}
20+
21+
#==============================================================================
22+
class CzarnyMapping:
23+
"""
24+
Represents a Czarny 2D Mapping object.
25+
"""
26+
expressions = {'x': '(1 - sqrt( 1 + eps*(eps + 2*x1*cos(x2)) )) / eps',
27+
'y': 'c2 + (b / sqrt(1-eps**2/4) * x1 * sin(x2)) /'
28+
'(2 - sqrt( 1 + eps*(eps + 2*x1*cos(x2)) ))'}
29+
constants = {'eps' : 0.1, 'c2' : 0.0, 'b' : 1.0}
30+
31+

tests/symbolic/test_symbolic.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,27 @@
44
import os
55
import pytest
66

7+
import numpy as np
8+
import sympy as sp
9+
10+
import mappings
11+
12+
from pyccel import lambdify as pyc_lambdify
713
from pyccel.parser.parser import Parser
814
from pyccel.codegen.codegen import Codegen
915
from pyccel.errors.errors import Errors
1016

17+
RTOL = 1e-12
18+
ATOL = 1e-16
19+
1120
base_dir = os.path.dirname(os.path.realpath(__file__))
1221
path_dir = os.path.join(base_dir, 'scripts')
1322

1423
files = sorted(os.listdir(path_dir))
15-
files = [f for f in files if (f.endswith(".py"))]
24+
files = [f for f in files if f.endswith(".py")]
1625

1726
@pytest.mark.parametrize( "f", files )
18-
@pytest.mark.xfail(reason="Broken symbolic function support, see issue #330")
27+
@pytest.mark.skip(reason="Broken symbolic function support, see issue #330")
1928
def test_symbolic(f):
2029

2130
pyccel = Parser(f)
@@ -34,6 +43,35 @@ def test_symbolic(f):
3443
errors = Errors()
3544
errors.reset()
3645

46+
def test_lambdify(language):
47+
r1 = np.linspace(0.0, 1.0, 100)
48+
p1 = np.linspace(0.0, 2*np.pi, 100)
49+
r,p = np.meshgrid(r1, p1)
50+
x,y = sp.symbols('x1,x2')
51+
for m in (mappings.PolarMapping, mappings.TargetMapping, mappings.CzarnyMapping):
52+
expr_x = sp.sympify(m.expressions['x']).subs(m.constants)
53+
expr_y = sp.sympify(m.expressions['y']).subs(m.constants)
54+
sp_x = sp.lambdify([x, y], expr_x)
55+
sp_y = sp.lambdify([x, y], expr_y)
56+
pyc_x = pyc_lambdify(expr_x, {x : 'float[:,:]', y : 'float[:,:]'}, result_type = 'float[:,:]',
57+
language = language)
58+
pyc_y = pyc_lambdify(expr_y, {x : 'float[:,:]', y : 'float[:,:]'}, result_type = 'float[:,:]',
59+
language = language)
60+
61+
assert np.allclose(sp_x(r, p), pyc_x(r, p), rtol=RTOL, atol=ATOL)
62+
assert np.allclose(sp_y(r, p), pyc_y(r, p), rtol=RTOL, atol=ATOL)
63+
64+
pyc_x = pyc_lambdify(expr_x, {x : 'T', y : 'T'}, templates = {'T': ['float[:]', 'float[:,:]']},
65+
language = language)
66+
pyc_y = pyc_lambdify(expr_y, {x : 'T', y : 'T'}, templates = {'T': ['float[:]', 'float[:,:]']},
67+
language = language)
68+
69+
assert np.allclose(sp_x(r, p), pyc_x(r, p), rtol=RTOL, atol=ATOL)
70+
assert np.allclose(sp_y(r, p), pyc_y(r, p), rtol=RTOL, atol=ATOL)
71+
72+
assert np.allclose(sp_x(r[0,:], p[0,:]), pyc_x(r[0,:], p[0,:]), rtol=RTOL, atol=ATOL)
73+
assert np.allclose(sp_y(r[0,:], p[0,:]), pyc_y(r[0,:], p[0,:]), rtol=RTOL, atol=ATOL)
74+
3775
######################
3876
if __name__ == '__main__':
3977
print('*********************************')
@@ -43,6 +81,6 @@ def test_symbolic(f):
4381
print('*********************************')
4482

4583
for f in files:
46-
print('> testing {0}'.format(str(f)))
84+
print(f'> testing {f}')
4785
test_symbolic(f)
4886
print('\n')

0 commit comments

Comments
 (0)