@@ -2127,21 +2127,23 @@ function compile_mlir!(
2127
2127
)
2128
2128
end
2129
2129
2130
- const COMMON_COMPILE_OPTIONS = Dict {Symbol,Any} (
2131
- :optimize => true ,
2132
- :no_nan => false ,
2133
- :client => nothing ,
2134
- :raise => false ,
2135
- :raise_first => false ,
2136
- :shardy_passes => :(:to_mhlo_shardings ),
2137
- :assert_nonallocating => false ,
2138
- :donated_args => :(:auto ),
2139
- :transpose_propagate => :(:up ),
2140
- :reshape_propagate => :(:up ),
2141
- :optimize_then_pad => true ,
2142
- :optimize_communications => true ,
2143
- :cudnn_hlo_optimize => false ,
2144
- )
2130
+ function get_common_compile_options ()
2131
+ return Dict {Symbol,Any} (
2132
+ :optimize => true ,
2133
+ :no_nan => false ,
2134
+ :client => nothing ,
2135
+ :raise => false ,
2136
+ :raise_first => false ,
2137
+ :shardy_passes => :(:to_mhlo_shardings ),
2138
+ :assert_nonallocating => false ,
2139
+ :donated_args => :(:auto ),
2140
+ :transpose_propagate => :(:up ),
2141
+ :reshape_propagate => :(:up ),
2142
+ :optimize_then_pad => true ,
2143
+ :optimize_communications => true ,
2144
+ :cudnn_hlo_optimize => false ,
2145
+ )
2146
+ end
2145
2147
2146
2148
const COMMON_COMPILE_OPTIONS_DOCS = """
2147
2149
- `optimize`: Optimizations passes to run on the traced MLIR code. Valid types of values
@@ -2212,7 +2214,7 @@ See also [`@code_xla`](@ref), [`@code_mhlo`](@ref).
2212
2214
"""
2213
2215
macro code_hlo (args... )
2214
2216
compile_expr, (; compiled) = compile_call_expr (
2215
- __module__, compile_mlir, COMMON_COMPILE_OPTIONS , args...
2217
+ __module__, compile_mlir, get_common_compile_options () , args...
2216
2218
)
2217
2219
# ! format: off
2218
2220
return esc (
@@ -2237,7 +2239,7 @@ See also [`@code_xla`](@ref), [`@code_hlo`](@ref).
2237
2239
"""
2238
2240
macro code_mhlo (args... )
2239
2241
compile_expr, (; compiled) = compile_call_expr (
2240
- __module__, compile_xla, COMMON_COMPILE_OPTIONS , args...
2242
+ __module__, compile_xla, get_common_compile_options () , args...
2241
2243
)
2242
2244
# ! format: off
2243
2245
return esc (
@@ -2263,7 +2265,7 @@ See also [`@code_mhlo`](@ref), [`@code_hlo`](@ref).
2263
2265
"""
2264
2266
macro code_xla (args... )
2265
2267
compile_expr, (; compiled) = compile_call_expr (
2266
- __module__, compile_xla, COMMON_COMPILE_OPTIONS , args...
2268
+ __module__, compile_xla, get_common_compile_options () , args...
2267
2269
)
2268
2270
# ! format: off
2269
2271
return esc (
@@ -2290,7 +2292,7 @@ $(SYNC_DOCS)
2290
2292
See also [`@jit`](@ref), [`@code_hlo`](@ref), [`@code_mhlo`](@ref), [`@code_xla`](@ref).
2291
2293
"""
2292
2294
macro compile (args... )
2293
- default_options = merge (COMMON_COMPILE_OPTIONS , Dict {Symbol,Any} (:sync => false ))
2295
+ default_options = merge (get_common_compile_options () , Dict {Symbol,Any} (:sync => false ))
2294
2296
return esc (first (compile_call_expr (__module__, compile, default_options, args... )))
2295
2297
end
2296
2298
@@ -2308,7 +2310,7 @@ $(SYNC_DOCS)
2308
2310
See also [`@compile`](@ref), [`@code_hlo`](@ref), [`@code_mhlo`](@ref), [`@code_xla`](@ref).
2309
2311
"""
2310
2312
macro jit (args... )
2311
- default_options = merge (COMMON_COMPILE_OPTIONS , Dict {Symbol,Any} (:sync => false ))
2313
+ default_options = merge (get_common_compile_options () , Dict {Symbol,Any} (:sync => false ))
2312
2314
compile_expr, (; compiled, args) = compile_call_expr (
2313
2315
__module__, compile, default_options, args...
2314
2316
)
0 commit comments