@@ -1475,6 +1475,22 @@ function get_optimize_comms_passes(options::OptimizeCommunicationOptions)
1475
1475
return res
1476
1476
end
1477
1477
1478
+ function get_stablehlo_to_hlo_passes (; stablehlo_to_mhlo:: Bool = true )
1479
+ passes = (
1480
+ " func.func(stablehlo-ext-chlo-recompose-ops)" ,
1481
+ " symbol-dce" ,
1482
+ " func.func(chlo-legalize-to-high-level-mhlo)" ,
1483
+ " func.func(chlo-legalize-to-stablehlo)" ,
1484
+ )
1485
+ if stablehlo_to_mhlo
1486
+ passes = (passes... , " stablehlo-legalize-to-hlo" )
1487
+ end
1488
+ passes = (
1489
+ passes... , " canonicalize" , " func.func(stablehlo-ext-sink-constants-to-control-flow)"
1490
+ )
1491
+ return passes
1492
+ end
1493
+
1478
1494
function compile_mlir! (
1479
1495
mod,
1480
1496
f,
@@ -1485,6 +1501,7 @@ function compile_mlir!(
1485
1501
fn_kwargs= (),
1486
1502
backend= " gpu" ,
1487
1503
runtime:: Union{Val{:PJRT},Val{:IFRT}} ,
1504
+ legalize_stablehlo_to_mhlo:: Bool = false ,
1488
1505
kwargs... ,
1489
1506
)
1490
1507
# Explicitly don't use block! to avoid creating a closure, which creates
@@ -1624,6 +1641,13 @@ function compile_mlir!(
1624
1641
lower_enzymexla_linalg_pass = " lower-enzymexla-linalg{backend=$backend \
1625
1642
blas_int_width=$blas_int_width }"
1626
1643
1644
+ legalize_chlo_to_stablehlo =
1645
+ if legalize_stablehlo_to_mhlo || compile_options. legalize_chlo_to_stablehlo
1646
+ get_stablehlo_to_hlo_passes (; stablehlo_to_mhlo= legalize_stablehlo_to_mhlo)
1647
+ else
1648
+ ()
1649
+ end
1650
+
1627
1651
if compile_options. optimization_passes === :all
1628
1652
run_pass_pipeline! (
1629
1653
mod,
@@ -1641,13 +1665,7 @@ function compile_mlir!(
1641
1665
" canonicalize" ,
1642
1666
" remove-unnecessary-enzyme-ops" ,
1643
1667
" enzyme-simplify-math" ,
1644
- (
1645
- if compile_options. legalize_chlo_to_stablehlo
1646
- [" func.func(chlo-legalize-to-stablehlo)" ]
1647
- else
1648
- []
1649
- end
1650
- ). .. ,
1668
+ legalize_chlo_to_stablehlo... ,
1651
1669
opt_passes2,
1652
1670
lower_enzymexla_linalg_pass,
1653
1671
jit,
@@ -1663,13 +1681,7 @@ function compile_mlir!(
1663
1681
" canonicalize" ,
1664
1682
" remove-unnecessary-enzyme-ops" ,
1665
1683
" enzyme-simplify-math" ,
1666
- (
1667
- if compile_options. legalize_chlo_to_stablehlo
1668
- [" func.func(chlo-legalize-to-stablehlo)" ]
1669
- else
1670
- []
1671
- end
1672
- ). .. ,
1684
+ legalize_chlo_to_stablehlo... ,
1673
1685
opt_passes2,
1674
1686
kern,
1675
1687
raise_passes,
@@ -1698,13 +1710,7 @@ function compile_mlir!(
1698
1710
" canonicalize" ,
1699
1711
" remove-unnecessary-enzyme-ops" ,
1700
1712
" enzyme-simplify-math" ,
1701
- (
1702
- if compile_options. legalize_chlo_to_stablehlo
1703
- [" func.func(chlo-legalize-to-stablehlo)" ]
1704
- else
1705
- []
1706
- end
1707
- ). .. ,
1713
+ legalize_chlo_to_stablehlo... ,
1708
1714
opt_passes2,
1709
1715
]
1710
1716
end ,
@@ -1729,13 +1735,7 @@ function compile_mlir!(
1729
1735
" canonicalize" ,
1730
1736
" remove-unnecessary-enzyme-ops" ,
1731
1737
" enzyme-simplify-math" ,
1732
- (
1733
- if compile_options. legalize_chlo_to_stablehlo
1734
- [" func.func(chlo-legalize-to-stablehlo)" ]
1735
- else
1736
- []
1737
- end
1738
- ). .. ,
1738
+ legalize_chlo_to_stablehlo... ,
1739
1739
opt_passes2,
1740
1740
]
1741
1741
else
@@ -1749,13 +1749,7 @@ function compile_mlir!(
1749
1749
" canonicalize" ,
1750
1750
" remove-unnecessary-enzyme-ops" ,
1751
1751
" enzyme-simplify-math" ,
1752
- (
1753
- if compile_options. legalize_chlo_to_stablehlo
1754
- [" func.func(chlo-legalize-to-stablehlo)" ]
1755
- else
1756
- []
1757
- end
1758
- ). .. ,
1752
+ legalize_chlo_to_stablehlo... ,
1759
1753
opt_passes2,
1760
1754
kern,
1761
1755
raise_passes,
@@ -1782,13 +1776,7 @@ function compile_mlir!(
1782
1776
" canonicalize" ,
1783
1777
" remove-unnecessary-enzyme-ops" ,
1784
1778
" enzyme-simplify-math" ,
1785
- (
1786
- if compile_options. legalize_chlo_to_stablehlo
1787
- [" func.func(chlo-legalize-to-stablehlo)" ]
1788
- else
1789
- []
1790
- end
1791
- ). .. ,
1779
+ legalize_chlo_to_stablehlo... ,
1792
1780
opt_passes2,
1793
1781
kern,
1794
1782
]
@@ -1811,13 +1799,7 @@ function compile_mlir!(
1811
1799
" canonicalize" ,
1812
1800
" remove-unnecessary-enzyme-ops" ,
1813
1801
" enzyme-simplify-math" ,
1814
- (
1815
- if compile_options. legalize_chlo_to_stablehlo
1816
- [" func.func(chlo-legalize-to-stablehlo)" ]
1817
- else
1818
- []
1819
- end
1820
- ). .. ,
1802
+ legalize_chlo_to_stablehlo... ,
1821
1803
opt_passes2,
1822
1804
],
1823
1805
' ,' ,
@@ -1854,13 +1836,7 @@ function compile_mlir!(
1854
1836
" canonicalize" ,
1855
1837
" remove-unnecessary-enzyme-ops" ,
1856
1838
" enzyme-simplify-math" ,
1857
- (
1858
- if compile_options. legalize_chlo_to_stablehlo
1859
- [" func.func(chlo-legalize-to-stablehlo)" ]
1860
- else
1861
- []
1862
- end
1863
- ). .. ,
1839
+ legalize_chlo_to_stablehlo... ,
1864
1840
opt_passes2,
1865
1841
lower_enzymexla_linalg_pass,
1866
1842
jit,
@@ -1873,13 +1849,7 @@ function compile_mlir!(
1873
1849
" canonicalize" ,
1874
1850
" remove-unnecessary-enzyme-ops" ,
1875
1851
" enzyme-simplify-math" ,
1876
- (
1877
- if compile_options. legalize_chlo_to_stablehlo
1878
- [" func.func(chlo-legalize-to-stablehlo)" ]
1879
- else
1880
- []
1881
- end
1882
- ). .. ,
1852
+ legalize_chlo_to_stablehlo... ,
1883
1853
opt_passes2,
1884
1854
kern,
1885
1855
raise_passes,
@@ -2406,7 +2376,13 @@ See also [`@code_xla`](@ref), [`@code_hlo`](@ref).
2406
2376
"""
2407
2377
macro code_mhlo (args... )
2408
2378
compile_expr, (; compiled) = compile_call_expr (
2409
- __module__, compile_xla, get_common_compile_options (), args...
2379
+ __module__,
2380
+ compile_mlir,
2381
+ merge (
2382
+ get_common_compile_options (),
2383
+ Dict {Symbol,Any} (:legalize_stablehlo_to_mhlo => true ),
2384
+ ),
2385
+ args... ,
2410
2386
)
2411
2387
# ! format: off
2412
2388
return esc (
@@ -2427,20 +2403,25 @@ This is the post optimizations XLA HLO module.
2427
2403
## Options
2428
2404
2429
2405
$(COMMON_COMPILE_OPTIONS_DOCS)
2406
+ - `before_xla_optimizations`: If `true`, return the `before_optimizations` HLO module.
2430
2407
2431
2408
See also [`@code_mhlo`](@ref), [`@code_hlo`](@ref).
2432
2409
"""
2433
2410
macro code_xla (args... )
2434
2411
compile_expr, (; compiled) = compile_call_expr (
2435
- __module__, compile_xla, get_common_compile_options (), args...
2412
+ __module__,
2413
+ compile_xla,
2414
+ merge (
2415
+ get_common_compile_options (),
2416
+ Dict {Symbol,Any} (:before_xla_optimizations => false ),
2417
+ ),
2418
+ args... ,
2436
2419
)
2437
2420
# ! format: off
2438
2421
return esc (
2439
2422
:(
2440
2423
$ (compile_expr);
2441
- exec = $ (compiled)[2 ];
2442
- hlo_modules = $ (XLA. get_hlo_modules)(exec);
2443
- length (hlo_modules) == 1 ? only (hlo_modules) : hlo_modules
2424
+ $ (compiled)[3 ]
2444
2425
)
2445
2426
)
2446
2427
# ! format: on
@@ -3374,7 +3355,14 @@ function __resolve_device_and_client(client, seen_args, linear_args, is_sharded)
3374
3355
return (client, device)
3375
3356
end
3376
3357
3377
- function compile_xla (f, args; client= nothing , serializable:: Bool = false , kwargs... )
3358
+ function compile_xla (
3359
+ f,
3360
+ args;
3361
+ before_xla_optimizations:: Bool = false ,
3362
+ client= nothing ,
3363
+ serializable:: Bool = false ,
3364
+ kwargs... ,
3365
+ )
3378
3366
# register MLIR dialects
3379
3367
ctx = MLIR. IR. Context (Reactant. registry[], false )
3380
3368
context_gc_vector[ctx] = Vector {Union{TracedRArray,TracedRNumber}} (undef, 0 )
@@ -3430,20 +3418,27 @@ function compile_xla(f, args; client=nothing, serializable::Bool=false, kwargs..
3430
3418
module_string = " "
3431
3419
end
3432
3420
3433
- exec = XLA. compile (
3434
- client,
3435
- device,
3436
- mod;
3437
- num_outputs= length (mlir_fn_res. linear_results),
3438
- num_parameters= length (mlir_fn_res. linear_args),
3439
- mlir_fn_res. is_sharded,
3440
- global_device_ids,
3441
- mlir_fn_res. num_replicas,
3442
- mlir_fn_res. num_partitions,
3443
- mlir_fn_res. use_shardy_partitioner,
3444
- )
3421
+ if before_xla_optimizations
3422
+ exec = nothing
3423
+ hlo_modules = XLA. HloModule (mod)
3424
+ else
3425
+ exec = XLA. compile (
3426
+ client,
3427
+ device,
3428
+ mod;
3429
+ num_outputs= length (mlir_fn_res. linear_results),
3430
+ num_parameters= length (mlir_fn_res. linear_args),
3431
+ mlir_fn_res. is_sharded,
3432
+ global_device_ids,
3433
+ mlir_fn_res. num_replicas,
3434
+ mlir_fn_res. num_partitions,
3435
+ mlir_fn_res. use_shardy_partitioner,
3436
+ )
3437
+ hlo_modules = XLA. get_hlo_modules (exec)
3438
+ hlo_modules = length (hlo_modules) == 1 ? only (hlo_modules) : hlo_modules
3439
+ end
3445
3440
3446
- return mod, exec, mlir_fn_res, device, client, module_string
3441
+ return mod, exec, hlo_modules, mlir_fn_res, device, client, module_string
3447
3442
finally
3448
3443
MLIR. IR. deactivate! (ctx)
3449
3444
end
@@ -3459,7 +3454,7 @@ const __thunk_rev_body_cache = Dict{Expr,Symbol}()
3459
3454
function compile (f, args; kwargs... )
3460
3455
compile_options, kwargs = __get_compile_options_and_kwargs (; kwargs... )
3461
3456
3462
- _, exec, mlir_fn_res, device, client, str = compile_xla (
3457
+ _, exec, _, mlir_fn_res, device, client, str = compile_xla (
3463
3458
f, args; compile_options, kwargs...
3464
3459
)
3465
3460
(;
0 commit comments