Skip to content

Commit afc3e4a

Browse files
authored
Using is_uniform to avoid repeated checks (#1517)
Now we have the `is_uniform` attribute; this PR uses it to speed up the match test with `SequenceBlank` and `SequenceNullBlank`. Also, ensure that this property is computed and stored in list/expression creation operations (`Table`, `Sum`, `Product`). * using `is_uniform` property in matching `SequenceBlank` and `SequenceNullBlank` patterns * `get_result` now accepts an optional attribute `is_uniform` to set this property on its result. * `Table`, `Sum`, and `Product` now set the element_property `is_uniform` to provide a faster evaluation of large lists. * convert_expression_elements now handles is_uniform * Start a workflow to benchmark some basic evaluations and pattern matching.
1 parent d98cc2c commit afc3e4a

File tree

9 files changed

+103
-52
lines changed

9 files changed

+103
-52
lines changed

mathics/builtin/arithmetic.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
Test,
4343
)
4444
from mathics.core.convert.sympy import SympyExpression, from_sympy
45-
from mathics.core.element import BaseElement
45+
from mathics.core.element import BaseElement, ElementsProperties
4646
from mathics.core.evaluation import Evaluation
4747
from mathics.core.expression import Expression
4848
from mathics.core.expression_predefined import (
@@ -784,8 +784,12 @@ class Product(IterationFunction, SympyFunction, PrefixOperator):
784784
sympy_name = "Product"
785785
throw_iterb = False
786786

787-
def get_result(self, elements):
788-
return Expression(SymbolTimes, *elements)
787+
def get_result(self, elements, is_uniform=False):
788+
return Expression(
789+
SymbolTimes,
790+
*elements,
791+
elements_properties=ElementsProperties(is_uniform=is_uniform),
792+
)
789793

790794
def to_sympy(self, expr, **kwargs):
791795
if expr.has_form("Product", 2) and expr.elements[1].has_form("List", 3):
@@ -1021,8 +1025,12 @@ class Sum(IterationFunction, SympyFunction, PrefixOperator):
10211025
# Do not throw warning message for symbolic iteration bounds
10221026
throw_iterb = False
10231027

1024-
def get_result(self, elements) -> Expression:
1025-
return Expression(SymbolPlus, *elements)
1028+
def get_result(self, elements, is_uniform=False) -> Expression:
1029+
return Expression(
1030+
SymbolPlus,
1031+
*elements,
1032+
elements_properties=ElementsProperties(is_uniform=is_uniform),
1033+
)
10261034

10271035
def to_sympy(self, expr, **kwargs) -> Optional[SympyExpression]:
10281036
"""

mathics/builtin/functional/apply_fns_to_lists.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
Many mathematical functions are automatically taken to be "listable", so that \
99
they are always applied to every element in a list.
1010
"""
11-
11+
from dataclasses import replace as dc_replace
1212
from typing import Iterable
1313

1414
from mathics.core.atoms import Integer, Integer0, Integer1, Integer3
@@ -94,7 +94,10 @@ def callback(level):
9494
if isinstance(level, Atom):
9595
return level
9696
else:
97-
return Expression(f, *level.elements)
97+
elem_prop = level.elements_properties
98+
if elem_prop is not None:
99+
elem_prop = dc_replace(elem_prop, elements_fully_evaluated=False)
100+
return Expression(f, *level.elements, elements_properties=elem_prop)
98101

99102
heads = self.get_option(options, "Heads", evaluation) is SymbolTrue
100103
result, _ = walk_levels(expr, start, stop, heads=heads, callback=callback)
@@ -154,6 +157,10 @@ def callback(level):
154157

155158
heads = self.get_option(options, "Heads", evaluation) is SymbolTrue
156159
result, _ = walk_levels(expr, start, stop, heads=heads, callback=callback)
160+
elem_prop = result.elements_properties
161+
if elem_prop is not None:
162+
elem_prop.elements_fully_evaluated = False
163+
result.elements_properties
157164

158165
return result
159166

@@ -288,6 +295,9 @@ def callback(level, pos: Iterable):
288295
result, depth = walk_levels(
289296
expr, start, stop, heads=heads, callback=callback, include_pos=True
290297
)
298+
elem_prop = result.elements_properties
299+
if elem_prop is not None:
300+
elem_prop.elements_fully_evaluated = False
291301

292302
return result
293303

mathics/builtin/list/constructing.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -566,10 +566,12 @@ class Table(IterationFunction):
566566

567567
summary_text = "make a table of values of an expression"
568568

569-
def get_result(self, elements) -> ListExpression:
569+
def get_result(self, elements, is_uniform=False) -> ListExpression:
570570
return ListExpression(
571571
*elements,
572-
elements_properties=ElementsProperties(elements_fully_evaluated=True),
572+
elements_properties=ElementsProperties(
573+
elements_fully_evaluated=True, is_uniform=is_uniform
574+
),
573575
)
574576

575577

mathics/builtin/patterns/basic.py

Lines changed: 52 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@
1919
BLANKSEQUENCE_GENERAL_PATTERN_SORT_KEY,
2020
BLANKSEQUENCE_WITH_PATTERN_PATTERN_SORT_KEY,
2121
)
22-
from mathics.core.symbols import BaseElement
22+
from mathics.core.symbols import BaseElement, Symbol
2323

