@@ -701,6 +701,7 @@ function optimization_passes(
701
701
recognize_comms:: Bool = true ,
702
702
lower_comms:: Bool = true ,
703
703
max_constant_threshold:: Int = 1024 ,
704
+ backend:: String = " gpu" ,
704
705
)
705
706
transform_passes_list = [
706
707
" patterns=compare_op_canon<16>" ,
@@ -902,6 +903,10 @@ function optimization_passes(
902
903
# "compare_mul",
903
904
" compare_convert" ,
904
905
" add_selects" ,
906
+ " self_subtract_to_convolution_like($(Int (backend == " tpu" )) )" ,
907
+ " self_add_to_convolution_like($(Int (backend == " tpu" )) )" ,
908
+ " self_mul_to_convolution_like($(Int (backend == " tpu" )) )" ,
909
+ " subtract_multiply_const_to_add_mul_const" ,
905
910
]
906
911
907
912
if ! compile_options. disable_scatter_gather_optimization_passes
@@ -1626,10 +1631,10 @@ function compile_mlir!(
1626
1631
end
1627
1632
1628
1633
opt_passes = optimization_passes (
1629
- compile_options; sroa= true , recognize_comms, lower_comms
1634
+ compile_options; sroa= true , recognize_comms, lower_comms, backend
1630
1635
)
1631
1636
opt_passes2 = optimization_passes (
1632
- compile_options; sroa= false , recognize_comms, lower_comms
1637
+ compile_options; sroa= false , recognize_comms, lower_comms, backend
1633
1638
)
1634
1639
1635
1640
raise_passes = if raise isa String
@@ -1650,6 +1655,7 @@ function compile_mlir!(
1650
1655
dus_to_concat= true ,
1651
1656
recognize_comms,
1652
1657
lower_comms,
1658
+ backend,
1653
1659
)
1654
1660
result = result * " ," * opt_passes3
1655
1661
end
@@ -1939,6 +1945,7 @@ function compile_mlir!(
1939
1945
Reactant. __compile_options_with_reversed_propagation (compile_options);
1940
1946
recognize_comms,
1941
1947
lower_comms,
1948
+ backend,
1942
1949
),
1943
1950
" post_op_transpose_reshape" ,
1944
1951
)
0 commit comments