Skip to content

Commit 45eeb8d

Browse files
authored
Optimize eval-setitem expressions as single eval expressions (#2695)
1 parent aaa23e8 commit 45eeb8d

File tree

3 files changed

+162
-67
lines changed

3 files changed

+162
-67
lines changed

mars/dataframe/indexing/setitem.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def tile(cls, op: "DataFrameSetitem"):
166166
]
167167
value_chunk_index_values = [v.index_value for v in value.chunks]
168168
is_identical = len(target_chunk_index_values) == len(
169-
target_chunk_index_values
169+
value_chunk_index_values
170170
) and all(
171171
c.key == v.key
172172
for c, v in zip(target_chunk_index_values, value_chunk_index_values)

mars/optimization/logical/tileable/arithmetic_query.py

Lines changed: 120 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,30 @@
1313
# limitations under the License.
1414

1515
import weakref
16-
from typing import Optional, Tuple
16+
from typing import NamedTuple, Optional
1717

18+
import numpy as np
1819
from pandas.api.types import is_scalar
1920

2021
from .... import dataframe as md
2122
from ....core import Tileable, get_output_types, ENTITY_TYPE
2223
from ....dataframe.arithmetic.core import DataFrameUnaryUfunc, DataFrameBinopUfunc
2324
from ....dataframe.base.eval import DataFrameEval
2425
from ....dataframe.indexing.getitem import DataFrameIndex
26+
from ....dataframe.indexing.setitem import DataFrameSetitem
2527
from ....typing import OperandType
2628
from ....utils import implements
2729
from ..core import OptimizationRecord, OptimizationRecordType
2830
from ..tileable.core import register_tileable_optimization_rule
2931
from .core import OptimizationRule
3032

3133

34+
class EvalExtractRecord(NamedTuple):
35+
tileable: Optional[Tileable] = None
36+
expr: Optional[str] = None
37+
variables: Optional[dict] = None
38+
39+
3240
def _get_binop_builder(op_str: str):
3341
def builder(lhs: str, rhs: str):
3442
return f"({lhs}) {op_str} ({rhs})"
@@ -60,9 +68,16 @@ def builder(lhs: str, rhs: str):
6068

6169
@register_tileable_optimization_rule([DataFrameUnaryUfunc, DataFrameBinopUfunc])
6270
class SeriesArithmeticToEval(OptimizationRule):
71+
_var_counter = 0
72+
73+
@classmethod
74+
def _next_var_id(cls):
75+
cls._var_counter += 1
76+
return cls._var_counter
77+
6378
@implements(OptimizationRule.match)
6479
def match(self, op: OperandType) -> bool:
65-
_, expr = self._extract_eval_expression(op.outputs[0])
80+
_, expr, _ = self._extract_eval_expression(op.outputs[0])
6681
return expr is not None
6782

6883
@staticmethod
@@ -91,14 +106,17 @@ def _is_select_dataframe_column(tileable) -> bool:
91106
)
92107

93108
@classmethod
94-
def _extract_eval_expression(
95-
cls, tileable
96-
) -> Tuple[Optional[Tileable], Optional[str]]:
109+
def _extract_eval_expression(cls, tileable) -> EvalExtractRecord:
97110
if is_scalar(tileable):
98-
return None, repr(tileable)
111+
if isinstance(tileable, (int, bool, str, bytes, np.integer, np.bool_)):
112+
return EvalExtractRecord(expr=repr(tileable))
113+
else:
114+
var_name = f"__eval_scalar_var{cls._next_var_id()}"
115+
var_dict = {var_name: tileable}
116+
return EvalExtractRecord(expr=f"@{var_name}", variables=var_dict)
99117

100118
if not isinstance(tileable, ENTITY_TYPE): # pragma: no cover
101-
return None, None
119+
return EvalExtractRecord()
102120

