Skip to content

Commit e7b0a5a

Browse files
fix: setup for new jll (#1538)
* feat: rotate patterns to reduce window [skip ci] * chore: add new pad [skip ci] * chore: optionally enalbe these passes * Update src/Compiler.jl * Update src/Compiler.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/Compiler.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * chore: update jll version --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 9795fa6 commit e7b0a5a

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ PythonCall = "0.9.25"
100100
Random = "1.10"
101101
Random123 = "1.7"
102102
ReactantCore = "0.1.15"
103-
Reactant_jll = "0.0.233"
103+
Reactant_jll = "0.0.235"
104104
ScopedValues = "1.3.0"
105105
Scratch = "1.2"
106106
Sockets = "1.10"

src/Compiler.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,7 @@ function optimization_passes(
701701
recognize_comms::Bool=true,
702702
lower_comms::Bool=true,
703703
max_constant_threshold::Int=1024,
704+
backend::String="gpu",
704705
)
705706
transform_passes_list = [
706707
"patterns=compare_op_canon<16>",
@@ -902,6 +903,10 @@ function optimization_passes(
902903
# "compare_mul",
903904
"compare_convert",
904905
"add_selects",
906+
"self_subtract_to_convolution_like($(Int(backend == "tpu")))",
907+
"self_add_to_convolution_like($(Int(backend == "tpu")))",
908+
"self_mul_to_convolution_like($(Int(backend == "tpu")))",
909+
"subtract_multiply_const_to_add_mul_const",
905910
]
906911

907912
if !compile_options.disable_scatter_gather_optimization_passes
@@ -1626,10 +1631,10 @@ function compile_mlir!(
16261631
end
16271632

16281633
opt_passes = optimization_passes(
1629-
compile_options; sroa=true, recognize_comms, lower_comms
1634+
compile_options; sroa=true, recognize_comms, lower_comms, backend
16301635
)
16311636
opt_passes2 = optimization_passes(
1632-
compile_options; sroa=false, recognize_comms, lower_comms
1637+
compile_options; sroa=false, recognize_comms, lower_comms, backend
16331638
)
16341639

16351640
raise_passes = if raise isa String
@@ -1650,6 +1655,7 @@ function compile_mlir!(
16501655
dus_to_concat=true,
16511656
recognize_comms,
16521657
lower_comms,
1658+
backend,
16531659
)
16541660
result = result * "," * opt_passes3
16551661
end
@@ -1939,6 +1945,7 @@ function compile_mlir!(
19391945
Reactant.__compile_options_with_reversed_propagation(compile_options);
19401946
recognize_comms,
19411947
lower_comms,
1948+
backend,
19421949
),
19431950
"post_op_transpose_reshape",
19441951
)

0 commit comments

Comments
 (0)