Skip to content

Commit 563c1b3

Browse files
authored
feat(python): enable newer passes + mimic options from julia (#1592)
* feat(python): enable newer passes + mimic options from julia * test: try dumping results (drop me) * fix: remove debug printing * feat(python): dump mlir source on failure * fix: reshape to rank 0 tensor * fix: move dump to a function
1 parent 4c28933 commit 563c1b3

File tree

3 files changed

+131
-68
lines changed

3 files changed

+131
-68
lines changed

src/enzyme_ad/jax/Passes/AutoBatching.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -861,15 +861,17 @@ LogicalResult GreedyWhileLoopBatchFission::matchAndRewriteImpl(
861861
auto operandTy = cast<RankedTensorType>(op->getOperand(0).getType());
862862
auto resultTy = cast<RankedTensorType>(op->getResult(0).getType());
863863

864+
bool needsManualReshape = false;
864865
if (!areValidInsertionDims(resultTy, operandTy,
865866
{ds.inductionVarDimension})) {
867+
needsManualReshape = true;
866868
reshapeShape = llvm::to_vector(resultTy.getShape());
867869
}
868870

869871
for (auto user : op->getUsers()) {
870872
userOpToSlicesMap[user].push_back(
871873
DynamicSliceInfo{ds.sliceOp, ds.inductionVarDimension, true,
872-
reshapeShape, ds.offset});
874+
reshapeShape, ds.offset, needsManualReshape});
873875
}
874876
}
875877
}
@@ -1041,7 +1043,7 @@ bool GreedyWhileLoopBatchFission::liftOperationByBatching(
10411043
rewriter.getDenseI64ArrayAttr(permutation))
10421044
.getResult();
10431045