103121
if tileable in _extract_result_cache:
104122
return _extract_result_cache[tileable]
@@ -109,69 +127,75 @@ def _extract_eval_expression(
109127
result = cls._extract_unary(tileable)
110128
elif isinstance(tileable.op, DataFrameBinopUfunc):
111129
if tileable.op.fill_value is not None or tileable.op.level is not None:
112-
result = None, None
130+
result = EvalExtractRecord()
113131
else:
114132
result = cls._extract_binary(tileable)
115133
else:
116-
result = None, None
134+
result = EvalExtractRecord()
117135

118136
_extract_result_cache[tileable] = result
119137
return result
120138

121139
@classmethod
122-
def _extract_column_select(
123-
cls, tileable
124-
) -> Tuple[Optional[Tileable], Optional[str]]:
125-
return tileable.inputs[0], f"`{tileable.op.col_names}`"
140+
def _extract_column_select(cls, tileable) -> EvalExtractRecord:
141+
return EvalExtractRecord(tileable.inputs[0], f"`{tileable.op.col_names}`")
126142

127143
@classmethod
128-
def _extract_unary(cls, tileable) -> Tuple[Optional[Tileable], Optional[str]]:
144+
def _extract_unary(cls, tileable) -> EvalExtractRecord:
129145
op = tileable.op
130146
func_name = getattr(op, "_func_name") or getattr(op, "_bin_func_name")
131147
if func_name not in _func_name_to_builder: # pragma: no cover
132-
return None, None
148+
return EvalExtractRecord()
133149

134-
in_tileable, expr = cls._extract_eval_expression(op.inputs[0])
150+
in_tileable, expr, variables = cls._extract_eval_expression(op.inputs[0])
135151
if in_tileable is None:
136-
return None, None
152+
return EvalExtractRecord()
137153

138154
cls._add_collapsable_predecessor(tileable, op.inputs[0])
139-
return in_tileable, _func_name_to_builder[func_name](expr)
155+
return EvalExtractRecord(
156+
in_tileable, _func_name_to_builder[func_name](expr), variables
157+
)
140158

141159
@classmethod
142-
def _extract_binary(cls, tileable) -> Tuple[Optional[Tileable], Optional[str]]:
160+
def _extract_binary(cls, tileable) -> EvalExtractRecord:
143161
op = tileable.op
144162
func_name = getattr(op, "_func_name", None) or getattr(op, "_bit_func_name")
145163
if func_name not in _func_name_to_builder: # pragma: no cover
146-
return None, None
164+
return EvalExtractRecord()
147165

148-
lhs_tileable, lhs_expr = cls._extract_eval_expression(op.lhs)
166+
lhs_tileable, lhs_expr, lhs_vars = cls._extract_eval_expression(op.lhs)
149167
if lhs_tileable is not None:
150168
cls._add_collapsable_predecessor(tileable, op.lhs)
151-
rhs_tileable, rhs_expr = cls._extract_eval_expression(op.rhs)
169+
rhs_tileable, rhs_expr, rhs_vars = cls._extract_eval_expression(op.rhs)
152170
if rhs_tileable is not None:
153171
cls._add_collapsable_predecessor(tileable, op.rhs)
154172

155173
if lhs_expr is None or rhs_expr is None:
156-
return None, None
174+
return EvalExtractRecord()
157175
if (
158176
lhs_tileable is not None
159177
and rhs_tileable is not None
160178
and lhs_tileable.key != rhs_tileable.key
161179
):
162-
return None, None
180+
return EvalExtractRecord()
181+
182+
variables = (lhs_vars or dict()).copy()
183+
variables.update(rhs_vars or dict())
163184
in_tileable = next(t for t in [lhs_tileable, rhs_tileable] if t is not None)
164-
return in_tileable, _func_name_to_builder[func_name](lhs_expr, rhs_expr)
185+
return EvalExtractRecord(
186+
in_tileable, _func_name_to_builder[func_name](lhs_expr, rhs_expr), variables
187+
)
165188

166189
@implements(OptimizationRule.apply)
167190
def apply(self, op: OperandType):
168191
node = op.outputs[0]
169-
in_tileable, expr = self._extract_eval_expression(node)
192+
in_tileable, expr, variables = self._extract_eval_expression(node)
170193

171194
new_op = DataFrameEval(
172195
_key=node.op.key,
173196
_output_types=get_output_types(node),
174197
expr=expr,
198+
variables=variables or dict(),
175199
parser="pandas",
176200
is_query=False,
177201
)
@@ -195,39 +219,41 @@ def apply(self, op: OperandType):
195219
pass
196220

197221

198-
@register_tileable_optimization_rule([DataFrameIndex])
199-
class DataFrameBoolEvalToQuery(OptimizationRule):
200-
def match(self, op: "DataFrameIndex") -> bool:
222+
class _DataFrameEvalRewriteRule(OptimizationRule):
223+
def match(self, op: OperandType) -> bool:
224+
optimized_eval_op = self._get_optimized_eval_op(op)
201225
if (
202-
op.col_names is not None
203-
or not isinstance(op.mask, md.Series)
204-
or op.mask.dtype != bool
226+
not isinstance(optimized_eval_op, DataFrameEval)
227+
or optimized_eval_op.is_query
228+
or optimized_eval_op.inputs[0].key != op.inputs[0].key
205229
):
206230
return False
207-
optimized = self._records.get_optimization_result(op.mask)
208-
mask_op = optimized.op if optimized is not None else op.mask.op
209-
if not isinstance(mask_op, DataFrameEval) or mask_op.is_query:
210-
return False
211231
return True
212232

213-
def apply(self, op: "DataFrameIndex"):
233+
def _build_new_eval_op(self, op: OperandType):
234+
raise NotImplementedError
235+
236+
def _get_optimized_eval_op(self, op: OperandType) -> OperandType:
237+
in_columnar_node = self._get_input_columnar_node(op)
238+
optimized = self._records.get_optimization_result(in_columnar_node)
239+
return optimized.op if optimized is not None else in_columnar_node.op
240+
241+
def _get_input_columnar_node(self, op: OperandType) -> ENTITY_TYPE:
242+
raise NotImplementedError
243+
244+
def apply(self, op: DataFrameIndex):
214245
node = op.outputs[0]
215246
in_tileable = op.inputs[0]
247+
in_columnar_node = self._get_input_columnar_node(op)
248+
249+
new_op = self._build_new_eval_op(op)
250+
new_op._key = node.op.key
216251

217-
optimized = self._records.get_optimization_result(op.mask)
218-
mask_op = optimized.op if optimized is not None else op.mask.op
219-
new_op = DataFrameEval(
220-
_key=node.op.key,
221-
_output_types=get_output_types(node),
222-
expr=mask_op.expr,
223-
parser="pandas",
224-
is_query=True,
225-
)
226252
new_node = new_op.new_tileable([in_tileable], **node.params).data
227253
new_node._key = node.key
228254
new_node._id = node.id
229255

230-
self._add_collapsable_predecessor(node, op.mask)
256+
self._add_collapsable_predecessor(node, in_columnar_node)
231257
self._remove_collapsable_predecessors(node)
232258

233259
self._replace_node(node, new_node)
@@ -242,3 +268,50 @@ def apply(self, op: "DataFrameIndex"):
242268
self._graph.results[i] = new_node
243269
except ValueError:
244270
pass
271+
272+
273+
@register_tileable_optimization_rule([DataFrameIndex])
274+
class DataFrameBoolEvalToQuery(_DataFrameEvalRewriteRule):
275+
def match(self, op: DataFrameIndex) -> bool:
276+
if (
277+
op.col_names is not None
278+
or not isinstance(op.mask, md.Series)
279+
or op.mask.dtype != bool
280+
):
281+
return False
282+
return super().match(op)
283+
284+
def _get_input_columnar_node(self, op: OperandType) -> ENTITY_TYPE:
285+
return op.mask
286+
287+
def _build_new_eval_op(self, op: OperandType):
288+
in_eval_op = self._get_optimized_eval_op(op)
289+
return DataFrameEval(
290+
_output_types=get_output_types(op.outputs[0]),
291+
expr=in_eval_op.expr,
292+
variables=in_eval_op.variables,
293+
parser="pandas",
294+
is_query=True,
295+
)
296+
297+
298+
@register_tileable_optimization_rule([DataFrameSetitem])
299+
class DataFrameEvalSetItemToEval(_DataFrameEvalRewriteRule):
300+
def match(self, op: DataFrameSetitem):
301+
if not isinstance(op.indexes, str) or not isinstance(op.value, md.Series):
302+
return False
303+
return super().match(op)
304+
305+
def _get_input_columnar_node(self, op: DataFrameSetitem) -> ENTITY_TYPE:
306+
return op.value
307+
308+
def _build_new_eval_op(self, op: DataFrameSetitem):
309+
in_eval_op = self._get_optimized_eval_op(op)
310+
return DataFrameEval(
311+
_output_types=get_output_types(op.outputs[0]),
312+
expr=f"`{op.indexes}` = {in_eval_op.expr}",
313+
variables=in_eval_op.variables,
314+
parser="pandas",
315+
is_query=False,
316+
self_target=True,
317+
)

mars/optimization/logical/tileable/tests/test_arithmetic_query.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import re
16+
1517
import numpy as np
1618
import pandas as pd
1719

@@ -22,6 +24,13 @@
2224
from .. import optimize
2325

2426

27+
_var_pattern = re.compile(r"@__eval_scalar_var\d+")
28+
29+
30+
def _norm_vars(var_str):
31+
return _var_pattern.sub("@scalar", var_str)
32+
33+
2534
@enter_mode(build=True)
2635
def test_arithmetic_query(setup):
2736
raw = pd.DataFrame(np.random.rand(100, 10), columns=list("ABCDEFGHIJ"))
@@ -62,7 +71,6 @@ def test_arithmetic_query(setup):
6271

6372
pd.testing.assert_series_equal(df2.execute().fetch(), -raw["A"] + raw["B"] * 5)
6473

65-
raw = pd.DataFrame(np.random.rand(100, 10), columns=list("ABCDEFGHIJ"))
6674
df1 = md.DataFrame(raw, chunk_size=10)
6775
df2 = -df1["A"] + df1["B"] * 5 + 3 * df1["C"]
6876
graph = TileableGraph([df1["A"].data, df2.data])
@@ -74,22 +82,6 @@ def test_arithmetic_query(setup):
7482
r_df2, _r_col_a = fetch(execute(df2, df1["A"]))
7583
pd.testing.assert_series_equal(r_df2, -raw["A"] + raw["B"] * 5 + 3 * raw["C"])
7684

77-
df1 = md.DataFrame(raw, chunk_size=10)
78-
df2 = md.DataFrame(raw2, chunk_size=10)
79-
df3 = df1.merge(df2, on="A", suffixes=("", "_"))
80-
df3["K"] = df4 = df3["A"] * (1 - df3["B"])
81-
graph = TileableGraph([df3.data])
82-
next(TileableGraphBuilder(graph).build())
83-
records = optimize(graph)
84-
opt_df4 = records.get_optimization_result(df4.data)
85-
assert opt_df4.op.expr == "(`A`) * ((1) - (`B`))"
86-
assert len(graph) == 5
87-
assert len([n for n in graph if isinstance(n.op, DataFrameEval)]) == 1
88-
89-
r_df3 = raw.merge(raw2, on="A", suffixes=("", "_"))
90-
r_df3["K"] = r_df3["A"] * (1 - r_df3["B"])
91-
pd.testing.assert_frame_equal(df3.execute().fetch(), r_df3)
92-
9385

9486
@enter_mode(build=True)
9587
def test_bool_eval_to_query(setup):
@@ -111,7 +103,7 @@ def test_bool_eval_to_query(setup):
111103
opt_df2 = records.get_optimization_result(df2.data)
112104
assert isinstance(opt_df2.op, DataFrameEval)
113105
assert opt_df2.op.is_query
114-
assert opt_df2.op.expr == "((`A`) > (0.5)) & ((`C`) < (0.5))"
106+
assert _norm_vars(opt_df2.op.expr) == "((`A`) > (@scalar)) & ((`C`) < (@scalar))"
115107

116108
pd.testing.assert_frame_equal(
117109
df2.execute().fetch(), raw[(raw["A"] > 0.5) & (raw["C"] < 0.5)]
@@ -138,7 +130,37 @@ def test_bool_eval_to_query(setup):
138130
next(TileableGraphBuilder(graph).build())
139131
records = optimize(graph)
140132
opt_df2 = records.get_optimization_result(df2.data)
141-
assert opt_df2.op.expr == "(`b`) < (Timestamp('2022-03-20 00:00:00'))"
133+
assert _norm_vars(opt_df2.op.expr) == "(`b`) < (@scalar)"
142134

143135
r_df2 = fetch(execute(df2))
144136
pd.testing.assert_frame_equal(r_df2, raw[raw.b < pd.Timestamp("2022-3-20")])
137+
138+
139+
@enter_mode(build=True)
140+
def test_eval_setitem_to_eval(setup):
141+
raw = pd.DataFrame(np.random.rand(100, 10), columns=list("ABCDEFGHIJ"))
142+
raw2 = pd.DataFrame(np.random.rand(100, 5), columns=list("ABCDE"))
143+
144+
# does not support non-eval value setting
145+
df1 = md.DataFrame(raw, chunk_size=10)
146+
df1["K"] = 345
147+
graph = TileableGraph([df1.data])
148+
next(TileableGraphBuilder(graph).build())
149+
records = optimize(graph)
150+
assert records.get_optimization_result(df1.data) is None
151+
152+
df1 = md.DataFrame(raw, chunk_size=10)
153+
df2 = md.DataFrame(raw2, chunk_size=10)
154+
df3 = df1.merge(df2, on="A", suffixes=("", "_"))
155+
df3["K"] = df3["A"] * (1 - df3["B"])
156+
graph = TileableGraph([df3.data])
157+
next(TileableGraphBuilder(graph).build())
158+
records = optimize(graph)
159+
opt_df3 = records.get_optimization_result(df3.data)
160+
assert opt_df3.op.expr == "`K` = (`A`) * ((1) - (`B`))"
161+
assert len(graph) == 4
162+
assert len([n for n in graph if isinstance(n.op, DataFrameEval)]) == 1
163+
164+
r_df3 = raw.merge(raw2, on="A", suffixes=("", "_"))
165+
r_df3["K"] = r_df3["A"] * (1 - r_df3["B"])
166+
pd.testing.assert_frame_equal(df3.execute().fetch(), r_df3)

0 commit comments

Comments
 (0)