diff --git a/src/Compiler.jl b/src/Compiler.jl index 2dce2dac1b..f5fd9b46bf 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -701,6 +701,7 @@ function optimization_passes( recognize_comms::Bool=true, lower_comms::Bool=true, max_constant_threshold::Int=1024, + backend::String="gpu", ) transform_passes_list = [ "patterns=compare_op_canon<16>", @@ -902,6 +903,10 @@ function optimization_passes( # "compare_mul", "compare_convert", "add_selects", + "self_subtract_to_convolution_like($(backend == "tpu"))", + "self_add_to_convolution_like($(backend == "tpu"))", + "self_mul_to_convolution_like($(backend == "tpu"))", + "subtract_multiply_const_to_add_mul_const", ] if !compile_options.disable_scatter_gather_optimization_passes @@ -1626,10 +1631,10 @@ function compile_mlir!( end opt_passes = optimization_passes( - compile_options; sroa=true, recognize_comms, lower_comms + compile_options; sroa=true, recognize_comms, lower_comms, backend ) opt_passes2 = optimization_passes( - compile_options; sroa=false, recognize_comms, lower_comms + compile_options; sroa=false, recognize_comms, lower_comms, backend ) raise_passes = if raise isa String @@ -1650,6 +1655,7 @@ function compile_mlir!( dus_to_concat=true, recognize_comms, lower_comms, + backend ) result = result * "," * opt_passes3 end @@ -1939,6 +1945,7 @@ function compile_mlir!( Reactant.__compile_options_with_reversed_propagation(compile_options); recognize_comms, lower_comms, + backend ), "post_op_transpose_reshape", )