1
- # TODO : make the other optimize options into a struct as well
1
+ # TODO : document these options at some point
2
+ """
3
+ OptimizeCommunicationOptions
4
+
5
+ Fine-grained control over the optimization passes that rewrite ops to minimize collective
6
+ communication.
7
+ """
2
8
@kwdef struct OptimizeCommunicationOptions
3
9
periodic_concat:: Int = 0
4
10
rotate_comm:: Int = 0
@@ -20,3 +26,285 @@ function Base.String(options::OptimizeCommunicationOptions)
20
26
" }"
21
27
)
22
28
end
29
+
30
+ """
31
+ ShardyPropagationOptions
32
+
33
+ Fine-grained control over the sharding propagation pipeline. For more information on
34
+ sharding propagation, see the
35
+ [Shardy Docs](https://openxla.org/shardy/sdy_propagation_passes).
36
+
37
+ ## Options
38
+
39
+ - `keep_sharding_rules::Bool`: whether to keep existing and created op sharding rules.
40
+ - `conservative_propagation::Bool`: whether to disallow split axes and non-divisible
41
+ sharding axes during propagation.
42
+ - `debug_sharding_origins::Bool`: whether to save information about the origin of a
43
+ sharding on the MLIR module. These would be the shardings on the function inputs,
44
+ outputs, sharding constraints and manual computations before propagation.
45
+ - `debug_propagation_edge_sharding::Bool`: whether to save information about the edge
46
+ source of a sharding on the MLIR module. These are what operand/result introduced a
47
+ sharding on some op result.
48
+ - `skip_convert_to_reshard::Bool`
49
+ - `skip_inline::Bool`
50
+ - `enable_insert_explicit_collectives::Bool`: whether to insert explicit collectives
51
+ for sharding propagation. This is useful for debugging and checking the location of
52
+ the communication ops.
53
+ """
54
+ @kwdef struct ShardyPropagationOptions
55
+ keep_sharding_rules:: Bool = false
56
+ conservative_propagation:: Bool = false
57
+ debug_sharding_origins:: Bool = false
58
+ debug_propagation_edge_sharding:: Bool = false
59
+ skip_convert_to_reshard:: Bool = false
60
+ skip_inline:: Bool = false
61
+ enable_insert_explicit_collectives:: Bool = false
62
+ end
63
+
64
+ """
65
+ CompileOptions
66
+
67
+ Fine-grained control over the compilation options for the Reactant compiler.
68
+
69
+ ## Controlling Optimization Passes
70
+
71
+ - `optimization_passes`: Optimizations passes to run on the traced MLIR code. Valid types
72
+ of values are:
73
+ - Bool (true/false): whether to run the optimization passes or not. Defaults to `true`.
74
+ - String: a custom string with the passes to run. The string should be a comma-separated
75
+ list of MLIR passes. For example, `"canonicalize,enzyme-hlo-opt"`.
76
+ - Symbol: a predefined set of passes to run. Valid options are:
77
+ 1. `:all`: Default set of optimization passes. The exact set of passes are not fixed
78
+ and may change in future versions of Reactant. It is recommended to use this
79
+ option for most users.
80
+ 2. `:none`: No optimization passes will be run.
81
+ 3. Other predefined options are: `:before_kernel`, `:before_jit`, `:before_raise`,
82
+ `:before_enzyme`, `:after_enzyme`, `:just_batch`, `:canonicalize`, `:only_enzyme`.
83
+ - `no_nan`: If `true`, the optimization passes will assume that the function does not
84
+ produce NaN values. This can lead to more aggressive optimizations **(and potentially
85
+ incorrect results if the function does produce NaN values)**.
86
+ - `all_finite`: If `true`, the optimization passes will assume that the function does not
87
+ produce Inf or -Inf values. This can lead to more aggressive optimizations **(and
88
+ potentially incorrect results if the function does produce Inf or -Inf values)**.
89
+ - `transpose_propagate`: If `:up`, `stablehlo.transpose` operations will be
90
+ propagated up the computation graph. If `:down`, they will be propagated down. Defaults
91
+ to `:up`.
92
+ - `reshape_propagate`: If `:up`, `stablehlo.reshape` operations will be propagated up
93
+ the computation graph. If `:down`, they will be propagated down. Defaults to `:up`.
94
+ - `max_constant_threshold`: If the number of elements in a constant is greater than this
95
+ threshold (for a non-splatted constant), we will throw an error.
96
+ - `inline`: If `true`, all functions will be inlined. This is `true` by default.
97
+
98
+ ## Raising Options
99
+
100
+ - `raise`: If `true`, the function will be compiled with the raising pass, which raises
101
+ CUDA and KernelAbstractions kernels to HLO. Defaults to `false`, but is automatically
102
+ activated if the inputs are sharded.
103
+ - `raise_first`: If `true`, the raising pass will be run before the optimization passes.
104
+ Defaults to `false`.
105
+
106
+ ## Dialect Specific Options
107
+
108
+ - `legalize_chlo_to_stablehlo`: If `true`, `chlo` dialect ops will be converted to
109
+ `stablehlo` ops. This is `false` by default.
110
+
111
+ ## Backend Specific Options
112
+
113
+ ### Only for CUDA backend
114
+
115
+ - `cudnn_hlo_optimize`: Run cuDNN specific HLO optimizations. This is only relevant for
116
+ GPU backends and is `false` by default. **Experimental and not heavily tested.**
117
+
118
+ ## Sharding Options
119
+
120
+ - `shardy_passes`: Defaults to `:to_mhlo_shardings`. Other options are:
121
+ - `:none`: No sharding passes will be run. Shardy + MHLO shardings are handled by XLA.
122
+ - `:post_sdy_propagation`: Runs the Shardy propagation passes. MHLO shardings are
123
+ handled by XLA.
124
+ - [`ShardyPropagationOptions`](@ref): Custom sharding propagation options.
125
+ MHLO shardings are handled by XLA.
126
+ - `:to_mhlo_shardings`: Runs the Shardy propagation passes and then exports the
127
+ shardings to MHLO. All passes are run via MLIR pass pipeline and don't involve XLA.
128
+ - `optimize_then_pad`: If `true`, the function will be optimized before padding (for
129
+ non-divisible sharding axes) is applied. Defaults to `true`. _(Only for Sharded Inputs)_
130
+ - `optimize_communications`: If `true`, additional passes for optimizing communication
131
+ in sharded computations will be run. Defaults to `true`. _(Only for Sharded Inputs)_
132
+
133
+ ## Julia Codegen Options
134
+
135
+ - `donated_args`: If `:auto`, the function will automatically donate the arguments that
136
+ are not preserved in the function body. If `:none`, no arguments will be donated.
137
+ Defaults to `:auto`.
138
+ - `assert_nonallocating`: If `true`, we make sure that no new buffers are
139
+ returned by the function. Any buffer returned must be donated from the inputs. Defaults
140
+ to `false`.
141
+ - `sync`: Reactant computations are asynchronous by default. If `true`, the computation
142
+ will be executed synchronously, blocking till the computation is complete. This is
143
+ recommended when benchmarking.
144
+
145
+ # Extended Help
146
+
147
+ ## Private Options
148
+
149
+ !!! warning
150
+
151
+ These options are not part of the public API and are subject to change without any
152
+ notice or deprecation cycle.
153
+
154
+ - `disable_scatter_gather_optimization_passes`: Disables the scatter-gather
155
+ optimization passes. This is `false` by default.
156
+ - `disable_pad_optimization_passes`: Disables the pad optimization passes. This is
157
+ `false` by default.
158
+ """
159
+ struct CompileOptions
160
+ optimization_passes:: Union{Symbol,String}
161
+ no_nan:: Bool
162
+ all_finite:: Bool
163
+ inline:: Bool
164
+ transpose_propagate:: Symbol
165
+ reshape_propagate:: Symbol
166
+ max_constant_threshold:: Int
167
+ # Raising options
168
+ raise:: Union{Bool,String}
169
+ raise_first:: Bool
170
+ # dialect specific options
171
+ legalize_chlo_to_stablehlo:: Bool
172
+ # backend specific options
173
+ cudnn_hlo_optimize:: Bool
174
+ # sharding options
175
+ shardy_passes:: Union{Symbol,ShardyPropagationOptions}
176
+ optimize_then_pad:: Bool
177
+ optimize_communications:: Union{Bool,OptimizeCommunicationOptions}
178
+ # julia codegen options
179
+ assert_nonallocating:: Bool
180
+ donated_args:: Symbol
181
+ sync:: Bool
182
+ # # private options for ablation studies
183
+ disable_scatter_gather_optimization_passes:: Bool
184
+ disable_pad_optimization_passes:: Bool
185
+ end
186
+
187
+ function CompileOptions (;
188
+ optimization_passes:: Union{Bool,Symbol,String} = :all ,
189
+ no_nan:: Bool = false ,
190
+ all_finite:: Bool = false ,
191
+ inline:: Bool = true ,
192
+ transpose_propagate:: Symbol = :up ,
193
+ reshape_propagate:: Symbol = :up ,
194
+ max_constant_threshold:: Int = 1024 ,
195
+ raise:: Union{Bool,String} = false ,
196
+ raise_first:: Bool = false ,
197
+ legalize_chlo_to_stablehlo:: Bool = false ,
198
+ cudnn_hlo_optimize:: Bool = false ,
199
+ shardy_passes:: Union{Symbol,ShardyPropagationOptions} = :to_mhlo_shardings ,
200
+ optimize_then_pad:: Bool = true ,
201
+ optimize_communications:: Union{Bool,OptimizeCommunicationOptions} = true ,
202
+ assert_nonallocating:: Bool = false ,
203
+ donated_args:: Symbol = :auto ,
204
+ sync:: Bool = false ,
205
+ disable_scatter_gather_optimization_passes:: Bool = false ,
206
+ disable_pad_optimization_passes:: Bool = false ,
207
+ )
208
+ optimization_passes isa Bool &&
209
+ (optimization_passes = ifelse (optimization_passes, :all , :none ))
210
+
211
+ if optimization_passes isa Symbol
212
+ @assert optimization_passes in [
213
+ :all ,
214
+ :before_kernel ,
215
+ :before_jit ,
216
+ :before_raise ,
217
+ :no_enzyme ,
218
+ :only_enzyme ,
219
+ :after_enzyme ,
220
+ :before_enzyme ,
221
+ :canonicalize ,
222
+ :just_batch ,
223
+ :none ,
224
+ ]
225
+ end
226
+
227
+ @assert transpose_propagate in [:up , :down , :none ]
228
+ @assert reshape_propagate in [:up , :down , :none ]
229
+
230
+ if shardy_passes isa Symbol
231
+ @assert shardy_passes in [:none , :to_mhlo_shardings , :post_sdy_propagation ]
232
+ end
233
+
234
+ return CompileOptions (
235
+ optimization_passes,
236
+ no_nan,
237
+ all_finite,
238
+ inline,
239
+ transpose_propagate,
240
+ reshape_propagate,
241
+ max_constant_threshold,
242
+ raise,
243
+ raise_first,
244
+ legalize_chlo_to_stablehlo,
245
+ cudnn_hlo_optimize,
246
+ shardy_passes,
247
+ optimize_then_pad,
248
+ optimize_communications,
249
+ assert_nonallocating,
250
+ donated_args,
251
+ sync,
252
+ disable_scatter_gather_optimization_passes,
253
+ disable_pad_optimization_passes,
254
+ )
255
+ end
256
+
257
+ function __compile_options_from_kwags (;
258
+ compile_options:: Union{Missing,CompileOptions} = missing ,
259
+ optimize:: Union{Bool,Symbol,String} = true ,
260
+ kwargs... ,
261
+ )
262
+ compile_options isa CompileOptions && return compile_options
263
+ return CompileOptions (; optimization_passes= optimize, kwargs... )
264
+ end
265
+
266
+ function __reverse_propagation (sym:: Symbol )
267
+ sym == :up && return :down
268
+ sym === :down && return :up
269
+ sym == :none && return :none
270
+ return error (" Invalid value: $sym . Expected :up or :down or :none" )
271
+ end
272
+
273
+ function __compile_options_with_reversed_propagation (compile_options:: CompileOptions )
274
+ return CompileOptions (
275
+ compile_options. optimization_passes,
276
+ compile_options. no_nan,
277
+ compile_options. all_finite,
278
+ compile_options. inline,
279
+ __reverse_propagation (compile_options. transpose_propagate),
280
+ __reverse_propagation (compile_options. reshape_propagate),
281
+ compile_options. max_constant_threshold,
282
+ compile_options. raise,
283
+ compile_options. raise_first,
284
+ compile_options. legalize_chlo_to_stablehlo,
285
+ compile_options. cudnn_hlo_optimize,
286
+ compile_options. shardy_passes,
287
+ compile_options. optimize_then_pad,
288
+ compile_options. optimize_communications,
289
+ compile_options. assert_nonallocating,
290
+ compile_options. donated_args,
291
+ compile_options. sync,
292
+ compile_options. disable_scatter_gather_optimization_passes,
293
+ compile_options. disable_pad_optimization_passes,
294
+ )
295
+ end
296
+
297
+ """
298
+ DefaultXLACompileOptions()
299
+
300
+ Runs specific Enzyme-JAX passes to ensure that the generated code is compatible with
301
+ XLA compilation.
302
+
303
+ !!! warning
304
+
305
+ This is mostly a benchmarking option, and the default [`CompileOptions`](@ref) is almost
306
+ certainly a better option.
307
+ """
308
+ function DefaultXLACompileOptions ()
309
+ return CompileOptions (; optimization_passes= :only_enzyme , inline= false )
310
+ end
0 commit comments