Skip to content

Commit 26bcaf1

Browse files
author
mmatera
committed
using sympy lambdify for compilation
1 parent 9912b72 commit 26bcaf1

File tree

4 files changed

+129
-72
lines changed

4 files changed

+129
-72
lines changed

mathics/builtin/compilation.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from types import FunctionType
1212

1313
from mathics.builtin.box.compilation import CompiledCodeBox
14-
from mathics.core.atoms import Integer, String
14+
from mathics.core.atoms import Complex, Integer, Real, String
1515
from mathics.core.attributes import A_HOLD_ALL, A_PROTECTED
1616
from mathics.core.builtin import Builtin
1717
from mathics.core.convert.expression import to_mathics_list
@@ -83,7 +83,6 @@ class Compile(Builtin):
8383

8484
def eval(self, vars, expr, evaluation: Evaluation):
8585
"Compile[vars_, expr_]"
86-
8786
if not vars.has_form("List", None):
8887
evaluation.message("Compile", "invars")
8988
return
@@ -167,7 +166,11 @@ def to_sympy(self, *args, **kwargs):
167166
raise NotImplementedError
168167

169168
def __hash__(self):
170-
return hash(("CompiledCode", ctypes.addressof(self.cfunc))) # XXX hack
169+
try:
170+
return hash(("CompiledCode", ctypes.addressof(self.cfunc))) # XXX hack
171+
except TypeError:
172+
return hash(("CompiledCode", self.cfunc,)) # XXX hack
173+
171174

