55from typing import Any
66import itertools
77import sys
8+ import os
9+ import tempfile
10+ from absl import logging
811
912import jax
1013from jax import lax
@@ -107,6 +110,9 @@ def optimization_passes(
107110 reshape_propagate : str = "up" ,
108111 max_constant_threshold : int = 1024 ,
109112 enable_batching_passes : bool = True ,
113+ enable_licm_optimization_passes : bool = True ,
114+ enable_scatter_gather_optimization_passes : bool = True ,
115+ enable_pad_optimization_passes : bool = True ,
110116):
111117 transform_passes_list = [
112118 "compare_op_canon<16>" ,
@@ -124,7 +130,6 @@ def optimization_passes(
124130 "imag_op_canon<16>" ,
125131 "conj_complex_negate<16>" ,
126132 "get_dimension_size_op_canon<16>" ,
127- "gather_op_canon<16>" ,
128133 "reshape_op_canon<16>" ,
129134 "merge_consecutive_reshapes<16>" ,
130135 "transpose_is_reshape<16>" ,
@@ -133,7 +138,6 @@ def optimization_passes(
133138 "cse_slice<16>" ,
134139 "cse_transpose<16>" ,
135140 "cse_convert<16>" ,
136- "cse_pad<16>" ,
137141 "cse_dot_general<16>" ,
138142 "cse_reshape<16>" ,
139143 "cse_mul<16>" ,
@@ -168,9 +172,6 @@ def optimization_passes(
168172 "slice_dynamic_slice<16>" ,
169173 "shift_right_logical_simplify<16>" ,
170174 f"pad_simplify<16>({ max_constant_threshold } )" ,
171- "select_pad_to_dus<1>" ,
172- "and_pad_pad<1>" ,
173- "negative_pad_to_slice<16>" ,
174175 "slice_simplify<16>" ,
175176 "convert_simplify<16>" ,
176177 "dynamic_slice_to_static<16>" ,
@@ -185,14 +186,10 @@ def optimization_passes(
185186 "dynamic_update_to_concat<1>" ,
186187 "slice_of_dynamic_update<1>" ,
187188 "slice_elementwise<1>" ,
188- "slice_pad<1>" ,
189189 "dot_reshape_dot<1>" ,
190190 "concat_fuse<1>" ,
191- "pad_reshape_pad<1>" ,
192- "pad_pad<1>" ,
193191 "concat_push_binop_add<1>" ,
194192 "concat_push_binop_mul<1>" ,
195- "scatter_to_dynamic_update_slice<1>" ,
196193 "reduce_concat<1>" ,
197194 "slice_concat<1>" ,
198195 "concat_slice<1>" ,
@@ -204,48 +201,21 @@ def optimization_passes(
204201 "dot_general_simplify<16>" ,
205202 "transpose_simplify<16>" ,
206203 "reshape_empty_broadcast<1>" ,
207- "add_pad_pad_to_concat<1>" ,
208204 "broadcast_reshape<1>" ,
209- "concat_pad<1>" ,
210- "reduce_pad<1>" ,
211- "broadcast_pad<1>" ,
212- "zero_product_reshape_pad<1>" ,
213- "mul_zero_pad<1>" ,
214- "div_zero_pad<1>" ,
215- "binop_const_reshape_pad<1>" ,
216- "binop_const_pad_add<1>" ,
217- "binop_const_pad_subtract<1>" ,
218- "binop_const_pad_mul<1>" ,
219- "binop_const_pad_div<1>" ,
220- "clamp_const_prop<1>" ,
221- "binop_binop_pad_pad_add<1>" ,
222- "binop_binop_pad_pad_mul<1>" ,
223- "binop_pad_pad_add<1>" ,
224- "binop_pad_pad_subtract<1>" ,
225- "binop_pad_pad_mul<1>" ,
226- "binop_pad_pad_div<1>" ,
227- "binop_pad_pad_min<1>" ,
228- "binop_pad_pad_max<1>" ,
229- "unary_pad_push_convert<1>" ,
230- "unary_pad_push_tanh<1>" ,
231- "unary_pad_push_exp<1>" ,
232205 "transpose_dot_reorder<1>" ,
233206 "dot_transpose<1>" ,
234207 "transpose_convolution<1>" ,
235208 "convolution_transpose<1>" ,
236209 "convert_convert_float<1>" ,
237210 "convert_convert_int<1>" ,
238- "concat_to_pad<1>" ,
239211 "reshape_iota<1>" ,
240212 "broadcast_reduce<1>" ,
241213 "slice_dot_general<1>" ,
242214 "if_inline<1>" ,
243215 "if_to_select<1>" ,
244- "dynamic_gather_op_is_not_dynamic<16>" ,
245216 "divide_sqrt_to_multiply_rsqrt<16>" ,
246217 "associative_binary_op_reordering<1>" ,
247218 "transpose_broadcast_in_dim_to_broadcast_in_dim<16>" ,
248- "scatter_indices_are_unique" ,
249219 "replace_neg_add_with_subtract" ,
250220 "binop_const_simplify" ,
251221 "not_select_simplify" ,
@@ -265,36 +235,23 @@ def optimization_passes(
265235 "slice_reduce_window<1>" ,
266236 "while_deadresult" ,
267237 "while_dus" ,
268- "dus_licm(0)" ,
269238 "while_op_induction_replacement" ,
270- "dus_pad" ,
271239 "dus_concat" ,
272240 "slice_dus_to_concat" ,
273241 "while_induction_reduction" ,
274- "slice_licm(0)" ,
275- "dot_general_licm(0)" ,
276- "pad_licm(0)" ,
277- "elementwise_licm(0)" ,
278- "concatenate_licm(0)" ,
279242 "slice_broadcast" ,
280- "while_pad_induction_reduction" ,
281- "while_licm<1>(1)" ,
282243 "associative_common_mul_op_reordering" ,
283244 "slice_select_to_select_slice" ,
284- "pad_concat_to_concat_pad" ,
285245 "slice_if" ,
286246 "dus_to_i32" ,
287- "rotate_pad" ,
288247 "slice_extend" ,
289248 "concat_wrap" ,
290249 "cse_extend<16>" ,
291250 "cse_wrap<16>" ,
292251 "cse_rotate<16>" ,
293252 "cse_rotate<16>" ,
294253 "concat_concat_axis_swap" ,
295- "concat_multipad" ,
296254 "concat_concat_to_dus" ,
297- "speculate_if_pad_to_select" ,
298255 "broadcast_iota_simplify" ,
299256 "select_comp_iota_to_dus" ,
300257 "compare_cleanup" ,
@@ -322,8 +279,6 @@ def optimization_passes(
322279 "split_convolution_into_reverse_convolution" ,
323280 # TODO we want to enable but may cause an infinite compile time
324281 # "concat_to_onedim_dusslice",
325- "scatter_multiply_simplify" ,
326- "unary_elementwise_scatter_simplify" ,
327282 # "chained_multiply_to_power", # TODO: make it into an optional pass
328283 "power_multiply_to_power" ,
329284 "common_associative_commutative_op_reorder" ,
@@ -333,19 +288,28 @@ def optimization_passes(
333288 "reshape_deletions_broadcast_in_dim_simplify" ,
334289 "reshape_insertions_broadcast_in_dim_simplify" ,
335290 "dot_general_reshape" ,
336- "diagonal_tensor_dot_general_rewrite" ,
337291 "widen_wrap" ,
338292 "widen_extend" ,
339- "elementwise_pad" ,
340293 "compare_negate_const_simplify" ,
341294 "select_simplify" ,
342- "concatenate_subtract_to_subtract_pad" ,
343295 "concatenate_broadcast_in_dim" ,
296+ "compare_abs" ,
297+ # "compare_mul",
298+ "compare_convert" ,
299+ "add_selects" ,
300+ # TODO: parameterize based on the device
301+ "self_subtract_to_convolution_like(0)" ,
302+ "self_add_to_convolution_like(0)" ,
303+ "self_mul_to_convolution_like(0)" ,
304+ "trivial_reduce_window_to_reduce_op" ,
344305 "case_to_if" ,
345- "dus_to_dynamic_pad " ,
346- "dynamic_pad_to_pad " ,
306+ "dot_general_add_distributive_simplify " ,
307+ "dot_general_subtract_distributive_simplify " ,
347308 "remove_no_ops_from_while_loop" ,
348309 "while_is_copy_simplify" ,
310+ "split_variadic_scatter_op" ,
311+ "dynamic_slice_simplify" ,
312+ "enzyme_hlo_unroll(4)" ,
349313 "dot_general_only_diagonal_access" ,
350314 ]
351315
@@ -395,15 +359,8 @@ def optimization_passes(
395359 # other constant propagations
396360 "const_prop_through_barrier<16>" ,
397361 f"concat_const_prop<1>({ max_constant_threshold } )" ,
398- f"scatter_const_fold({ max_constant_threshold } )" ,
399362 f"dynamic_update_slice_const_prop({ max_constant_threshold } )" ,
400- "scatter_update_computation_const_prop" ,
401- "gather_const_prop" ,
402- # TODO: parameterize based on the device
403- "self_subtract_to_convolution_like(0)" ,
404- "self_add_to_convolution_like(0)" ,
405- "self_mul_to_convolution_like(0)" ,
406- "trivial_reduce_window_to_reduce_op" ,
363+ "clamp_const_prop" ,
407364 ]
408365
409366 if enable_batching_passes :
@@ -428,6 +385,85 @@ def optimization_passes(
428385 "broadcastindim_slice_to_batch" ,
429386 "reducewindow_slice_to_batch" ,
430387 "elementwise_slice_to_batch" ,
388+ "greedy_while_loop_batch_fission" ,
389+ ]
390+
391+ if enable_licm_optimization_passes :
392+ transform_passes_list += [
393+ "dus_licm(0)" ,
394+ "slice_licm(0)" ,
395+ "elementwise_licm(0)" ,
396+ "concatenate_licm(0)" ,
397+ "while_licm<1>(1)" ,
398+ "transpose_licm(0)" ,
399+ "broadcastindim_licm(0)" ,
400+ "reshape_licm(0)" ,
401+ "dot_general_licm(0)" ,
402+ "reduce_licm(0)" ,
403+ "reduce_window_licm(0)" ,
404+ "reverse_licm(0)" ,
405+ ]
406+
407+ if enable_scatter_gather_optimization_passes :
408+ transform_passes_list += [
409+ "scatter_to_dynamic_update_slice<1>" ,
410+ "scatter_multiply_simplify" ,
411+ "unary_elementwise_scatter_simplify" ,
412+ "scatter_indices_are_unique" ,
413+ "diagonal_tensor_dot_general_rewrite" ,
414+ ## const prop patterns
415+ "scatter_update_computation_const_prop" ,
416+ # gather patterns
417+ "dynamic_gather_op_is_not_dynamic<16>" ,
418+ "gather_op_canon<16>" ,
419+ "gather_elementwise" ,
420+ ## const prop patterns
421+ "gather_const_prop" ,
422+ f"scatter_const_fold({ max_constant_threshold } )" ,
423+ ]
424+
425+ if enable_pad_optimization_passes :
426+ transform_passes_list += [
427+ "dus_pad" ,
428+ "cse_pad<16>" ,
429+ f"pad_simplify<16>({ max_constant_threshold } )" ,
430+ "select_pad_to_dus<1>" ,
431+ "and_pad_pad<1>" ,
432+ "negative_pad_to_slice<16>" ,
433+ "slice_pad<1>" ,
434+ "pad_reshape_pad<1>" ,
435+ "pad_pad<1>" ,
436+ "add_pad_pad_to_concat<1>" ,
437+ "concat_pad<1>" ,
438+ "reduce_pad<1>" ,
439+ "broadcast_pad<1>" ,
440+ "zero_product_reshape_pad<1>" ,
441+ "mul_zero_pad<1>" ,
442+ "div_zero_pad<1>" ,
443+ "binop_const_reshape_pad<1>" ,
444+ "binop_const_pad_add<1>" ,
445+ "binop_const_pad_subtract<1>" ,
446+ "binop_const_pad_mul<1>" ,
447+ "binop_const_pad_div<1>" ,
448+ "binop_binop_pad_pad_add<1>" ,
449+ "binop_binop_pad_pad_mul<1>" ,
450+ "binop_pad_pad_add<1>" ,
451+ "binop_pad_pad_subtract<1>" ,
452+ "binop_pad_pad_mul<1>" ,
453+ "binop_pad_pad_div<1>" ,
454+ "binop_pad_pad_min<1>" ,
455+ "binop_pad_pad_max<1>" ,
456+ "unary_pad_push_convert<1>" ,
457+ "unary_pad_push_tanh<1>" ,
458+ "unary_pad_push_exp<1>" ,
459+ "concat_to_pad<1>" ,
460+ "while_pad_induction_reduction" ,
461+ "pad_concat_to_concat_pad" ,
462+ "rotate_pad" ,
463+ "concat_multipad" ,
464+ "speculate_if_pad_to_select" ,
465+ "dus_to_dynamic_pad" ,
466+ "dynamic_pad_to_pad" ,
431467 ]
432468
433469 if reshape_propagate == "up" :
@@ -881,6 +917,22 @@ def ret_activity_from_pipeline(pass_pipeline):
881917 return pre_act , acts , post_act
882918
883919
920+ def _dump_mlir_to_file (source , pass_pipeline ):
921+ # bazel will zip up the outputs in this directory
922+ dump_mlir_dir = os .environ .get ("TEST_UNDECLARED_OUTPUTS_DIR" , None )
923+ if dump_mlir_dir is None :
924+ dump_mlir_dir = tempfile .gettempdir ()
925+
926+ tmpfile = tempfile .NamedTemporaryFile (
927+ suffix = ".mlir" , dir = dump_mlir_dir , delete = False
928+ )
929+ with open (tmpfile .name , "w" ) as f :
930+ f .write ("# " + pass_pipeline + "\n " )
931+ f .write (str (source ))
932+
933+ return tmpfile .name
934+
935+
884936def _enzyme_primal_lowering (
885937 ctx : jax_mlir .LoweringRuleContext ,
886938 * args_flat : ir .Value ,
@@ -978,12 +1030,20 @@ def _enzyme_primal_lowering(
9781030
9791031 if len (pass_pipeline ) > 0 :
9801032 pass_pipeline = pass_pipeline + ",tensor-empty-raise"
981- name , nmod = enzyme_call .run_pass_pipeline (fns , source , pass_pipeline )
1033+
1034+ try :
1035+ name , nmod = enzyme_call .run_pass_pipeline (fns , source , pass_pipeline )
1036+ except Exception as e :
1037+ filename = _dump_mlir_to_file (source , pass_pipeline )
1038+ logging .exception ("Enzyme MLIR dumped to %s" , filename )
1039+ raise e
1040+
9821041 if print_mlir :
983- if type (print_mlir ) != type ( True ):
1042+ if not isinstance (print_mlir , bool ):
9841043 print_mlir .write (nmod )
9851044 else :
9861045 print (str (nmod ), flush = True )
1046+
9871047 nmod = ir .Module .parse (nmod )
9881048 fn = None
9891049 pushtop = []
0 commit comments