From b7660edf07aa2cb480f46ed261c42ffd07a36e67 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 7 Aug 2025 08:15:01 -0400 Subject: [PATCH 1/3] feat: rotate patterns to reduce window [skip ci] --- src/Compiler.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/Compiler.jl b/src/Compiler.jl index 2dce2dac1b..49cf7f73fa 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -902,6 +902,9 @@ function optimization_passes( # "compare_mul", "compare_convert", "add_selects", + "self_subtract_to_convolution_like", + "self_add_to_convolution_like", + "self_mul_to_convolution_like", ] if !compile_options.disable_scatter_gather_optimization_passes From e5a7a503b15fb5e48e3b5bd2e2c8497242409351 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 11 Aug 2025 18:02:09 -0400 Subject: [PATCH 2/3] chore: add new pad [skip ci] --- src/Compiler.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Compiler.jl b/src/Compiler.jl index 49cf7f73fa..2672091ce5 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -905,6 +905,7 @@ function optimization_passes( "self_subtract_to_convolution_like", "self_add_to_convolution_like", "self_mul_to_convolution_like", + "subtract_multiply_const_to_add_mul_const", ] if !compile_options.disable_scatter_gather_optimization_passes From 955520ca1a07dca57326e8c596a5f22528fb7ab9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 13 Aug 2025 07:43:54 -0400 Subject: [PATCH 3/3] chore: optionally enalbe these passes --- src/Compiler.jl | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 2672091ce5..f5fd9b46bf 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -701,6 +701,7 @@ function optimization_passes( recognize_comms::Bool=true, lower_comms::Bool=true, max_constant_threshold::Int=1024, + backend::String="gpu", ) transform_passes_list = [ "patterns=compare_op_canon<16>", @@ -902,9 +903,9 @@ function optimization_passes( # "compare_mul", "compare_convert", "add_selects", - "self_subtract_to_convolution_like", - "self_add_to_convolution_like", - "self_mul_to_convolution_like", + "self_subtract_to_convolution_like($(backend == "tpu"))", + "self_add_to_convolution_like($(backend == "tpu"))", + "self_mul_to_convolution_like($(backend == "tpu"))", "subtract_multiply_const_to_add_mul_const", ] @@ -1630,10 +1631,10 @@ function compile_mlir!( end opt_passes = optimization_passes( - compile_options; sroa=true, recognize_comms, lower_comms + compile_options; sroa=true, recognize_comms, lower_comms, backend ) opt_passes2 = optimization_passes( - compile_options; sroa=false, recognize_comms, lower_comms + compile_options; sroa=false, recognize_comms, lower_comms, backend ) raise_passes = if raise isa String @@ -1654,6 +1655,7 @@ function compile_mlir!( dus_to_concat=true, recognize_comms, lower_comms, + backend ) result = result * "," * opt_passes3 end @@ -1943,6 +1945,7 @@ function compile_mlir!( Reactant.__compile_options_with_reversed_propagation(compile_options); recognize_comms, lower_comms, + backend ), "post_op_transpose_reshape", )