@@ -2127,29 +2127,92 @@ function compile_mlir!(
2127
2127
)
2128
2128
end
2129
2129
2130
+ const COMMON_COMPILE_OPTIONS = Dict {Symbol,Any} (
2131
+ :optimize => true ,
2132
+ :no_nan => false ,
2133
+ :client => nothing ,
2134
+ :raise => false ,
2135
+ :raise_first => false ,
2136
+ :shardy_passes => :(:to_mhlo_shardings ),
2137
+ :assert_nonallocating => false ,
2138
+ :donated_args => :(:auto ),
2139
+ :transpose_propagate => :(:up ),
2140
+ :reshape_propagate => :(:up ),
2141
+ :optimize_then_pad => true ,
2142
+ :optimize_communications => true ,
2143
+ :cudnn_hlo_optimize => false ,
2144
+ )
2145
+
2146
+ const COMMON_COMPILE_OPTIONS_DOCS = """
2147
+ - `optimize`: Optimizations passes to run on the traced MLIR code. Valid types of values
2148
+ are:
2149
+ - Bool (true/false): whether to run the optimization passes or not. Defaults to `true`.
2150
+ - String: a custom string with the passes to run. The string should be a comma-separated
2151
+ list of MLIR passes. For example, `"canonicalize,enzyme-hlo-opt"`.
2152
+ - Symbol: a predefined set of passes to run. Valid options are:
2153
+ 1. `:all`: Default set of optimization passes. The exact set of passes are not fixed
2154
+ and may change in future versions of Reactant. It is recommended to use this
2155
+ option for most users.
2156
+ 2. `:none`: No optimization passes will be run.
2157
+ 3. Other predefined options are: `:before_kernel`, `:before_jit`, `:before_raise`,
2158
+ `:before_enzyme`, `:after_enzyme`, `:just_batch`, `:canonicalize`, `:only_enzyme`.
2159
+ - `no_nan`: If `true`, the optimization passes will assume that the function does not
2160
+ produce NaN values. This can lead to more aggressive optimizations **(and potentially
2161
+ incorrect results if the function does produce NaN values)**.
2162
+ - `client`: XLA Client used for compilation. If not specified, the default client is used.
2163
+ - `raise`: If `true`, the function will be compiled with the raising pass, which raises
2164
+ CUDA and KernelAbstractions kernels to HLO. Defaults to `false`, but is automatically
2165
+ activated if the inputs are sharded.
2166
+ - `raise_first`: If `true`, the raising pass will be run before the optimization passes.
2167
+ Defaults to `false`.
2168
+ - `shardy_passes`: Defaults to `:to_mhlo_shardings`. Other options are:
2169
+ - `:none`: No sharding passes will be run. Shardy + MHLO shardings are handled by XLA.
2170
+ - `:post_sdy_propagation`: Runs the Shardy propagation passes. MHLO shardings are
2171
+ handled by XLA.
2172
+ - [`Sharding.ShardyPropagationOptions`](@ref): Custom sharding propagation options.
2173
+ MHLO shardings are handled by XLA.
2174
+ - `:to_mhlo_shardings`: Runs the Shardy propagation passes and then exports the
2175
+ shardings to MHLO. All passes are run via MLIR pass pipeline and don't involve XLA.
2176
+ - `assert_nonallocating`: If `true`, we make sure that no new buffers are
2177
+ returned by the function. Any buffer returned must be donated from the inputs. Defaults
2178
+ to `false`.
2179
+ - `donated_args`: If `:auto`, the function will automatically donate the arguments that
2180
+ are not preserved in the function body. If `:none`, no arguments will be donated.
2181
+ Defaults to `:auto`.
2182
+ - `transpose_propagate`: If `:up`, `stablehlo.transpose` operations will be
2183
+ propagated up the computation graph. If `:down`, they will be propagated down. Defaults
2184
+ to `:up`.
2185
+ - `reshape_propagate`: If `:up`, `stablehlo.reshape` operations will be propagated up
2186
+ the computation graph. If `:down`, they will be propagated down. Defaults to `:up`.
2187
+ - `optimize_then_pad`: If `true`, the function will be optimized before padding (for
2188
+ non-divisible sharding axes) is applied. Defaults to `true`. _(Only for Sharded Inputs)_
2189
+ - `optimize_communications`: If `true`, additional passes for optimizing communication
2190
+ in sharded computations will be run. Defaults to `true`. _(Only for Sharded Inputs)_
2191
+ - `cudnn_hlo_optimize`: Run cuDNN specific HLO optimizations. This is only relevant for
2192
+ GPU backends and is `false` by default. **Experimental and not heavily tested.**
2193
+ _(Only for CUDA backend)_
2194
+ """
2195
+
2196
+ const SYNC_DOCS = """
2197
+ - `sync`: Reactant computations are asynchronous by default. If `true`, the computation
2198
+ will be executed synchronously, blocking till the computation is complete. This is
2199
+ recommended when benchmarking.
2200
+ """
2201
+
2130
2202
"""
2131
2203
@code_hlo [optimize = ...] [no_nan = <true/false>] f(args...)
2132
2204
2205
+ Prints the compiled MLIR module for the function `f` with arguments `args`.
2206
+
2207
+ ## Options
2208
+
2209
+ $(COMMON_COMPILE_OPTIONS_DOCS)
2210
+
2133
2211
See also [`@code_xla`](@ref), [`@code_mhlo`](@ref).
2134
2212
"""
2135
2213
macro code_hlo (args... )
2136
- default_options = Dict {Symbol,Any} (
2137
- :optimize => true ,
2138
- :no_nan => false ,
2139
- :client => nothing ,
2140
- :raise => false ,
2141
- :raise_first => false ,
2142
- :shardy_passes => :(:to_mhlo_shardings ),
2143
- :assert_nonallocating => false ,
2144
- :donated_args => :(:auto ),
2145
- :transpose_propagate => :(:up ),
2146
- :reshape_propagate => :(:up ),
2147
- :optimize_then_pad => true ,
2148
- :optimize_communications => true ,
2149
- :cudnn_hlo_optimize => false ,
2150
- )
2151
2214
compile_expr, (; compiled) = compile_call_expr (
2152
- __module__, compile_mlir, default_options , args...
2215
+ __module__, compile_mlir, COMMON_COMPILE_OPTIONS , args...
2153
2216
)
2154
2217
# ! format: off
2155
2218
return esc (
@@ -2164,28 +2227,17 @@ end
2164
2227
"""
2165
2228
@code_mhlo [optimize = ...] [no_nan = <true/false>] f(args...)
2166
2229
2167
- Similar to `@code_hlo`, but prints the module after running the XLA compiler.
2230
+ Similar to `@code_hlo`, but runs additional passes to export the stablehlo module to MHLO.
2231
+
2232
+ ## Options
2233
+
2234
+ $(COMMON_COMPILE_OPTIONS_DOCS)
2168
2235
2169
2236
See also [`@code_xla`](@ref), [`@code_hlo`](@ref).
2170
2237
"""
2171
2238
macro code_mhlo (args... )
2172
- default_options = Dict {Symbol,Any} (
2173
- :optimize => true ,
2174
- :no_nan => false ,
2175
- :client => nothing ,
2176
- :raise => false ,
2177
- :raise_first => false ,
2178
- :shardy_passes => :(:to_mhlo_shardings ),
2179
- :assert_nonallocating => false ,
2180
- :donated_args => :(:auto ),
2181
- :transpose_propagate => :(:up ),
2182
- :reshape_propagate => :(:up ),
2183
- :optimize_then_pad => true ,
2184
- :optimize_communications => true ,
2185
- :cudnn_hlo_optimize => false ,
2186
- )
2187
2239
compile_expr, (; compiled) = compile_call_expr (
2188
- __module__, compile_xla, default_options , args...
2240
+ __module__, compile_xla, COMMON_COMPILE_OPTIONS , args...
2189
2241
)
2190
2242
# ! format: off
2191
2243
return esc (
@@ -2200,28 +2252,18 @@ end
2200
2252
"""
2201
2253
@code_xla [optimize = ...] [no_nan = <true/false>] f(args...)
2202
2254
2203
- Similar to `@code_hlo`, but prints the HLO module.
2255
+ Similar to [`@code_hlo`](@ref), but runs additional XLA passes and exports MLIR to XLA HLO.
2256
+ This is the post optimizations XLA HLO module.
2257
+
2258
+ ## Options
2259
+
2260
+ $(COMMON_COMPILE_OPTIONS_DOCS)
2204
2261
2205
2262
See also [`@code_mhlo`](@ref), [`@code_hlo`](@ref).
2206
2263
"""
2207
2264
macro code_xla (args... )
2208
- default_options = Dict {Symbol,Any} (
2209
- :optimize => true ,
2210
- :no_nan => false ,
2211
- :client => nothing ,
2212
- :raise => false ,
2213
- :raise_first => false ,
2214
- :shardy_passes => :(:to_mhlo_shardings ),
2215
- :assert_nonallocating => false ,
2216
- :donated_args => :(:auto ),
2217
- :transpose_propagate => :(:up ),
2218
- :reshape_propagate => :(:up ),
2219
- :optimize_then_pad => true ,
2220
- :optimize_communications => true ,
2221
- :cudnn_hlo_optimize => false ,
2222
- )
2223
2265
compile_expr, (; compiled) = compile_call_expr (
2224
- __module__, compile_xla, default_options , args...
2266
+ __module__, compile_xla, COMMON_COMPILE_OPTIONS , args...
2225
2267
)
2226
2268
# ! format: off
2227
2269
return esc (
@@ -2237,50 +2279,36 @@ end
2237
2279
2238
2280
"""
2239
2281
@compile [optimize = ...] [no_nan = <true/false>] [sync = <true/false>] f(args...)
2282
+
2283
+ Compile the function `f` with arguments `args` and return the compiled function.
2284
+
2285
+ ## Options
2286
+
2287
+ $(COMMON_COMPILE_OPTIONS_DOCS)
2288
+ $(SYNC_DOCS)
2289
+
2290
+ See also [`@jit`](@ref), [`@code_hlo`](@ref), [`@code_mhlo`](@ref), [`@code_xla`](@ref).
2240
2291
"""
2241
2292
macro compile (args... )
2242
- default_options = Dict {Symbol,Any} (
2243
- :optimize => true ,
2244
- :sync => false ,
2245
- :no_nan => false ,
2246
- :client => nothing ,
2247
- :raise => false ,
2248
- :raise_first => false ,
2249
- :shardy_passes => :(:to_mhlo_shardings ),
2250
- :assert_nonallocating => false ,
2251
- :serializable => false ,
2252
- :donated_args => :(:auto ),
2253
- :transpose_propagate => :(:up ),
2254
- :reshape_propagate => :(:up ),
2255
- :optimize_then_pad => true ,
2256
- :optimize_communications => true ,
2257
- :cudnn_hlo_optimize => false ,
2258
- )
2293
+ default_options = merge (COMMON_COMPILE_OPTIONS, Dict {Symbol,Any} (:sync => false ))
2259
2294
return esc (first (compile_call_expr (__module__, compile, default_options, args... )))
2260
2295
end
2261
2296
2262
2297
"""
2263
2298
@jit [optimize = ...] [no_nan = <true/false>] [sync = <true/false>] f(args...)
2264
2299
2265
- Run @compile f(args..) then immediately execute it
2300
+ Run @compile f(args..) then immediately execute it. Most users should use [`@compile`](@ref)
2301
+ instead to cache the compiled function and execute it later.
2302
+
2303
+ ## Options
2304
+
2305
+ $(COMMON_COMPILE_OPTIONS_DOCS)
2306
+ $(SYNC_DOCS)
2307
+
2308
+ See also [`@compile`](@ref), [`@code_hlo`](@ref), [`@code_mhlo`](@ref), [`@code_xla`](@ref).
2266
2309
"""
2267
2310
macro jit (args... )
2268
- default_options = Dict {Symbol,Any} (
2269
- :optimize => true ,
2270
- :sync => false ,
2271
- :no_nan => false ,
2272
- :client => nothing ,
2273
- :raise => false ,
2274
- :raise_first => false ,
2275
- :shardy_passes => :(:to_mhlo_shardings ),
2276
- :assert_nonallocating => false ,
2277
- :donated_args => :(:auto ),
2278
- :transpose_propagate => :(:up ),
2279
- :reshape_propagate => :(:up ),
2280
- :optimize_then_pad => true ,
2281
- :optimize_communications => true ,
2282
- :cudnn_hlo_optimize => false ,
2283
- )
2311
+ default_options = merge (COMMON_COMPILE_OPTIONS, Dict {Symbol,Any} (:sync => false ))
2284
2312
compile_expr, (; compiled, args) = compile_call_expr (
2285
2313
__module__, compile, default_options, args...
2286
2314
)
0 commit comments