Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 57 additions & 38 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1729,6 +1729,8 @@ function get_stablehlo_to_hlo_passes(; stablehlo_to_mhlo::Bool=true)
return passes
end

join_after_filtering_empty(list, delim) = join(filter(!isempty, list), delim)

function compile_mlir!(
mod,
f,
Expand Down Expand Up @@ -1903,10 +1905,38 @@ function compile_mlir!(

legal_to_run_shardy_passes = compile_options.optimization_passes === :all

layout_propagation_passes = ""
# We want to run these passes even before enzyme to raise loops
# which are generally nasty to differentiate and often times it
# hard to raise post differentiation
if (
compile_options.transpose_propagate === :up ||
compile_options.reshape_propagate === :up
)
# We tried propagating reshapes and transposes up. If at this point we are left
# with them, we propagate them down to minimize the number of Ops in the IR.
# Since this might enable certain raising, we do push down -> push up -> push down
common_kwargs = (;
recognize_comms,
lower_comms,
backend,
is_sharded,
raise_shlo_to_blas_lapack=false,
)
opt_passes_down = optimization_passes(
Reactant.__compile_options_with_reversed_propagation(compile_options);
common_kwargs...,
)
opt_passes_up = optimization_passes(compile_options; common_kwargs...)
layout_propagation_passes = join_after_filtering_empty(
[opt_passes_down, opt_passes_up, opt_passes_down, opt_passes_up], ","
)
end

if compile_options.optimization_passes === :all
run_pass_pipeline!(
mod,
join(
join_after_filtering_empty(
if compile_options.raise_first
[
"mark-func-memory-effects",
Expand All @@ -1915,6 +1945,7 @@ function compile_mlir!(
raise_passes,
"enzyme-batch",
opt_passes2,
layout_propagation_passes,
enzyme_pass,
opt_passes2,
"canonicalize",
Expand All @@ -1925,13 +1956,15 @@ function compile_mlir!(
lower_enzymexla_linalg_pass,
lower_enzymexla_mpi_pass,
jit,
layout_propagation_passes,
]
else
[
"mark-func-memory-effects",
opt_passes,
"enzyme-batch",
opt_passes2,
layout_propagation_passes,
enzyme_pass,
opt_passes2,
"canonicalize",
Expand All @@ -1944,6 +1977,7 @@ function compile_mlir!(
lower_enzymexla_linalg_pass,
lower_enzymexla_mpi_pass,
jit,
layout_propagation_passes,
]
end,
",",
Expand All @@ -1953,7 +1987,7 @@ function compile_mlir!(
elseif compile_options.optimization_passes === :before_kernel
run_pass_pipeline!(
mod,
join(
join_after_filtering_empty(
if compile_options.raise_first
["mark-func-memory-effects", opt_passes]
else
Expand All @@ -1962,13 +1996,15 @@ function compile_mlir!(
opt_passes,
"enzyme-batch",
opt_passes2,
layout_propagation_passes,
enzyme_pass,
opt_passes2,
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
legalize_chlo_to_stablehlo...,
opt_passes2,
layout_propagation_passes,
]
end,
',',
Expand All @@ -1978,7 +2014,7 @@ function compile_mlir!(
elseif compile_options.optimization_passes === :before_jit
run_pass_pipeline!(
mod,
join(
join_after_filtering_empty(
if compile_options.raise_first
[
"mark-func-memory-effects",
Expand All @@ -1987,20 +2023,23 @@ function compile_mlir!(
raise_passes,
"enzyme-batch",
opt_passes2,
layout_propagation_passes,
enzyme_pass,
opt_passes2,
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
legalize_chlo_to_stablehlo...,
opt_passes2,
layout_propagation_passes,
]
else
[
"mark-func-memory-effects",
opt_passes,
"enzyme-batch",
opt_passes2,
layout_propagation_passes,
enzyme_pass,
opt_passes2,
"canonicalize",
Expand All @@ -2010,6 +2049,7 @@ function compile_mlir!(
opt_passes2,
kern,
raise_passes,
layout_propagation_passes,
]
end,
',',
Expand All @@ -2019,7 +2059,7 @@ function compile_mlir!(
elseif compile_options.optimization_passes === :before_raise
run_pass_pipeline!(
mod,
join(
join_after_filtering_empty(
if compile_options.raise_first
["mark-func-memory-effects", opt_passes]
else
Expand All @@ -2028,6 +2068,7 @@ function compile_mlir!(
opt_passes,
"enzyme-batch",
opt_passes2,
layout_propagation_passes,
enzyme_pass,
opt_passes2,
"canonicalize",
Expand All @@ -2036,6 +2077,7 @@ function compile_mlir!(
legalize_chlo_to_stablehlo...,
opt_passes2,
kern,
layout_propagation_passes,
]
end,
',',
Expand All @@ -2045,19 +2087,20 @@ function compile_mlir!(
elseif compile_options.optimization_passes === :no_enzyme
run_pass_pipeline!(
mod,
join(
join_after_filtering_empty(
[
"mark-func-memory-effects",
opt_passes,
"enzyme-batch",
opt_passes2,
enzyme_pass,
layout_propagation_passes,
opt_passes2,
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
legalize_chlo_to_stablehlo...,
opt_passes2,
layout_propagation_passes,
],
',',
),
Expand All @@ -2066,7 +2109,7 @@ function compile_mlir!(
elseif compile_options.optimization_passes === :only_enzyme
run_pass_pipeline!(
mod,
join(
join_after_filtering_empty(
[
"mark-func-memory-effects",
"enzyme-batch",
Expand All @@ -2082,7 +2125,7 @@ function compile_mlir!(
elseif compile_options.optimization_passes === :after_enzyme
run_pass_pipeline!(
mod,
join(
join_after_filtering_empty(
if compile_options.raise_first
[
"mark-func-memory-effects",
Expand All @@ -2098,6 +2141,7 @@ function compile_mlir!(
lower_enzymexla_linalg_pass,
lower_enzymexla_mpi_pass,
jit,
layout_propagation_passes,
]
else
[
Expand All @@ -2114,6 +2158,7 @@ function compile_mlir!(
lower_enzymexla_linalg_pass,
lower_enzymexla_mpi_pass,
jit,
layout_propagation_passes,
]
end,
',',
Expand All @@ -2123,7 +2168,7 @@ function compile_mlir!(
elseif compile_options.optimization_passes === :before_enzyme
run_pass_pipeline!(
mod,
join(
join_after_filtering_empty(
if compile_options.raise_first
[
"mark-func-memory-effects",
Expand All @@ -2132,6 +2177,7 @@ function compile_mlir!(
raise_passes,
"enzyme-batch",
opt_passes2,
layout_propagation_passes,
enzyme_pass,
"canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math",
lower_enzymexla_linalg_pass,
Expand All @@ -2144,6 +2190,7 @@ function compile_mlir!(
opt_passes,
"enzyme-batch",
opt_passes2,
layout_propagation_passes,
enzyme_pass,
"canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math",
kern,
Expand All @@ -2165,34 +2212,6 @@ function compile_mlir!(
run_pass_pipeline!(mod, compile_options.optimization_passes, "custom_pass")
end

if compile_options.optimization_passes isa Symbol &&
compile_options.optimization_passes === :all &&
(
compile_options.transpose_propagate === :up ||
compile_options.reshape_propagate === :up
)
# We tried propagating reshapes and transposes up. If at this point we are left
# with them, we propagate them down to minimize the number of Ops in the IR.
# Since this might enable certain raising, we do push down -> push up -> push down
common_kwargs = (;
recognize_comms,
lower_comms,
backend,
is_sharded,
raise_shlo_to_blas_lapack=false,
)
opt_passes_down = optimization_passes(
Reactant.__compile_options_with_reversed_propagation(compile_options);
common_kwargs...,
)
opt_passes_up = optimization_passes(compile_options; common_kwargs...)
run_pass_pipeline!(
mod,
join([opt_passes_down, opt_passes_up, opt_passes_down], ","),
"post_op_transpose_reshape",
)
end

if backend == "cuda" && compile_options.cudnn_hlo_optimize
run_pass_pipeline!(mod, "enzymexla-cudnn-hlo-opt", "cudnn-hlo-opt")
end
Expand Down Expand Up @@ -2329,7 +2348,7 @@ function compile_mlir!(
if compile_options.optimization_passes === :all
run_pass_pipeline!(
mod,
join(
join_after_filtering_empty(
[
opt_passes,
"canonicalize",
Expand Down
Loading