@@ -701,6 +701,7 @@ function optimization_passes(
701
701
recognize_comms:: Bool = true ,
702
702
lower_comms:: Bool = true ,
703
703
max_constant_threshold:: Int = 1024 ,
704
+ backend:: String = " gpu" ,
704
705
)
705
706
transform_passes_list = [
706
707
" patterns=compare_op_canon<16>" ,
@@ -902,9 +903,9 @@ function optimization_passes(
902
903
# "compare_mul",
903
904
" compare_convert" ,
904
905
" add_selects" ,
905
- " self_subtract_to_convolution_like" ,
906
- " self_add_to_convolution_like" ,
907
- " self_mul_to_convolution_like" ,
906
+ " self_subtract_to_convolution_like( $(backend == " tpu " ) ) " ,
907
+ " self_add_to_convolution_like( $(backend == " tpu " ) ) " ,
908
+ " self_mul_to_convolution_like( $(backend == " tpu " ) ) " ,
908
909
" subtract_multiply_const_to_add_mul_const" ,
909
910
]
910
911
@@ -1630,10 +1631,10 @@ function compile_mlir!(
1630
1631
end
1631
1632
1632
1633
opt_passes = optimization_passes (
1633
- compile_options; sroa= true , recognize_comms, lower_comms
1634
+ compile_options; sroa= true , recognize_comms, lower_comms, backend
1634
1635
)
1635
1636
opt_passes2 = optimization_passes (
1636
- compile_options; sroa= false , recognize_comms, lower_comms
1637
+ compile_options; sroa= false , recognize_comms, lower_comms, backend
1637
1638
)
1638
1639
1639
1640
raise_passes = if raise isa String
@@ -1654,6 +1655,7 @@ function compile_mlir!(
1654
1655
dus_to_concat= true ,
1655
1656
recognize_comms,
1656
1657
lower_comms,
1658
+ backend
1657
1659
)
1658
1660
result = result * " ," * opt_passes3
1659
1661
end
@@ -1943,6 +1945,7 @@ function compile_mlir!(
1943
1945
Reactant. __compile_options_with_reversed_propagation (compile_options);
1944
1946
recognize_comms,
1945
1947
lower_comms,
1948
+ backend
1946
1949
),
1947
1950
" post_op_transpose_reshape" ,
1948
1951
)
0 commit comments