2424
# This tells documentation how to sort this module
2525
sort_order = "mathics.builtin.rules-and-patterns.basic"
2626

2727

2828
class _Blank(PatternObject, ABC):
2929
arg_counts = [0, 1]
30-
30+
target_head: OptionalType[Symbol]
3131
_instance = None
3232

3333
def __new__(cls, *args, **kwargs):
@@ -49,12 +49,14 @@ def init(
4949
) -> None:
5050
super().init(expr, evaluation=evaluation)
5151
if expr.elements:
52-
self.head = expr.elements[0]
52+
target_head = expr.elements[0]
53+
assert isinstance(target_head, Symbol)
54+
self.target_head = target_head
5355
else:
54-
# FIXME: elswhere, some code wants to
56+
# FIXME: elsewhere, some code wants to
5557
# get the attributes of head.
5658
# So is this really the best thing to do here?
57-
self.head = None
59+
self.target_head = None
5860

5961

6062
class Blank(_Blank):
@@ -100,15 +102,17 @@ class Blank(_Blank):
100102
summary_text = "match to any single expression"
101103

102104
def match(self, expression: BaseElement, pattern_context: dict):
105+
if expression.has_form("Sequence", 0):
106+
return
107+
108+
target_head = self.target_head
109+
if target_head is not None and expression.get_head() is not target_head:
110+
return
111+
112+
# Match!
103113
vars_dict = pattern_context["vars_dict"]
104114
yield_func = pattern_context["yield_func"]
105-
106-
if not expression.has_form("Sequence", 0):
107-
if self.head is not None:
108-
if expression.get_head().sameQ(self.head):
109-
yield_func(vars_dict, None)
110-
else:
111-
yield_func(vars_dict, None)
115+
yield_func(vars_dict, None)
112116

113117
@property
114118
def element_order(self):
@@ -157,19 +161,26 @@ class BlankNullSequence(_Blank):
157161

158162
def match(self, expression: Expression, pattern_context: dict):
159163
"""Match with a BlankNullSequence"""
160-
vars_dict = pattern_context["vars_dict"]
161-
yield_func = pattern_context["yield_func"]
162-
elements = expression.get_sequence()
163-
if self.head:
164-
ok = True
164+
165+
target_head = self.target_head
166+
if target_head:
167+
elements = expression.get_sequence()
168+
is_uniform = False
169+
if isinstance(expression, Expression):
170+
element_properties = expression.elements_properties
171+
if element_properties is not None:
172+
is_uniform = element_properties.is_uniform
165173
for element in elements:
166-
if element.get_head() != self.head:
167-
ok = False
174+
if target_head is not element.get_head():
175+
return
176+
# If the expression is uniform, no further checks are necessary.
177+
if is_uniform:
168178
break
169-
if ok:
170-
yield_func(vars_dict, None)
171-
else:
172-
yield_func(vars_dict, None)
179+
180+
# Match!
181+
vars_dict = pattern_context["vars_dict"]
182+
yield_func = pattern_context["yield_func"]
183+
yield_func(vars_dict, None)
173184

174185
@property
175186
def element_order(self) -> tuple:
@@ -240,21 +251,29 @@ class BlankSequence(_Blank):
240251
summary_text = "match to a non-empty sequence of elements"
241252

242253
def match(self, expression: Expression, pattern_context: dict):
243-
vars_dict = pattern_context["vars_dict"]
244-
yield_func = pattern_context["yield_func"]
245254
elements = expression.get_sequence()
255+
246256
if not elements:
247257
return
248-
if self.head:
249-
ok = True
258+
259+
target_head = self.target_head
260+
if target_head:
261+
is_uniform = False
262+
if isinstance(expression, Expression):
263+
element_properties = expression.elements_properties
264+
if element_properties is not None:
265+
is_uniform = element_properties.is_uniform
250266
for element in elements:
251-
if element.get_head() != self.head:
252-
ok = False
267+
if target_head is not element.get_head():
268+
return
269+
# If the expression is uniform, no further checks are necessary.
270+
if is_uniform:
253271
break
254-
if ok:
255-
yield_func(vars_dict, None)
256-
else:
257-
yield_func(vars_dict, None)
272+
273+
# Match!
274+
vars_dict = pattern_context["vars_dict"]
275+
yield_func = pattern_context["yield_func"]
276+
yield_func(vars_dict, None)
258277