172175
def atom_to_boxes(self, f, evaluation: Evaluation):
173176
return CompiledCodeBox(String(self.__str__()), evaluation=evaluation)
@@ -191,27 +194,38 @@ class CompiledFunction(Builtin):
191194
192195
"""
193196

194-
messages = {"argerr": "Invalid argument `1` should be Integer, Real or boolean."}
197+
messages = {"argerr": "Invalid argument `1` should be Integer, Real, Complex or boolean."}
195198
summary_text = "A CompiledFunction object."
196199

197200
def eval(self, argnames, expr, code, args, evaluation: Evaluation):
198201
"CompiledFunction[argnames_, expr_, code_CompiledCode][args__]"
199-
200202
argseq = args.get_sequence()
201203

202204
if len(argseq) != len(code.args):
203205
return
204206

205207
py_args = []
206-
for arg in argseq:
207-
if isinstance(arg, Integer):
208-
py_args.append(arg.get_int_value())
208+
args_spec = code.args or []
209+
if len(args_spec)!= len(argseq):
210+
evaluation.mesage("CompiledFunction","cfct", Integer(len(argseq)), Integer(len(args_spec)))
211+
return
212+
for arg, spec in zip(argseq, args_spec):
213+
# TODO: check if the types are consistent.
214+
# If not, show a message.
215+
if isinstance(arg, (Integer, Real, Complex)):
216+
val = arg.value
209217
elif arg.sameQ(SymbolTrue):
210-
py_args.append(True)
218+
val = True
211219
elif arg.sameQ(SymbolFalse):
212-
py_args.append(False)
220+
val = False
213221
else:
214-
py_args.append(arg.round_to_float(evaluation))
222+
val = arg.to_python()
223+
try:
224+
val = spec.type(val)
225+
except TypeError:
226+
# Fallback by replace values in expr?
227+
return
228+
py_args.append(val)
215229
try:
216230
result = code.cfunc(*py_args)
217231
except (TypeError, ctypes.ArgumentError):

mathics/compile/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,6 @@ class CompileArg:
1010
def __init__(self, name, type):
1111
self.name = name
1212
self.type = type
13+
14+
def __repr__(self):
15+
return f"{self.name}:{self.type}"

mathics/core/convert/function.py

Lines changed: 100 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,39 @@
11
# -*- coding: utf-8 -*-
2+
import numpy
3+
from typing import Callable, List, Optional, Tuple
24

3-
from typing import Callable, Optional, Tuple
45

56
from mathics.core.evaluation import Evaluation
67
from mathics.core.expression import Expression, from_python
78
from mathics.core.symbols import Symbol, SymbolFalse, SymbolTrue
8-
from mathics.core.systemsymbols import SymbolBlank, SymbolInteger, SymbolReal
9+
from mathics.core.systemsymbols import SymbolBlank, SymbolInteger, SymbolReal, SymbolComplex, SymbolOr
910
from mathics.eval.nevaluator import eval_N
11+
from mathics.core.convert.lambdify import lambdify_compile, CompileError as LambdifyCompileError
12+
13+
14+
PERMITTED_TYPES = {
15+
Expression(SymbolBlank, SymbolInteger): int,
16+
Expression(SymbolBlank, SymbolReal): float,
17+
Expression(SymbolBlank, SymbolComplex): complex,
18+
Expression(SymbolOr, SymbolTrue, SymbolFalse): bool,
19+
Expression(SymbolOr, SymbolFalse, SymbolTrue): bool,
20+
}
21+
1022

1123
try:
1224
from mathics.compile import CompileArg, CompileError, _compile
1325
from mathics.compile.types import bool_type, int_type, real_type
1426

15-
use_llvm = True
27+
USE_LLVM = True
1628
# _Complex not implemented
17-
permitted_types = {
18-
Expression(SymbolBlank, SymbolInteger): int_type,
19-
Expression(SymbolBlank, SymbolReal): real_type,
20-
SymbolTrue: bool_type,
21-
SymbolFalse: bool_type,
29+
LLVM_TYPE_TRANSLATION = {
30+
int: int_type,
31+
float: real_type,
32+
bool: bool_type,
2233
}
2334
except ImportError:
24-
use_llvm = False
25-
bool_type = bool
26-
int_type = int
27-
real_type = float
35+
USE_LLVM = False
2836

29-
permitted_types = {
30-
Expression(SymbolBlank, SymbolInteger): int,
31-
Expression(SymbolBlank, SymbolReal): float,
32-
SymbolTrue: bool,
33-
SymbolFalse: bool,
34-
}
3537

3638

3739
class CompileDuplicateArgName(Exception):
@@ -44,69 +46,66 @@ def __init__(self, var):
4446
self.var = var
4547

4648

47-
def expression_to_callable(
48-
expr: Expression,
49-
args: Optional[list] = None,
50-
evaluation: Optional[Evaluation] = None,
51-
) -> Optional[Callable]:
49+
50+
def expression_to_llvm(expr: Expression, args:Optional[list]=None, evaluation: Optional[Evaluation]=None):
5251
"""
53-
Return a Python callable from an expression. If llvm is available,
54-
tries to produce llvm code. Otherwise, returns a Python function.
52+
Convert an expression to LLVM code. None if it fails.
5553
expr: Expression
5654
args: a list of CompileArg elements
5755
evaluation: an Evaluation object used if the llvm compilation fails
5856
"""
5957
try:
60-
cfunc = _compile(expr, args) if (use_llvm and args is not None) else None
58+
return _compile(expr, args) if (USE_LLVM and args is not None) else None
6159
except CompileError:
62-
cfunc = None
63-
64-
if cfunc is None:
65-
if evaluation is None:
66-
raise CompileError
67-
try:
68-
69-
def _pythonized_mathics_expr(*x):
70-
inner_evaluation = Evaluation(definitions=evaluation.definitions)
71-
x_mathics = (from_python(u) for u in x[: len(args)])
72-
vars = dict(list(zip([a.name for a in args], x_mathics)))
73-
pyexpr = expr.replace_vars(vars)
74-
pyexpr = eval_N(pyexpr, inner_evaluation)
75-
res = pyexpr.to_python()
76-
return res
60+
return None
7761

78-
# TODO: check if we can use numba to compile this...
79-
cfunc = _pythonized_mathics_expr
80-
except Exception:
81-
cfunc = None
82-
return cfunc
8362

84-
85-
def expression_to_callable_and_args(
86-
expr: Expression,
87-
vars: Optional[list] = None,
88-
evaluation: Optional[Evaluation] = None,
89-
) -> Tuple[Optional[Callable], Optional[list]]:
63+
def expression_to_python_function(
64+
expr: Expression,
65+
args: Optional[list] = None,
66+
evaluation: Optional[Evaluation] = None,
67+
) -> Optional[Callable]:
9068
"""
91-
Return a tuple of Python callable and a list of CompileArgs.
92-
expr: A Mathics Expression object
93-
vars: a list of Symbols or Mathics Lists of the form {Symbol, Type}
69+
Return a Python function from an expression.
70+
expr: Expression
71+
args: a list of CompileArg elements
72+
evaluation: an Evaluation object used if the llvm compilation fails
73+
"""
74+
try:
75+
def _pythonized_mathics_expr(*x):
76+
inner_evaluation = Evaluation(definitions=evaluation.definitions)
77+
x_mathics = (from_python(u) for u in x[: len(args)])
78+
vars = dict(list(zip([a.name for a in args], x_mathics)))
79+
pyexpr = expr.replace_vars(vars)
80+
pyexpr = eval_N(pyexpr, inner_evaluation)
81+
res = pyexpr.to_python()
82+
return res
83+
84+
# TODO: check if we can use numba to compile this...
85+
return _pythonized_mathics_expr
86+
except Exception:
87+
return None
88+
89+
90+
def collect_args(vars)->Optional[List[CompileArg]]:
91+
"""
92+
Convert a List expression into a list of CompileArg objects.
9493
"""
9594
if vars is None:
96-
args = None
95+
return None
9796
else:
9897
args = []
9998
names = []
10099
for var in vars:
101100
if isinstance(var, Symbol):
102101
symb = var
103102
name = symb.get_name()
104-
typ = real_type
103+
typ = float
105104
elif var.has_form("List", 2):
106105
symb, typ = var.elements
107-
if isinstance(symb, Symbol) and typ in permitted_types:
106+
if isinstance(symb, Symbol) and typ in PERMITTED_TYPES:
108107
name = symb.get_name()
109-
typ = permitted_types[typ]
108+
typ = PERMITTED_TYPES[typ]
110109
else:
111110
raise CompileWrongArgType(var)
112111
else:
@@ -116,5 +115,46 @@ def expression_to_callable_and_args(
116115
raise CompileDuplicateArgName(symb)
117116
names.append(name)
118117
args.append(CompileArg(name, typ))
118+
return args
119+
120+
121+
122+
def expression_to_callable_and_args(
123+
expr: Expression,
124+
vars: Optional[list] = None,
125+
evaluation: Optional[Evaluation] = None,
126+
debug:int = 0,
127+
vectorize=False
128+
) -> Tuple[Callable, Optional[list]]:
129+
"""
130+
Return a tuple of Python callable and a list of CompileArgs.
131+
expr: A Mathics Expression object
132+
vars: a list of Symbols or Mathics Lists of the form {Symbol, Type}
133+
"""
134+
args = collect_args(vars)
135+
136+
# First, try to lambdify the expression:
137+
try:
138+
cfunc = lambdify_compile(evaluation, expr, [arg.name for arg in args], debug)
139+
# lambdify_compile returns an already vectorized expression.
140+
return cfunc, args
141+
except LambdifyCompileError:
142+
pass
143+
144+
# Then, try with llvm if available
145+
if USE_LLVM:
146+
try:
147+
llvm_args = None if args is None else [CompileArg(compile_arg.name, LLVM_TYPE_TRANSLATION[compile_arg.typ]) for compile_arg in args]
148+
cfunc = expression_to_llvm(expr, llvm_args, evaluation)
149+
if vectorize:
150+
cfunc = numpy.vectorize(cfunc)
151+
return cfunc, llvm_args
152+
except KeyError:
153+
pass
154+
155+
# Last resource
156+
cfunc = expression_to_python_function(expr, args, evaluation)
157+
if vectorize:
158+
cfunc = numpy.vectorize(cfunc)
159+
return cfunc, args
119160

120-
return expression_to_callable(expr, args, evaluation), args

mathics/core/convert/lambdify.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def lambdify_compile(evaluation, expr, names, debug=0):
8181
# Use numpy and scipy to do the evaluation so that operations are vectorized.
8282
# Augment the default numpy mappings with some additional ones not handled by default.
8383
try:
84-
symbols = sympy.symbols(names)
84+
symbols = sympy.symbols(tuple(strip_context(name) for name in names))
8585
# compiled_function = sympy.lambdify(symbols, sympy_expr, mappings)
8686
compiled_function = sympy.lambdify(
8787
symbols, sympy_expr, modules=["numpy", "scipy"]

0 commit comments

Comments
 (0)