Skip to content

Commit 632964e

Browse files
feat: add an option for chlo to stablehlo (#1394)
* feat: add an option for chlo to stablehlo * fix: call * test: chlo legalize * Update test/basic.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 89369a8 commit 632964e

File tree

3 files changed

+79
-1
lines changed

3 files changed

+79
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ PythonCall = "0.9"
9090
Random = "1.10"
9191
Random123 = "1.7"
9292
ReactantCore = "0.1.12"
93-
Reactant_jll = "0.0.198"
93+
Reactant_jll = "0.0.200"
9494
ScopedValues = "1.3.0"
9595
Scratch = "1.2"
9696
Sockets = "1.10"

src/Compiler.jl

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,6 +1388,7 @@ function compile_mlir!(
13881388
donated_args::Symbol=:auto, # :auto | :none
13891389
optimize_then_pad::Bool=true,
13901390
runtime::Union{Val{:PJRT},Val{:IFRT}},
1391+
legalize_chlo_to_stablehlo::Bool=false,
13911392
kwargs...,
13921393
)
13931394
@assert donated_args (:auto, :none)
@@ -1552,6 +1553,13 @@ function compile_mlir!(
15521553
"canonicalize",
15531554
"remove-unnecessary-enzyme-ops",
15541555
"enzyme-simplify-math",
1556+
(
1557+
if legalize_chlo_to_stablehlo
1558+
["func.func(chlo-legalize-to-stablehlo)"]
1559+
else
1560+
[]
1561+
end
1562+
)...,
15551563
opt_passes2,
15561564
lower_enzymexla_linalg_pass,
15571565
jit,
@@ -1567,6 +1575,13 @@ function compile_mlir!(
15671575
"canonicalize",
15681576
"remove-unnecessary-enzyme-ops",
15691577
"enzyme-simplify-math",
1578+
(
1579+
if legalize_chlo_to_stablehlo
1580+
["func.func(chlo-legalize-to-stablehlo)"]
1581+
else
1582+
[]
1583+
end
1584+
)...,
15701585
opt_passes2,
15711586
kern,
15721587
raise_passes,
@@ -1595,6 +1610,13 @@ function compile_mlir!(
15951610
"canonicalize",
15961611
"remove-unnecessary-enzyme-ops",
15971612
"enzyme-simplify-math",
1613+
(
1614+
if legalize_chlo_to_stablehlo
1615+
["func.func(chlo-legalize-to-stablehlo)"]
1616+
else
1617+
[]
1618+
end
1619+
)...,
15981620
opt_passes2,
15991621
]
16001622
end,
@@ -1619,6 +1641,13 @@ function compile_mlir!(
16191641
"canonicalize",
16201642
"remove-unnecessary-enzyme-ops",
16211643
"enzyme-simplify-math",
1644+
(
1645+
if legalize_chlo_to_stablehlo
1646+
["func.func(chlo-legalize-to-stablehlo)"]
1647+
else
1648+
[]
1649+
end
1650+
)...,
16221651
opt_passes2,
16231652
]
16241653
else
@@ -1632,6 +1661,13 @@ function compile_mlir!(
16321661
"canonicalize",
16331662
"remove-unnecessary-enzyme-ops",
16341663
"enzyme-simplify-math",
1664+
(
1665+
if legalize_chlo_to_stablehlo
1666+
["func.func(chlo-legalize-to-stablehlo)"]
1667+
else
1668+
[]
1669+
end
1670+
)...,
16351671
opt_passes2,
16361672
kern,
16371673
raise_passes,
@@ -1658,6 +1694,13 @@ function compile_mlir!(
16581694
"canonicalize",
16591695
"remove-unnecessary-enzyme-ops",
16601696
"enzyme-simplify-math",
1697+
(
1698+
if legalize_chlo_to_stablehlo
1699+
["func.func(chlo-legalize-to-stablehlo)"]
1700+
else
1701+
[]
1702+
end
1703+
)...,
16611704
opt_passes2,
16621705
kern,
16631706
]
@@ -1680,6 +1723,13 @@ function compile_mlir!(
16801723
"canonicalize",
16811724
"remove-unnecessary-enzyme-ops",
16821725
"enzyme-simplify-math",
1726+
(
1727+
if legalize_chlo_to_stablehlo
1728+
["func.func(chlo-legalize-to-stablehlo)"]
1729+
else
1730+
[]
1731+
end
1732+
)...,
16831733
opt_passes2,
16841734
],
16851735
',',
@@ -1716,6 +1766,13 @@ function compile_mlir!(
17161766
"canonicalize",
17171767
"remove-unnecessary-enzyme-ops",
17181768
"enzyme-simplify-math",
1769+
(
1770+
if legalize_chlo_to_stablehlo
1771+
["func.func(chlo-legalize-to-stablehlo)"]
1772+
else
1773+
[]
1774+
end
1775+
)...,
17191776
opt_passes2,
17201777
lower_enzymexla_linalg_pass,
17211778
jit,
@@ -1728,6 +1785,13 @@ function compile_mlir!(
17281785
"canonicalize",
17291786
"remove-unnecessary-enzyme-ops",
17301787
"enzyme-simplify-math",
1788+
(
1789+
if legalize_chlo_to_stablehlo
1790+
["func.func(chlo-legalize-to-stablehlo)"]
1791+
else
1792+
[]
1793+
end
1794+
)...,
17311795
opt_passes2,
17321796
kern,
17331797
raise_passes,
@@ -2191,6 +2255,7 @@ function get_common_compile_options()
21912255
:optimize_then_pad => true,
21922256
:optimize_communications => true,
21932257
:cudnn_hlo_optimize => false,
2258+
:legalize_chlo_to_stablehlo => false,
21942259
)
21952260
end
21962261

@@ -2242,6 +2307,8 @@ const COMMON_COMPILE_OPTIONS_DOCS = """
22422307
- `cudnn_hlo_optimize`: Run cuDNN specific HLO optimizations. This is only relevant for
22432308
GPU backends and is `false` by default. **Experimental and not heavily tested.**
22442309
_(Only for CUDA backend)_
2310+
- `legalize_chlo_to_stablehlo`: If `true`, `chlo` dialect ops will be converted to
2311+
`stablehlo` ops. This is `false` by default.
22452312
"""
22462313

22472314
const SYNC_DOCS = """

test/basic.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,3 +1360,14 @@ linrange_mat(x1, x2) = Reactant.materialize_traced_array(LinRange(x1, x2, 10024)
13601360
hlo = repr(@code_hlo(linrange_mat(x1_ra, x2_ra)))
13611361
@test contains(hlo, "stablehlo.iota")
13621362
end
1363+
1364+
@testset "chlo legalize to stablehlo" begin
1365+
x = rand(ComplexF32, 4, 4)
1366+
x_ra = Reactant.to_rarray(x)
1367+
1368+
hlo1 = repr(@code_hlo Reactant.Ops.conj(x_ra))
1369+
hlo2 = repr(@code_hlo legalize_chlo_to_stablehlo = true Reactant.Ops.conj(x_ra))
1370+
1371+
@test contains(hlo1, "chlo.conj")
1372+
@test !contains(hlo2, "chlo")
1373+
end

0 commit comments

Comments
 (0)