259278
def get_match_count(self, vars_dict: OptionalType[dict] = None) -> tuple:
260279
return (1, None)

mathics/builtin/procedural.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ class Do(IterationFunction):
309309
allow_loopcontrol = True
310310
summary_text = "evaluate an expression looping over a variable"
311311

312-
def get_result(self, _items):
312+
def get_result(self, _items, is_uniform=False):
313313
return SymbolNull
314314

315315

mathics/core/builtin.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,7 +1023,7 @@ class IterationFunction(Builtin, ABC):
10231023
allow_loopcontrol = False
10241024
throw_iterb = True
10251025

1026-
def get_result(self, elements) -> Expression:
1026+
def get_result(self, elements, is_uniform=False) -> Expression:
10271027
raise NotImplementedError
10281028

10291029
def eval_symbol(self, expr, iterator, evaluation):
@@ -1159,6 +1159,8 @@ def eval_iter(self, expr, i, imin, imax, di, evaluation):
11591159
).evaluate(evaluation)
11601160

11611161
result = []
1162+
last_head = None
1163+
is_uniform = True
11621164
while True:
11631165
cont = Expression(SymbolLessEqual, index, normalised_range).evaluate(
11641166
evaluation
@@ -1186,6 +1188,10 @@ def eval_iter(self, expr, i, imin, imax, di, evaluation):
11861188
evaluation,
11871189
)
11881190
result.append(item)
1191+
if last_head is None:
1192+
last_head = item.get_head()
1193+
elif is_uniform and last_head is not item.get_head():
1194+
is_uniform = False
11891195
except ContinueInterrupt:
11901196
if self.allow_loopcontrol:
11911197
pass
@@ -1202,7 +1208,7 @@ def eval_iter(self, expr, i, imin, imax, di, evaluation):
12021208
else:
12031209
raise
12041210
index = Expression(SymbolPlus, index, Integer1).evaluate(evaluation)
1205-
return self.get_result(result)
1211+
return self.get_result(result, is_uniform=is_uniform)
12061212

12071213
def eval_list(self, expr, i, items, evaluation):
12081214
"%(name)s[expr_, {i_Symbol, {items___}}]"

mathics/core/expression.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1956,21 +1956,23 @@ def construct(elements):
19561956

19571957
# Note: this function is called a *lot* so it needs to be fast.
19581958
def convert_expression_elements(
1959-
elements: Iterable, conversion_fn: Callable = from_python
1959+
elements: Iterable, conversion_fn: Callable = from_python, is_uniform: bool = True
19601960
) -> Tuple[tuple, ElementsProperties, Optional[tuple]]:
19611961
"""
19621962
Convert and return tuple of Elements from the Python-like items in
19631963
`elements`, along with elements properties of the elements tuple,
19641964
and a tuple of literal values if it elements are all literal
19651965
otherwise, None.
1966+
By default, is is assumed that `elements` are *uniform*, which is the typical case
1967+
of elements coming from applying a numerical function to a set of different arguments.
19661968
19671969
The return information is suitable for use to the Expression() constructor.
19681970
19691971
"""
19701972

19711973
# All of the properties start out optimistic (True) and are reset when that
19721974
# proves wrong.
1973-
elements_properties = ElementsProperties(True, True, True)
1975+
elements_properties = ElementsProperties(True, True, True, is_uniform)
19741976

19751977
is_literal = True
19761978
values = [] # If is_literal, "values" contains the (Python) literal values

mathics/eval/parts.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,9 @@ def walk_levels(
167167

168168
# FIXME: we could keep track of elements properties here.
169169
elements = []
170+
elem_prop = expr.elements_properties
170171
for index, element in enumerate(expr.elements):
171-
element, element_depth = walk_levels(
172+
new_element, element_depth = walk_levels(
172173
element,
173174
start,
174175
stop,
@@ -179,8 +180,11 @@ def walk_levels(
179180
cur_pos + [index + 1],
180181
)
181182
depth = max(element_depth + 1, depth)
182-
elements.append(element)
183-
new_expr = make_expression(head, *elements)
183+
elements.append(new_element)
184+
if new_element is not element:
185+
elem_prop = None
186+
187+
new_expr = make_expression(head, *elements, elements_properties=elem_prop)
184188

185189
if is_in_level(current, depth, start, stop):
186190
if include_pos:

test/timings/test_uniform_tables.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
table_non_uniform_expr = session.evaluate(
1111
"nonuniformTable=Table[If[i==0,1,1./(1.+i^2)],{i, 0,1000}]"
1212
)
13-
# assert table_uniform_expr.elements_properties.is_uniform
14-
# assert not table_non_uniform_expr.elements_properties.is_uniform
13+
assert table_uniform_expr.elements_properties.is_uniform
14+
assert not table_non_uniform_expr.elements_properties.is_uniform
1515

1616

1717
@pytest.mark.skipif(

0 commit comments

Comments
 (0)