Skip to content

Commit a46511e

Browse files
authored
feat: xla/mhlo export passes (#1463)
* feat: run mhlo passes before legalize * feat: allow direct conversion of shlo to mhlo * feat: before optimizations HLO * chore: run fmt * chore: bump jll * fix: tests
1 parent 66c200b commit a46511e

File tree

5 files changed

+107
-88
lines changed

5 files changed

+107
-88
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.145"
4+
version = "0.2.146"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -90,7 +90,7 @@ PythonCall = "0.9.25"
9090
Random = "1.10"
9191
Random123 = "1.7"
9292
ReactantCore = "0.1.15"
93-
Reactant_jll = "0.0.218"
93+
Reactant_jll = "0.0.219"
9494
ScopedValues = "1.3.0"
9595
Scratch = "1.2"
9696
Sockets = "1.10"

src/Compiler.jl

Lines changed: 78 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1475,6 +1475,22 @@ function get_optimize_comms_passes(options::OptimizeCommunicationOptions)
14751475
return res
14761476
end
14771477

1478+
function get_stablehlo_to_hlo_passes(; stablehlo_to_mhlo::Bool=true)
1479+
passes = (
1480+
"func.func(stablehlo-ext-chlo-recompose-ops)",
1481+
"symbol-dce",
1482+
"func.func(chlo-legalize-to-high-level-mhlo)",
1483+
"func.func(chlo-legalize-to-stablehlo)",
1484+
)
1485+
if stablehlo_to_mhlo
1486+
passes = (passes..., "stablehlo-legalize-to-hlo")
1487+
end
1488+
passes = (
1489+
passes..., "canonicalize", "func.func(stablehlo-ext-sink-constants-to-control-flow)"
1490+
)
1491+
return passes
1492+
end
1493+
14781494
function compile_mlir!(
14791495
mod,
14801496
f,
@@ -1485,6 +1501,7 @@ function compile_mlir!(
14851501
fn_kwargs=(),
14861502
backend="gpu",
14871503
runtime::Union{Val{:PJRT},Val{:IFRT}},
1504+
legalize_stablehlo_to_mhlo::Bool=false,
14881505
kwargs...,
14891506
)
14901507
# Explicitly don't use block! to avoid creating a closure, which creates
@@ -1624,6 +1641,13 @@ function compile_mlir!(
16241641
lower_enzymexla_linalg_pass = "lower-enzymexla-linalg{backend=$backend \
16251642
blas_int_width=$blas_int_width}"
16261643

1644+
legalize_chlo_to_stablehlo =
1645+
if legalize_stablehlo_to_mhlo || compile_options.legalize_chlo_to_stablehlo
1646+
get_stablehlo_to_hlo_passes(; stablehlo_to_mhlo=legalize_stablehlo_to_mhlo)
1647+
else
1648+
()
1649+
end
1650+
16271651
if compile_options.optimization_passes === :all
16281652
run_pass_pipeline!(
16291653
mod,
@@ -1641,13 +1665,7 @@ function compile_mlir!(
16411665
"canonicalize",
16421666
"remove-unnecessary-enzyme-ops",
16431667
"enzyme-simplify-math",
1644-
(
1645-
if compile_options.legalize_chlo_to_stablehlo
1646-
["func.func(chlo-legalize-to-stablehlo)"]
1647-
else
1648-
[]
1649-
end
1650-
)...,
1668+
legalize_chlo_to_stablehlo...,
16511669
opt_passes2,
16521670
lower_enzymexla_linalg_pass,
16531671
jit,
@@ -1663,13 +1681,7 @@ function compile_mlir!(
16631681
"canonicalize",
16641682
"remove-unnecessary-enzyme-ops",
16651683
"enzyme-simplify-math",
1666-
(
1667-
if compile_options.legalize_chlo_to_stablehlo
1668-
["func.func(chlo-legalize-to-stablehlo)"]
1669-
else
1670-
[]
1671-
end
1672-
)...,
1684+
legalize_chlo_to_stablehlo...,
16731685
opt_passes2,
16741686
kern,
16751687
raise_passes,
@@ -1698,13 +1710,7 @@ function compile_mlir!(
16981710
"canonicalize",
16991711
"remove-unnecessary-enzyme-ops",
17001712
"enzyme-simplify-math",
1701-
(
1702-
if compile_options.legalize_chlo_to_stablehlo
1703-
["func.func(chlo-legalize-to-stablehlo)"]
1704-
else
1705-
[]
1706-
end
1707-
)...,
1713+
legalize_chlo_to_stablehlo...,
17081714
opt_passes2,
17091715
]
17101716
end,
@@ -1729,13 +1735,7 @@ function compile_mlir!(
17291735
"canonicalize",
17301736
"remove-unnecessary-enzyme-ops",
17311737
"enzyme-simplify-math",
1732-
(
1733-
if compile_options.legalize_chlo_to_stablehlo
1734-
["func.func(chlo-legalize-to-stablehlo)"]
1735-
else
1736-
[]
1737-
end
1738-
)...,
1738+
legalize_chlo_to_stablehlo...,
17391739
opt_passes2,
17401740
]
17411741
else
@@ -1749,13 +1749,7 @@ function compile_mlir!(
17491749
"canonicalize",
17501750
"remove-unnecessary-enzyme-ops",
17511751
"enzyme-simplify-math",
1752-
(
1753-
if compile_options.legalize_chlo_to_stablehlo
1754-
["func.func(chlo-legalize-to-stablehlo)"]
1755-
else
1756-
[]
1757-
end
1758-
)...,
1752+
legalize_chlo_to_stablehlo...,
17591753
opt_passes2,
17601754
kern,
17611755
raise_passes,
@@ -1782,13 +1776,7 @@ function compile_mlir!(
17821776
"canonicalize",
17831777
"remove-unnecessary-enzyme-ops",
17841778
"enzyme-simplify-math",
1785-
(
1786-
if compile_options.legalize_chlo_to_stablehlo
1787-
["func.func(chlo-legalize-to-stablehlo)"]
1788-
else
1789-
[]
1790-
end
1791-
)...,
1779+
legalize_chlo_to_stablehlo...,
17921780
opt_passes2,
17931781
kern,
17941782
]
@@ -1811,13 +1799,7 @@ function compile_mlir!(
18111799
"canonicalize",
18121800
"remove-unnecessary-enzyme-ops",
18131801
"enzyme-simplify-math",
1814-
(
1815-
if compile_options.legalize_chlo_to_stablehlo
1816-
["func.func(chlo-legalize-to-stablehlo)"]
1817-
else
1818-
[]
1819-
end
1820-
)...,
1802+
legalize_chlo_to_stablehlo...,
18211803
opt_passes2,
18221804
],
18231805
',',
@@ -1854,13 +1836,7 @@ function compile_mlir!(
18541836
"canonicalize",
18551837
"remove-unnecessary-enzyme-ops",
18561838
"enzyme-simplify-math",
1857-
(
1858-
if compile_options.legalize_chlo_to_stablehlo
1859-
["func.func(chlo-legalize-to-stablehlo)"]
1860-
else
1861-
[]
1862-
end
1863-
)...,
1839+
legalize_chlo_to_stablehlo...,
18641840
opt_passes2,
18651841
lower_enzymexla_linalg_pass,
18661842
jit,
@@ -1873,13 +1849,7 @@ function compile_mlir!(
18731849
"canonicalize",
18741850
"remove-unnecessary-enzyme-ops",
18751851
"enzyme-simplify-math",
1876-
(
1877-
if compile_options.legalize_chlo_to_stablehlo
1878-
["func.func(chlo-legalize-to-stablehlo)"]
1879-
else
1880-
[]
1881-
end
1882-
)...,
1852+
legalize_chlo_to_stablehlo...,
18831853
opt_passes2,
18841854
kern,
18851855
raise_passes,
@@ -2406,7 +2376,13 @@ See also [`@code_xla`](@ref), [`@code_hlo`](@ref).
24062376
"""
24072377
macro code_mhlo(args...)
24082378
compile_expr, (; compiled) = compile_call_expr(
2409-
__module__, compile_xla, get_common_compile_options(), args...
2379+
__module__,
2380+
compile_mlir,
2381+
merge(
2382+
get_common_compile_options(),
2383+
Dict{Symbol,Any}(:legalize_stablehlo_to_mhlo => true),
2384+
),
2385+
args...,
24102386
)
24112387
#! format: off
24122388
return esc(
@@ -2427,20 +2403,25 @@ This is the post optimizations XLA HLO module.
24272403
## Options
24282404
24292405
$(COMMON_COMPILE_OPTIONS_DOCS)
2406+
- `before_xla_optimizations`: If `true`, return the `before_optimizations` HLO module.
24302407
24312408
See also [`@code_mhlo`](@ref), [`@code_hlo`](@ref).
24322409
"""
24332410
macro code_xla(args...)
24342411
compile_expr, (; compiled) = compile_call_expr(
2435-
__module__, compile_xla, get_common_compile_options(), args...
2412+
__module__,
2413+
compile_xla,
2414+
merge(
2415+
get_common_compile_options(),
2416+
Dict{Symbol,Any}(:before_xla_optimizations => false),
2417+
),
2418+
args...,
24362419
)
24372420
#! format: off
24382421
return esc(
24392422
:(
24402423
$(compile_expr);
2441-
exec = $(compiled)[2];
2442-
hlo_modules = $(XLA.get_hlo_modules)(exec);
2443-
length(hlo_modules) == 1 ? only(hlo_modules) : hlo_modules
2424+
$(compiled)[3]
24442425
)
24452426
)
24462427
#! format: on
@@ -3374,7 +3355,14 @@ function __resolve_device_and_client(client, seen_args, linear_args, is_sharded)
33743355
return (client, device)
33753356
end
33763357

3377-
function compile_xla(f, args; client=nothing, serializable::Bool=false, kwargs...)
3358+
function compile_xla(
3359+
f,
3360+
args;
3361+
before_xla_optimizations::Bool=false,
3362+
client=nothing,
3363+
serializable::Bool=false,
3364+
kwargs...,
3365+
)
33783366
# register MLIR dialects
33793367
ctx = MLIR.IR.Context(Reactant.registry[], false)
33803368
context_gc_vector[ctx] = Vector{Union{TracedRArray,TracedRNumber}}(undef, 0)
@@ -3430,20 +3418,27 @@ function compile_xla(f, args; client=nothing, serializable::Bool=false, kwargs..
34303418
module_string = ""
34313419
end
34323420

3433-
exec = XLA.compile(
3434-
client,
3435-
device,
3436-
mod;
3437-
num_outputs=length(mlir_fn_res.linear_results),
3438-
num_parameters=length(mlir_fn_res.linear_args),
3439-
mlir_fn_res.is_sharded,
3440-
global_device_ids,
3441-
mlir_fn_res.num_replicas,
3442-
mlir_fn_res.num_partitions,
3443-
mlir_fn_res.use_shardy_partitioner,
3444-
)
3421+
if before_xla_optimizations
3422+
exec = nothing
3423+
hlo_modules = XLA.HloModule(mod)
3424+
else
3425+
exec = XLA.compile(
3426+
client,
3427+
device,
3428+
mod;
3429+
num_outputs=length(mlir_fn_res.linear_results),
3430+
num_parameters=length(mlir_fn_res.linear_args),
3431+
mlir_fn_res.is_sharded,
3432+
global_device_ids,
3433+
mlir_fn_res.num_replicas,
3434+
mlir_fn_res.num_partitions,
3435+
mlir_fn_res.use_shardy_partitioner,
3436+
)
3437+
hlo_modules = XLA.get_hlo_modules(exec)
3438+
hlo_modules = length(hlo_modules) == 1 ? only(hlo_modules) : hlo_modules
3439+
end
34453440

3446-
return mod, exec, mlir_fn_res, device, client, module_string
3441+
return mod, exec, hlo_modules, mlir_fn_res, device, client, module_string
34473442
finally
34483443
MLIR.IR.deactivate!(ctx)
34493444
end
@@ -3459,7 +3454,7 @@ const __thunk_rev_body_cache = Dict{Expr,Symbol}()
34593454
function compile(f, args; kwargs...)
34603455
compile_options, kwargs = __get_compile_options_and_kwargs(; kwargs...)
34613456

3462-
_, exec, mlir_fn_res, device, client, str = compile_xla(
3457+
_, exec, _, mlir_fn_res, device, client, str = compile_xla(
34633458
f, args; compile_options, kwargs...
34643459
)
34653460
(;

src/xla/HloModule.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@ function free_hlo_module(hlo_module)
1111
@ccall MLIR.API.mlir_c.FreeHloModule(hlo_module.ptr::Ptr{Cvoid})::Cvoid
1212
end
1313

14+
function HloModule(mod::MLIR.IR.Module)
15+
return HloModule(
16+
@ccall MLIR.API.mlir_c.convertMlirModuleToHloModule(
17+
mod::MLIR.API.MlirModule
18+
)::Ptr{Cvoid}
19+
)
20+
end
21+
1422
function Base.show(io::IO, hlo_module::HloModule)
1523
GC.@preserve hlo_module begin
1624
str = @ccall MLIR.API.mlir_c.HloModuleToString(hlo_module.ptr::Ptr{Cvoid})::Cstring

test/buffer_donation.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ end
2525
hlo = @code_hlo(donate_fill_x_with_2(a, b))
2626
@test length(findall("tf.aliasing_output = 0", repr(hlo))) == 1
2727

28-
(; preserved_args) = Reactant.Compiler.compile_xla(donate_fill_x_with_2, (a, b))[3]
28+
(; preserved_args) = Reactant.Compiler.compile_xla(donate_fill_x_with_2, (a, b))[4]
2929
preserved_args_idx = last.(preserved_args)
3030
@test preserved_args_idx == [1] # only `y`(i.e. `b`) is preserved
3131

@@ -36,7 +36,7 @@ end
3636
hlo = @code_hlo(donate_inplace_mul(a, b))
3737
@test length(findall("tf.aliasing_output = 0", repr(hlo))) == 1
3838

39-
(; preserved_args) = Reactant.Compiler.compile_xla(donate_inplace_mul, (a, b))[3]
39+
(; preserved_args) = Reactant.Compiler.compile_xla(donate_inplace_mul, (a, b))[4]
4040
preserved_args_idx = last.(preserved_args)
4141
@test preserved_args_idx == [1] # only `y`(i.e. `b`) is preserved
4242

@@ -71,7 +71,7 @@ end
7171
z = Reactant.to_rarray(ones(3))
7272

7373
@code_hlo assert_nonallocating = true update_inplace!(x, y, z)
74-
(; preserved_args) = Reactant.Compiler.compile_xla(update_inplace!, (x, y, z))[3]
74+
(; preserved_args) = Reactant.Compiler.compile_xla(update_inplace!, (x, y, z))[4]
7575
preserved_args_idx = last.(preserved_args)
7676
@test preserved_args_idx == [1, 2] # y and z are both preserved (preserved_args is 0-indexed)
7777

test/compile.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,3 +227,19 @@ end
227227
@test y.x isa Reactant.RArray
228228
@test y.x == fcustom_path(x).x
229229
end
230+
231+
# CHLO legalize options
232+
# test that we are running some mhlo passes first before legalizing, else we will end up
233+
# decomposing some necessary ops
234+
function fn_test(x)
235+
y = Reactant.Ops.top_k(x, 16).values
236+
y_complex = Complex.(y, -y .+ 1)
237+
conj!(y_complex)
238+
return y_complex
239+
end
240+
241+
@testset "chlo legalize" begin
242+
x_ra = Reactant.to_rarray(rand(Float32, 128))
243+
hlo = @code_hlo legalize_chlo_to_stablehlo = true fn_test(x_ra)
244+
@test occursin("mhlo.topk", repr(hlo))
245+
end

0 commit comments

Comments
 (0)