Skip to content

Commit 3388d93

Browse files
authored
refactor: centralize compile options (#1407)
* refactor: centralize compile options * fix: use sync from compile_options
1 parent 13a08fa commit 3388d93

File tree

6 files changed

+548
-255
lines changed

6 files changed

+548
-255
lines changed

docs/src/api/api.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,20 @@ within_compile
2929
@code_xla
3030
```
3131

32+
## Compile Options
33+
34+
```@docs
35+
CompileOptions
36+
Reactant.DefaultXLACompileOptions
37+
```
38+
39+
### Sharding Specific Options
40+
41+
```@docs
42+
OptimizeCommunicationOptions
43+
ShardyPropagationOptions
44+
```
45+
3246
## Tracing customization
3347

3448
```@docs

src/CompileOptions.jl

Lines changed: 289 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
# TODO: make the other optimize options into a struct as well
1+
# TODO: document these options at some point
2+
"""
3+
OptimizeCommunicationOptions
4+
5+
Fine-grained control over the optimization passes that rewrite ops to minimize collective
6+
communication.
7+
"""
28
@kwdef struct OptimizeCommunicationOptions
39
periodic_concat::Int = 0
410
rotate_comm::Int = 0
@@ -20,3 +26,285 @@ function Base.String(options::OptimizeCommunicationOptions)
2026
"}"
2127
)
2228
end
29+
30+
"""
31+
ShardyPropagationOptions
32+
33+
Fine-grained control over the sharding propagation pipeline. For more information on
34+
sharding propagation, see the
35+
[Shardy Docs](https://openxla.org/shardy/sdy_propagation_passes).
36+
37+
## Options
38+
39+
- `keep_sharding_rules::Bool`: whether to keep existing and created op sharding rules.
40+
- `conservative_propagation::Bool`: whether to disallow split axes and non-divisible
41+
sharding axes during propagation.
42+
- `debug_sharding_origins::Bool`: whether to save information about the origin of a
43+
sharding on the MLIR module. These would be the shardings on the function inputs,
44+
outputs, sharding constraints and manual computations before propagation.
45+
- `debug_propagation_edge_sharding::Bool`: whether to save information about the edge
46+
source of a sharding on the MLIR module. These are what operand/result introduced a
47+
sharding on some op result.
48+
- `skip_convert_to_reshard::Bool`
49+
- `skip_inline::Bool`
50+
- `enable_insert_explicit_collectives::Bool`: whether to insert explicit collectives
51+
for sharding propagation. This is useful for debugging and checking the location of
52+
the communication ops.
53+
"""
54+
@kwdef struct ShardyPropagationOptions
55+
keep_sharding_rules::Bool = false
56+
conservative_propagation::Bool = false
57+
debug_sharding_origins::Bool = false
58+
debug_propagation_edge_sharding::Bool = false
59+
skip_convert_to_reshard::Bool = false
60+
skip_inline::Bool = false
61+
enable_insert_explicit_collectives::Bool = false
62+
end
63+
64+
"""
65+
CompileOptions
66+
67+
Fine-grained control over the compilation options for the Reactant compiler.
68+
69+
## Controlling Optimization Passes
70+
71+
- `optimization_passes`: Optimizations passes to run on the traced MLIR code. Valid types
72+
of values are:
73+
- Bool (true/false): whether to run the optimization passes or not. Defaults to `true`.
74+
- String: a custom string with the passes to run. The string should be a comma-separated
75+
list of MLIR passes. For example, `"canonicalize,enzyme-hlo-opt"`.
76+
- Symbol: a predefined set of passes to run. Valid options are:
77+
1. `:all`: Default set of optimization passes. The exact set of passes are not fixed
78+
and may change in future versions of Reactant. It is recommended to use this
79+
option for most users.
80+
2. `:none`: No optimization passes will be run.
81+
3. Other predefined options are: `:before_kernel`, `:before_jit`, `:before_raise`,
82+
`:before_enzyme`, `:after_enzyme`, `:just_batch`, `:canonicalize`, `:only_enzyme`.
83+
- `no_nan`: If `true`, the optimization passes will assume that the function does not
84+
produce NaN values. This can lead to more aggressive optimizations **(and potentially
85+
incorrect results if the function does produce NaN values)**.
86+
- `all_finite`: If `true`, the optimization passes will assume that the function does not
87+
produce Inf or -Inf values. This can lead to more aggressive optimizations **(and
88+
potentially incorrect results if the function does produce Inf or -Inf values)**.
89+
- `transpose_propagate`: If `:up`, `stablehlo.transpose` operations will be
90+
propagated up the computation graph. If `:down`, they will be propagated down. Defaults
91+
to `:up`.
92+
- `reshape_propagate`: If `:up`, `stablehlo.reshape` operations will be propagated up
93+
the computation graph. If `:down`, they will be propagated down. Defaults to `:up`.
94+
- `max_constant_threshold`: If the number of elements in a constant is greater than this
95+
threshold (for a non-splatted constant), we will throw an error.
96+
- `inline`: If `true`, all functions will be inlined. This is `true` by default.
97+
98+
## Raising Options
99+
100+
- `raise`: If `true`, the function will be compiled with the raising pass, which raises
101+
CUDA and KernelAbstractions kernels to HLO. Defaults to `false`, but is automatically
102+
activated if the inputs are sharded.
103+
- `raise_first`: If `true`, the raising pass will be run before the optimization passes.
104+
Defaults to `false`.
105+
106+
## Dialect Specific Options
107+
108+
- `legalize_chlo_to_stablehlo`: If `true`, `chlo` dialect ops will be converted to
109+
`stablehlo` ops. This is `false` by default.
110+
111+
## Backend Specific Options
112+
113+
### Only for CUDA backend
114+
115+
- `cudnn_hlo_optimize`: Run cuDNN specific HLO optimizations. This is only relevant for
116+
GPU backends and is `false` by default. **Experimental and not heavily tested.**
117+
118+
## Sharding Options
119+
120+
- `shardy_passes`: Defaults to `:to_mhlo_shardings`. Other options are:
121+
- `:none`: No sharding passes will be run. Shardy + MHLO shardings are handled by XLA.
122+
- `:post_sdy_propagation`: Runs the Shardy propagation passes. MHLO shardings are
123+
handled by XLA.
124+
- [`ShardyPropagationOptions`](@ref): Custom sharding propagation options.
125+
MHLO shardings are handled by XLA.
126+
- `:to_mhlo_shardings`: Runs the Shardy propagation passes and then exports the
127+
shardings to MHLO. All passes are run via MLIR pass pipeline and don't involve XLA.
128+
- `optimize_then_pad`: If `true`, the function will be optimized before padding (for
129+
non-divisible sharding axes) is applied. Defaults to `true`. _(Only for Sharded Inputs)_
130+
- `optimize_communications`: If `true`, additional passes for optimizing communication
131+
in sharded computations will be run. Defaults to `true`. _(Only for Sharded Inputs)_
132+
133+
## Julia Codegen Options
134+
135+
- `donated_args`: If `:auto`, the function will automatically donate the arguments that
136+
are not preserved in the function body. If `:none`, no arguments will be donated.
137+
Defaults to `:auto`.
138+
- `assert_nonallocating`: If `true`, we make sure that no new buffers are
139+
returned by the function. Any buffer returned must be donated from the inputs. Defaults
140+
to `false`.
141+
- `sync`: Reactant computations are asynchronous by default. If `true`, the computation
142+
will be executed synchronously, blocking till the computation is complete. This is
143+
recommended when benchmarking.
144+
145+
# Extended Help
146+
147+
## Private Options
148+
149+
!!! warning
150+
151+
These options are not part of the public API and are subject to change without any
152+
notice or deprecation cycle.
153+
154+
- `disable_scatter_gather_optimization_passes`: Disables the scatter-gather
155+
optimization passes. This is `false` by default.
156+
- `disable_pad_optimization_passes`: Disables the pad optimization passes. This is
157+
`false` by default.
158+
"""
159+
struct CompileOptions
160+
optimization_passes::Union{Symbol,String}
161+
no_nan::Bool
162+
all_finite::Bool
163+
inline::Bool
164+
transpose_propagate::Symbol
165+
reshape_propagate::Symbol
166+
max_constant_threshold::Int
167+
# Raising options
168+
raise::Union{Bool,String}
169+
raise_first::Bool
170+
# dialect specific options
171+
legalize_chlo_to_stablehlo::Bool
172+
# backend specific options
173+
cudnn_hlo_optimize::Bool
174+
# sharding options
175+
shardy_passes::Union{Symbol,ShardyPropagationOptions}
176+
optimize_then_pad::Bool
177+
optimize_communications::Union{Bool,OptimizeCommunicationOptions}
178+
# julia codegen options
179+
assert_nonallocating::Bool
180+
donated_args::Symbol
181+
sync::Bool
182+
## private options for ablation studies
183+
disable_scatter_gather_optimization_passes::Bool
184+
disable_pad_optimization_passes::Bool
185+
end
186+
187+
function CompileOptions(;
188+
optimization_passes::Union{Bool,Symbol,String}=:all,
189+
no_nan::Bool=false,
190+
all_finite::Bool=false,
191+
inline::Bool=true,
192+
transpose_propagate::Symbol=:up,
193+
reshape_propagate::Symbol=:up,
194+
max_constant_threshold::Int=1024,
195+
raise::Union{Bool,String}=false,
196+
raise_first::Bool=false,
197+
legalize_chlo_to_stablehlo::Bool=false,
198+
cudnn_hlo_optimize::Bool=false,
199+
shardy_passes::Union{Symbol,ShardyPropagationOptions}=:to_mhlo_shardings,
200+
optimize_then_pad::Bool=true,
201+
optimize_communications::Union{Bool,OptimizeCommunicationOptions}=true,
202+
assert_nonallocating::Bool=false,
203+
donated_args::Symbol=:auto,
204+
sync::Bool=false,
205+
disable_scatter_gather_optimization_passes::Bool=false,
206+
disable_pad_optimization_passes::Bool=false,
207+
)
208+
optimization_passes isa Bool &&
209+
(optimization_passes = ifelse(optimization_passes, :all, :none))
210+
211+
if optimization_passes isa Symbol
212+
@assert optimization_passes in [
213+
:all,
214+
:before_kernel,
215+
:before_jit,
216+
:before_raise,
217+
:no_enzyme,
218+
:only_enzyme,
219+
:after_enzyme,
220+
:before_enzyme,
221+
:canonicalize,
222+
:just_batch,
223+
:none,
224+
]
225+
end
226+
227+
@assert transpose_propagate in [:up, :down, :none]
228+
@assert reshape_propagate in [:up, :down, :none]
229+
230+
if shardy_passes isa Symbol
231+
@assert shardy_passes in [:none, :to_mhlo_shardings, :post_sdy_propagation]
232+
end
233+
234+
return CompileOptions(
235+
optimization_passes,
236+
no_nan,
237+
all_finite,
238+
inline,
239+
transpose_propagate,
240+
reshape_propagate,
241+
max_constant_threshold,
242+
raise,
243+
raise_first,
244+
legalize_chlo_to_stablehlo,
245+
cudnn_hlo_optimize,
246+
shardy_passes,
247+
optimize_then_pad,
248+
optimize_communications,
249+
assert_nonallocating,
250+
donated_args,
251+
sync,
252+
disable_scatter_gather_optimization_passes,
253+
disable_pad_optimization_passes,
254+
)
255+
end
256+
257+
function __compile_options_from_kwags(;
258+
compile_options::Union{Missing,CompileOptions}=missing,
259+
optimize::Union{Bool,Symbol,String}=true,
260+
kwargs...,
261+
)
262+
compile_options isa CompileOptions && return compile_options
263+
return CompileOptions(; optimization_passes=optimize, kwargs...)
264+
end
265+
266+
function __reverse_propagation(sym::Symbol)
267+
sym == :up && return :down
268+
sym === :down && return :up
269+
sym == :none && return :none
270+
return error("Invalid value: $sym. Expected :up or :down or :none")
271+
end
272+
273+
function __compile_options_with_reversed_propagation(compile_options::CompileOptions)
274+
return CompileOptions(
275+
compile_options.optimization_passes,
276+
compile_options.no_nan,
277+
compile_options.all_finite,
278+
compile_options.inline,
279+
__reverse_propagation(compile_options.transpose_propagate),
280+
__reverse_propagation(compile_options.reshape_propagate),
281+
compile_options.max_constant_threshold,
282+
compile_options.raise,
283+
compile_options.raise_first,
284+
compile_options.legalize_chlo_to_stablehlo,
285+
compile_options.cudnn_hlo_optimize,
286+
compile_options.shardy_passes,
287+
compile_options.optimize_then_pad,
288+
compile_options.optimize_communications,
289+
compile_options.assert_nonallocating,
290+
compile_options.donated_args,
291+
compile_options.sync,
292+
compile_options.disable_scatter_gather_optimization_passes,
293+
compile_options.disable_pad_optimization_passes,
294+
)
295+
end
296+
297+
"""
298+
DefaultXLACompileOptions()
299+
300+
Runs specific Enzyme-JAX passes to ensure that the generated code is compatible with
301+
XLA compilation.
302+
303+
!!! warning
304+
305+
This is mostly a benchmarking option, and the default [`CompileOptions`](@ref) is almost
306+
certainly a better option.
307+
"""
308+
function DefaultXLACompileOptions()
309+
return CompileOptions(; optimization_passes=:only_enzyme, inline=false)
310+
end

0 commit comments

Comments
 (0)