Skip to content

Commit f57817f

Browse files
authored
docs: document the options to core macros (#1348)
* docs: document the options to core macros * docs: cleanup optimize docs
1 parent 74d6cb7 commit f57817f

File tree

2 files changed

+132
-85
lines changed

2 files changed

+132
-85
lines changed

src/Compiler.jl

Lines changed: 112 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -2127,29 +2127,92 @@ function compile_mlir!(
21272127
)
21282128
end
21292129

2130+
const COMMON_COMPILE_OPTIONS = Dict{Symbol,Any}(
2131+
:optimize => true,
2132+
:no_nan => false,
2133+
:client => nothing,
2134+
:raise => false,
2135+
:raise_first => false,
2136+
:shardy_passes => :(:to_mhlo_shardings),
2137+
:assert_nonallocating => false,
2138+
:donated_args => :(:auto),
2139+
:transpose_propagate => :(:up),
2140+
:reshape_propagate => :(:up),
2141+
:optimize_then_pad => true,
2142+
:optimize_communications => true,
2143+
:cudnn_hlo_optimize => false,
2144+
)
2145+
2146+
const COMMON_COMPILE_OPTIONS_DOCS = """
2147+
- `optimize`: Optimizations passes to run on the traced MLIR code. Valid types of values
2148+
are:
2149+
- Bool (true/false): whether to run the optimization passes or not. Defaults to `true`.
2150+
- String: a custom string with the passes to run. The string should be a comma-separated
2151+
list of MLIR passes. For example, `"canonicalize,enzyme-hlo-opt"`.
2152+
- Symbol: a predefined set of passes to run. Valid options are:
2153+
1. `:all`: Default set of optimization passes. The exact set of passes are not fixed
2154+
and may change in future versions of Reactant. It is recommended to use this
2155+
option for most users.
2156+
2. `:none`: No optimization passes will be run.
2157+
3. Other predefined options are: `:before_kernel`, `:before_jit`, `:before_raise`,
2158+
`:before_enzyme`, `:after_enzyme`, `:just_batch`, `:canonicalize`, `:only_enzyme`.
2159+
- `no_nan`: If `true`, the optimization passes will assume that the function does not
2160+
produce NaN values. This can lead to more aggressive optimizations **(and potentially
2161+
incorrect results if the function does produce NaN values)**.
2162+
- `client`: XLA Client used for compilation. If not specified, the default client is used.
2163+
- `raise`: If `true`, the function will be compiled with the raising pass, which raises
2164+
CUDA and KernelAbstractions kernels to HLO. Defaults to `false`, but is automatically
2165+
activated if the inputs are sharded.
2166+
- `raise_first`: If `true`, the raising pass will be run before the optimization passes.
2167+
Defaults to `false`.
2168+
- `shardy_passes`: Defaults to `:to_mhlo_shardings`. Other options are:
2169+
- `:none`: No sharding passes will be run. Shardy + MHLO shardings are handled by XLA.
2170+
- `:post_sdy_propagation`: Runs the Shardy propagation passes. MHLO shardings are
2171+
handled by XLA.
2172+
- [`Sharding.ShardyPropagationOptions`](@ref): Custom sharding propagation options.
2173+
MHLO shardings are handled by XLA.
2174+
- `:to_mhlo_shardings`: Runs the Shardy propagation passes and then exports the
2175+
shardings to MHLO. All passes are run via MLIR pass pipeline and don't involve XLA.
2176+
- `assert_nonallocating`: If `true`, we make sure that no new buffers are
2177+
returned by the function. Any buffer returned must be donated from the inputs. Defaults
2178+
to `false`.
2179+
- `donated_args`: If `:auto`, the function will automatically donate the arguments that
2180+
are not preserved in the function body. If `:none`, no arguments will be donated.
2181+
Defaults to `:auto`.
2182+
- `transpose_propagate`: If `:up`, `stablehlo.transpose` operations will be
2183+
propagated up the computation graph. If `:down`, they will be propagated down. Defaults
2184+
to `:up`.
2185+
- `reshape_propagate`: If `:up`, `stablehlo.reshape` operations will be propagated up
2186+
the computation graph. If `:down`, they will be propagated down. Defaults to `:up`.
2187+
- `optimize_then_pad`: If `true`, the function will be optimized before padding (for
2188+
non-divisible sharding axes) is applied. Defaults to `true`. _(Only for Sharded Inputs)_
2189+
- `optimize_communications`: If `true`, additional passes for optimizing communication
2190+
in sharded computations will be run. Defaults to `true`. _(Only for Sharded Inputs)_
2191+
- `cudnn_hlo_optimize`: Run cuDNN specific HLO optimizations. This is only relevant for
2192+
GPU backends and is `false` by default. **Experimental and not heavily tested.**
2193+
_(Only for CUDA backend)_
2194+
"""
2195+
2196+
const SYNC_DOCS = """
2197+
- `sync`: Reactant computations are asynchronous by default. If `true`, the computation
2198+
will be executed synchronously, blocking till the computation is complete. This is
2199+
recommended when benchmarking.
2200+
"""
2201+
21302202
"""
21312203
@code_hlo [optimize = ...] [no_nan = <true/false>] f(args...)
21322204
2205+
Prints the compiled MLIR module for the function `f` with arguments `args`.
2206+
2207+
## Options
2208+
2209+
$(COMMON_COMPILE_OPTIONS_DOCS)
2210+
21332211
See also [`@code_xla`](@ref), [`@code_mhlo`](@ref).
21342212
"""
21352213
macro code_hlo(args...)
2136-
default_options = Dict{Symbol,Any}(
2137-
:optimize => true,
2138-
:no_nan => false,
2139-
:client => nothing,
2140-
:raise => false,
2141-
:raise_first => false,
2142-
:shardy_passes => :(:to_mhlo_shardings),
2143-
:assert_nonallocating => false,
2144-
:donated_args => :(:auto),
2145-
:transpose_propagate => :(:up),
2146-
:reshape_propagate => :(:up),
2147-
:optimize_then_pad => true,
2148-
:optimize_communications => true,
2149-
:cudnn_hlo_optimize => false,
2150-
)
21512214
compile_expr, (; compiled) = compile_call_expr(
2152-
__module__, compile_mlir, default_options, args...
2215+
__module__, compile_mlir, COMMON_COMPILE_OPTIONS, args...
21532216
)
21542217
#! format: off
21552218
return esc(
@@ -2164,28 +2227,17 @@ end
21642227
"""
21652228
@code_mhlo [optimize = ...] [no_nan = <true/false>] f(args...)
21662229
2167-
Similar to `@code_hlo`, but prints the module after running the XLA compiler.
2230+
Similar to `@code_hlo`, but runs additional passes to export the stablehlo module to MHLO.
2231+
2232+
## Options
2233+
2234+
$(COMMON_COMPILE_OPTIONS_DOCS)
21682235
21692236
See also [`@code_xla`](@ref), [`@code_hlo`](@ref).
21702237
"""
21712238
macro code_mhlo(args...)
2172-
default_options = Dict{Symbol,Any}(
2173-
:optimize => true,
2174-
:no_nan => false,
2175-
:client => nothing,
2176-
:raise => false,
2177-
:raise_first => false,
2178-
:shardy_passes => :(:to_mhlo_shardings),
2179-
:assert_nonallocating => false,
2180-
:donated_args => :(:auto),
2181-
:transpose_propagate => :(:up),
2182-
:reshape_propagate => :(:up),
2183-
:optimize_then_pad => true,
2184-
:optimize_communications => true,
2185-
:cudnn_hlo_optimize => false,
2186-
)
21872239
compile_expr, (; compiled) = compile_call_expr(
2188-
__module__, compile_xla, default_options, args...
2240+
__module__, compile_xla, COMMON_COMPILE_OPTIONS, args...
21892241
)
21902242
#! format: off
21912243
return esc(
@@ -2200,28 +2252,18 @@ end
22002252
"""
22012253
@code_xla [optimize = ...] [no_nan = <true/false>] f(args...)
22022254
2203-
Similar to `@code_hlo`, but prints the HLO module.
2255+
Similar to [`@code_hlo`](@ref), but runs additional XLA passes and exports MLIR to XLA HLO.
2256+
This is the post optimizations XLA HLO module.
2257+
2258+
## Options
2259+
2260+
$(COMMON_COMPILE_OPTIONS_DOCS)
22042261
22052262
See also [`@code_mhlo`](@ref), [`@code_hlo`](@ref).
22062263
"""
22072264
macro code_xla(args...)
2208-
default_options = Dict{Symbol,Any}(
2209-
:optimize => true,
2210-
:no_nan => false,
2211-
:client => nothing,
2212-
:raise => false,
2213-
:raise_first => false,
2214-
:shardy_passes => :(:to_mhlo_shardings),
2215-
:assert_nonallocating => false,
2216-
:donated_args => :(:auto),
2217-
:transpose_propagate => :(:up),
2218-
:reshape_propagate => :(:up),
2219-
:optimize_then_pad => true,
2220-
:optimize_communications => true,
2221-
:cudnn_hlo_optimize => false,
2222-
)
22232265
compile_expr, (; compiled) = compile_call_expr(
2224-
__module__, compile_xla, default_options, args...
2266+
__module__, compile_xla, COMMON_COMPILE_OPTIONS, args...
22252267
)
22262268
#! format: off
22272269
return esc(
@@ -2237,50 +2279,36 @@ end
22372279

22382280
"""
22392281
@compile [optimize = ...] [no_nan = <true/false>] [sync = <true/false>] f(args...)
2282+
2283+
Compile the function `f` with arguments `args` and return the compiled function.
2284+
2285+
## Options
2286+
2287+
$(COMMON_COMPILE_OPTIONS_DOCS)
2288+
$(SYNC_DOCS)
2289+
2290+
See also [`@jit`](@ref), [`@code_hlo`](@ref), [`@code_mhlo`](@ref), [`@code_xla`](@ref).
22402291
"""
22412292
macro compile(args...)
2242-
default_options = Dict{Symbol,Any}(
2243-
:optimize => true,
2244-
:sync => false,
2245-
:no_nan => false,
2246-
:client => nothing,
2247-
:raise => false,
2248-
:raise_first => false,
2249-
:shardy_passes => :(:to_mhlo_shardings),
2250-
:assert_nonallocating => false,
2251-
:serializable => false,
2252-
:donated_args => :(:auto),
2253-
:transpose_propagate => :(:up),
2254-
:reshape_propagate => :(:up),
2255-
:optimize_then_pad => true,
2256-
:optimize_communications => true,
2257-
:cudnn_hlo_optimize => false,
2258-
)
2293+
default_options = merge(COMMON_COMPILE_OPTIONS, Dict{Symbol,Any}(:sync => false))
22592294
return esc(first(compile_call_expr(__module__, compile, default_options, args...)))
22602295
end
22612296

22622297
"""
22632298
@jit [optimize = ...] [no_nan = <true/false>] [sync = <true/false>] f(args...)
22642299
2265-
Run @compile f(args..) then immediately execute it
2300+
Run @compile f(args..) then immediately execute it. Most users should use [`@compile`](@ref)
2301+
instead to cache the compiled function and execute it later.
2302+
2303+
## Options
2304+
2305+
$(COMMON_COMPILE_OPTIONS_DOCS)
2306+
$(SYNC_DOCS)
2307+
2308+
See also [`@compile`](@ref), [`@code_hlo`](@ref), [`@code_mhlo`](@ref), [`@code_xla`](@ref).
22662309
"""
22672310
macro jit(args...)
2268-
default_options = Dict{Symbol,Any}(
2269-
:optimize => true,
2270-
:sync => false,
2271-
:no_nan => false,
2272-
:client => nothing,
2273-
:raise => false,
2274-
:raise_first => false,
2275-
:shardy_passes => :(:to_mhlo_shardings),
2276-
:assert_nonallocating => false,
2277-
:donated_args => :(:auto),
2278-
:transpose_propagate => :(:up),
2279-
:reshape_propagate => :(:up),
2280-
:optimize_then_pad => true,
2281-
:optimize_communications => true,
2282-
:cudnn_hlo_optimize => false,
2283-
)
2311+
default_options = merge(COMMON_COMPILE_OPTIONS, Dict{Symbol,Any}(:sync => false))
22842312
compile_expr, (; compiled, args) = compile_call_expr(
22852313
__module__, compile, default_options, args...
22862314
)

src/Sharding.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1113,7 +1113,26 @@ end
11131113
"""
11141114
ShardyPropagationOptions
11151115
1116-
Fine-grained control over the sharding propagation pipeline.
1116+
Fine-grained control over the sharding propagation pipeline. For more information on
1117+
sharding propagation, see the
1118+
[Shardy Docs](https://openxla.org/shardy/sdy_propagation_passes).
1119+
1120+
## Options
1121+
1122+
- `keep_sharding_rules::Bool`: whether to keep existing and created op sharding rules.
1123+
- `conservative_propagation::Bool`: whether to disallow split axes and non-divisible
1124+
sharding axes during propagation.
1125+
- `debug_sharding_origins::Bool`: whether to save information about the origin of a
1126+
sharding on the MLIR module. These would be the shardings on the function inputs,
1127+
outputs, sharding constraints and manual computations before propagation.
1128+
- `debug_propagation_edge_sharding::Bool`: whether to save information about the edge
1129+
source of a sharding on the MLIR module. These are what operand/result introduced a
1130+
sharding on some op result.
1131+
- `skip_convert_to_reshard::Bool`
1132+
- `skip_inline::Bool`
1133+
- `enable_insert_explicit_collectives::Bool`: whether to insert explicit collectives
1134+
for sharding propagation. This is useful for debugging and checking the location of
1135+
the communication ops.
11171136
"""
11181137
@kwdef struct ShardyPropagationOptions
11191138
keep_sharding_rules::Bool = false

0 commit comments

Comments
 (0)