1044-
if (!sliceInfo.reshapeShape.empty()) {
1046+
if (sliceInfo.needsManualReshape) {
10451047
SmallVector<int64_t> reshapedShape(sliceInfo.reshapeShape.begin(),
10461048
sliceInfo.reshapeShape.end());
10471049
reshapedShape.insert(reshapedShape.begin(),

src/enzyme_ad/jax/Passes/AutoBatching.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ struct GreedyWhileLoopBatchFission
213213
bool intermediateReshape;
214214
llvm::SmallVector<int64_t> reshapeShape;
215215
int64_t offset;
216+
bool needsManualReshape;
216217
};
217218

218219
enum class BatchLiftingMode {

src/enzyme_ad/jax/primitives.py

Lines changed: 126 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from typing import Any
66
import itertools
77
import sys
8+
import os
9+
import tempfile
10+
from absl import logging
811

912
import jax
1013
from jax import lax
@@ -107,6 +110,9 @@ def optimization_passes(
107110
reshape_propagate: str = "up",
108111
max_constant_threshold: int = 1024,
109112
enable_batching_passes: bool = True,
113+
enable_licm_optimization_passes: bool = True,
114+
enable_scatter_gather_optimization_passes: bool = True,
115+
enable_pad_optimization_passes: bool = True,
110116
):
111117
transform_passes_list = [
112118
"compare_op_canon<16>",
@@ -124,7 +130,6 @@ def optimization_passes(
124130
"imag_op_canon<16>",
125131
"conj_complex_negate<16>",
126132
"get_dimension_size_op_canon<16>",
127-
"gather_op_canon<16>",
128133
"reshape_op_canon<16>",
129134
"merge_consecutive_reshapes<16>",
130135
"transpose_is_reshape<16>",
@@ -133,7 +138,6 @@ def optimization_passes(
133138
"cse_slice<16>",
134139
"cse_transpose<16>",
135140
"cse_convert<16>",
136-
"cse_pad<16>",
137141
"cse_dot_general<16>",
138142
"cse_reshape<16>",
139143
"cse_mul<16>",
@@ -168,9 +172,6 @@ def optimization_passes(
168172
"slice_dynamic_slice<16>",
169173
"shift_right_logical_simplify<16>",
170174
f"pad_simplify<16>({max_constant_threshold})",
171-
"select_pad_to_dus<1>",
172-
"and_pad_pad<1>",
173-
"negative_pad_to_slice<16>",
174175
"slice_simplify<16>",
175176
"convert_simplify<16>",
176177
"dynamic_slice_to_static<16>",
@@ -185,14 +186,10 @@ def optimization_passes(
185186
"dynamic_update_to_concat<1>",
186187
"slice_of_dynamic_update<1>",
187188
"slice_elementwise<1>",
188-
"slice_pad<1>",
189189
"dot_reshape_dot<1>",
190190
"concat_fuse<1>",
191-
"pad_reshape_pad<1>",
192-
"pad_pad<1>",
193191
"concat_push_binop_add<1>",
194192
"concat_push_binop_mul<1>",
195-
"scatter_to_dynamic_update_slice<1>",
196193
"reduce_concat<1>",
197194
"slice_concat<1>",
198195
"concat_slice<1>",
@@ -204,48 +201,21 @@ def optimization_passes(
204201
"dot_general_simplify<16>",
205202
"transpose_simplify<16>",
206203
"reshape_empty_broadcast<1>",
207-
"add_pad_pad_to_concat<1>",
208204
"broadcast_reshape<1>",
209-
"concat_pad<1>",
210-
"reduce_pad<1>",
211-
"broadcast_pad<1>",
212-
"zero_product_reshape_pad<1>",
213-
"mul_zero_pad<1>",
214-
"div_zero_pad<1>",
215-
"binop_const_reshape_pad<1>",
216-
"binop_const_pad_add<1>",
217-
"binop_const_pad_subtract<1>",
218-
"binop_const_pad_mul<1>",
219-
"binop_const_pad_div<1>",
220-
"clamp_const_prop<1>",
221-
"binop_binop_pad_pad_add<1>",
222-
"binop_binop_pad_pad_mul<1>",
223-
"binop_pad_pad_add<1>",
224-
"binop_pad_pad_subtract<1>",
225-
"binop_pad_pad_mul<1>",
226-
"binop_pad_pad_div<1>",
227-
"binop_pad_pad_min<1>",
228-
"binop_pad_pad_max<1>",
229-
"unary_pad_push_convert<1>",
230-
"unary_pad_push_tanh<1>",
231-
"unary_pad_push_exp<1>",
232205
"transpose_dot_reorder<1>",
233206
"dot_transpose<1>",
234207
"transpose_convolution<1>",
235208
"convolution_transpose<1>",
236209
"convert_convert_float<1>",
237210
"convert_convert_int<1>",
238-
"concat_to_pad<1>",
239211
"reshape_iota<1>",
240212
"broadcast_reduce<1>",
241213
"slice_dot_general<1>",
242214
"if_inline<1>",
243215
"if_to_select<1>",
244-
"dynamic_gather_op_is_not_dynamic<16>",
245216
"divide_sqrt_to_multiply_rsqrt<16>",
246217
"associative_binary_op_reordering<1>",
247218
"transpose_broadcast_in_dim_to_broadcast_in_dim<16>",
248-
"scatter_indices_are_unique",
249219
"replace_neg_add_with_subtract",
250220
"binop_const_simplify",
251221
"not_select_simplify",
@@ -265,36 +235,23 @@ def optimization_passes(
265235
"slice_reduce_window<1>",
266236
"while_deadresult",
267237
"while_dus",
268-
"dus_licm(0)",
269238
"while_op_induction_replacement",
270-
"dus_pad",
271239
"dus_concat",
272240
"slice_dus_to_concat",
273241
"while_induction_reduction",
274-
"slice_licm(0)",
275-
"dot_general_licm(0)",
276-
"pad_licm(0)",
277-
"elementwise_licm(0)",
278-
"concatenate_licm(0)",
279242
"slice_broadcast",
280-
"while_pad_induction_reduction",
281-
"while_licm<1>(1)",
282243
"associative_common_mul_op_reordering",
283244
"slice_select_to_select_slice",
284-
"pad_concat_to_concat_pad",
285245
"slice_if",
286246
"dus_to_i32",
287-
"rotate_pad",
288247
"slice_extend",
289248
"concat_wrap",
290249
"cse_extend<16>",
291250
"cse_wrap<16>",
292251
"cse_rotate<16>",
293252
"cse_rotate<16>",
294253
"concat_concat_axis_swap",
295-
"concat_multipad",
296254
"concat_concat_to_dus",
297-
"speculate_if_pad_to_select",
298255
"broadcast_iota_simplify",
299256
"select_comp_iota_to_dus",
300257
"compare_cleanup",
@@ -322,8 +279,6 @@ def optimization_passes(
322279
"split_convolution_into_reverse_convolution",
323280
# TODO we want to enable but may cause an infinite compile time
324281
# "concat_to_onedim_dusslice",
325-
"scatter_multiply_simplify",
326-
"unary_elementwise_scatter_simplify",
327282
# "chained_multiply_to_power", # TODO: make it into an optional pass
328283
"power_multiply_to_power",
329284
"common_associative_commutative_op_reorder",
@@ -333,19 +288,28 @@ def optimization_passes(
333288
"reshape_deletions_broadcast_in_dim_simplify",
334289
"reshape_insertions_broadcast_in_dim_simplify",
335290
"dot_general_reshape",
336-
"diagonal_tensor_dot_general_rewrite",
337291
"widen_wrap",
338292
"widen_extend",
339-
"elementwise_pad",
340293
"compare_negate_const_simplify",
341294
"select_simplify",
342-
"concatenate_subtract_to_subtract_pad",
343295
"concatenate_broadcast_in_dim",
296+
"compare_abs",
297+
# "compare_mul",
298+
"compare_convert",
299+
"add_selects",
300+
# TODO: parameterize based on the device
301+
"self_subtract_to_convolution_like(0)",
302+
"self_add_to_convolution_like(0)",
303+
"self_mul_to_convolution_like(0)",
304+
"trivial_reduce_window_to_reduce_op",
344305
"case_to_if",
345-
"dus_to_dynamic_pad",
346-
"dynamic_pad_to_pad",
306+
"dot_general_add_distributive_simplify",
307+
"dot_general_subtract_distributive_simplify",
347308
"remove_no_ops_from_while_loop",
348309
"while_is_copy_simplify",
310+
"split_variadic_scatter_op",
311+
"dynamic_slice_simplify",
312+
"enzyme_hlo_unroll(4)",
349313
"dot_general_only_diagonal_access",
350314
]
351315

@@ -395,15 +359,8 @@ def optimization_passes(
395359
# other constant propagations
396360
"const_prop_through_barrier<16>",
397361
f"concat_const_prop<1>({max_constant_threshold})",
398-
f"scatter_const_fold({max_constant_threshold})",
399362
f"dynamic_update_slice_const_prop({max_constant_threshold})",
400-
"scatter_update_computation_const_prop",
401-
"gather_const_prop",
402-
# TODO: parameterize based on the device
403-
"self_subtract_to_convolution_like(0)",
404-
"self_add_to_convolution_like(0)",
405-
"self_mul_to_convolution_like(0)",
406-
"trivial_reduce_window_to_reduce_op",
363+
"clamp_const_prop",
407364
]
408365

409366
if enable_batching_passes:
@@ -428,6 +385,85 @@ def optimization_passes(
428385
"broadcastindim_slice_to_batch",
429386
"reducewindow_slice_to_batch",
430387
"elementwise_slice_to_batch",
388+
"greedy_while_loop_batch_fission",
389+
]
390+
391+
if enable_licm_optimization_passes:
392+
transform_passes_list += [
393+
"dus_licm(0)",
394+
"slice_licm(0)",
395+
"elementwise_licm(0)",
396+
"concatenate_licm(0)",
397+
"while_licm<1>(1)",
398+
"transpose_licm(0)",
399+
"broadcastindim_licm(0)",
400+
"reshape_licm(0)",
401+
"dot_general_licm(0)",
402+
"reduce_licm(0)",
403+
"reduce_window_licm(0)",
404+
"reverse_licm(0)",
405+
]
406+
407+
if enable_scatter_gather_optimization_passes:
408+
transform_passes_list += [
409+
"scatter_to_dynamic_update_slice<1>",
410+
"scatter_multiply_simplify",
411+
"unary_elementwise_scatter_simplify",
412+
"scatter_indices_are_unique",
413+
"diagonal_tensor_dot_general_rewrite",
414+
## const prop patterns
415+
"scatter_update_computation_const_prop",
416+
# gather patterns
417+
"dynamic_gather_op_is_not_dynamic<16>",
418+
"gather_op_canon<16>",
419+
"gather_elementwise",
420+
## const prop patterns
421+
"gather_const_prop",
422+
f"scatter_const_fold({max_constant_threshold})",
423+
]
424+
425+
if enable_pad_optimization_passes:
426+
transform_passes_list += [
427+
"dus_pad",
428+
"cse_pad<16>",
429+
f"pad_simplify<16>({max_constant_threshold})",
430+
"select_pad_to_dus<1>",
431+
"and_pad_pad<1>",
432+
"negative_pad_to_slice<16>",
433+
"slice_pad<1>",
434+
"pad_reshape_pad<1>",
435+
"pad_pad<1>",
436+
"add_pad_pad_to_concat<1>",
437+
"concat_pad<1>",
438+
"reduce_pad<1>",
439+
"broadcast_pad<1>",
440+
"zero_product_reshape_pad<1>",
441+
"mul_zero_pad<1>",
442+
"div_zero_pad<1>",
443+
"binop_const_reshape_pad<1>",
444+
"binop_const_pad_add<1>",
445+
"binop_const_pad_subtract<1>",
446+
"binop_const_pad_mul<1>",
447+
"binop_const_pad_div<1>",
448+
"binop_binop_pad_pad_add<1>",
449+
"binop_binop_pad_pad_mul<1>",
450+
"binop_pad_pad_add<1>",
451+
"binop_pad_pad_subtract<1>",
452+
"binop_pad_pad_mul<1>",
453+
"binop_pad_pad_div<1>",
454+
"binop_pad_pad_min<1>",
455+
"binop_pad_pad_max<1>",
456+
"unary_pad_push_convert<1>",
457+
"unary_pad_push_tanh<1>",
458+
"unary_pad_push_exp<1>",
459+
"concat_to_pad<1>",
460+
"while_pad_induction_reduction",
461+
"pad_concat_to_concat_pad",
462+
"rotate_pad",
463+
"concat_multipad",
464+
"speculate_if_pad_to_select",
465+
"dus_to_dynamic_pad",
466+
"dynamic_pad_to_pad",
431467
]
432468

433469
if reshape_propagate == "up":
@@ -881,6 +917,22 @@ def ret_activity_from_pipeline(pass_pipeline):
881917
return pre_act, acts, post_act
882918

883919

920+
def _dump_mlir_to_file(source, pass_pipeline):
921+
# bazel will zip up the outputs in this directory
922+
dump_mlir_dir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", None)
923+
if dump_mlir_dir is None:
924+
dump_mlir_dir = tempfile.gettempdir()
925+
926+
tmpfile = tempfile.NamedTemporaryFile(
927+
suffix=".mlir", dir=dump_mlir_dir, delete=False
928+
)
929+
with open(tmpfile.name, "w") as f:
930+
f.write("# " + pass_pipeline + "\n")
931+
f.write(str(source))
932+
933+
return tmpfile.name
934+
935+
884936
def _enzyme_primal_lowering(
885937
ctx: jax_mlir.LoweringRuleContext,
886938
*args_flat: ir.Value,
@@ -978,12 +1030,20 @@ def _enzyme_primal_lowering(
9781030

9791031
if len(pass_pipeline) > 0:
9801032
pass_pipeline = pass_pipeline + ",tensor-empty-raise"
981-
name, nmod = enzyme_call.run_pass_pipeline(fns, source, pass_pipeline)
1033+
1034+
try:
1035+
name, nmod = enzyme_call.run_pass_pipeline(fns, source, pass_pipeline)
1036+
except Exception as e:
1037+
filename = _dump_mlir_to_file(source, pass_pipeline)
1038+
logging.exception("Enzyme MLIR dumped to %s", filename)
1039+
raise e
1040+
9821041
if print_mlir:
983-
if type(print_mlir) != type(True):
1042+
if not isinstance(print_mlir, bool):
9841043
print_mlir.write(nmod)
9851044
else:
9861045
print(str(nmod), flush=True)
1046+
9871047
nmod = ir.Module.parse(nmod)
9881048
fn = None
9891049
pushtop = []

0 commit comments

Comments
 (0)