Skip to content

Commit 48ec28b

Browse files
authored
fix: default to shardy again (#1512)
* fix: default to shardy again * chore: run fmt * fix: reshard to collectives
1 parent 1e68728 commit 48ec28b

File tree

5 files changed

+18
-20
lines changed

5 files changed

+18
-20
lines changed

src/CompileOptions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ Fine-grained control over the compilation options for the Reactant compiler.
117117
118118
## Sharding Options
119119
120-
- `shardy_passes`: Defaults to `:to_mhlo_shardings`. Other options are:
120+
- `shardy_passes`: Defaults to `:post_sdy_propagation`. Other options are:
121121
- `:none`: No sharding passes will be run. Shardy + MHLO shardings are handled by XLA.
122122
- `:post_sdy_propagation`: Runs the Shardy propagation passes. MHLO shardings are
123123
handled by XLA.

src/Compiler.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,7 +1289,7 @@ function __get_compile_options_and_kwargs(;
12891289
raise_first::Bool=false,
12901290
legalize_chlo_to_stablehlo::Bool=false,
12911291
cudnn_hlo_optimize::Bool=false,
1292-
shardy_passes::Union{Symbol,ShardyPropagationOptions}=:to_mhlo_shardings,
1292+
shardy_passes::Union{Symbol,ShardyPropagationOptions}=:post_sdy_propagation,
12931293
optimize_then_pad::Bool=true,
12941294
optimize_communications::Union{Bool,OptimizeCommunicationOptions}=true,
12951295
assert_nonallocating::Bool=false,
@@ -2132,6 +2132,7 @@ function compile_mlir!(
21322132
get_optimize_comms_passes(
21332133
compile_options.optimize_communications
21342134
)...,
2135+
"func.func(sdy-reshard-to-collectives)",
21352136
],
21362137
",",
21372138
),
@@ -2165,6 +2166,7 @@ function compile_mlir!(
21652166
get_optimize_comms_passes(
21662167
compile_options.optimize_communications
21672168
)...,
2169+
"func.func(sdy-reshard-to-collectives)",
21682170
"xla-sdy-stablehlo-export-pipeline",
21692171
],
21702172
",",
@@ -2319,7 +2321,7 @@ function get_common_compile_options()
23192321
:client => nothing,
23202322
:raise => false,
23212323
:raise_first => false,
2322-
:shardy_passes => :(:to_mhlo_shardings),
2324+
:shardy_passes => :(:post_sdy_propagation),
23232325
:assert_nonallocating => false,
23242326
:donated_args => :(:auto),
23252327
:transpose_propagate => :(:up),
@@ -2362,7 +2364,10 @@ See also [`@code_xla`](@ref), [`@code_mhlo`](@ref).
23622364
"""
23632365
macro code_hlo(args...)
23642366
compile_expr, (; compiled) = compile_call_expr(
2365-
__module__, compile_mlir, get_common_compile_options(), args...
2367+
__module__,
2368+
compile_mlir,
2369+
merge(get_common_compile_options(), Dict{Symbol,Any}(:shardy_passes => :(:none))),
2370+
args...,
23662371
)
23672372
#! format: off
23682373
return esc(
@@ -2391,7 +2396,9 @@ macro code_mhlo(args...)
23912396
compile_mlir,
23922397
merge(
23932398
get_common_compile_options(),
2394-
Dict{Symbol,Any}(:legalize_stablehlo_to_mhlo => true),
2399+
Dict{Symbol,Any}(
2400+
:legalize_stablehlo_to_mhlo => true, :shardy_passes => :(:to_mhlo_shardings)
2401+
),
23952402
),
23962403
args...,
23972404
)

src/ConcreteRArray.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,8 @@ function write_to_host_buffer!(data::Array, X::ConcretePJRTArray{T,N}) where {T,
118118
completed = Set{eltype(X.sharding.device_to_array_slices)}()
119119
for idx in 1:length(X.data)
120120
slice = X.sharding.device_to_array_slices[idx]
121-
if slice completed
122-
push!(completed, slice)
123-
else
124-
continue
125-
end
121+
slice completed && continue
122+
push!(completed, slice)
126123
data_slice = data[slice...]
127124
XLA.to_host(X.data[idx], data_slice, Reactant.Sharding.NoSharding())
128125
data[slice...] .= data_slice

src/xla/IFRT/Array.jl

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -354,13 +354,7 @@ function replicate_array_to_all_devices(array::Array, sharding, mesh, size_arr)
354354
Reactant.Compiler.run_pass_pipeline!(
355355
mod,
356356
join(
357-
[
358-
"sdy-propagation-pipeline",
359-
"sdy-close-shardings",
360-
"xla-sdy-stablehlo-export-pipeline",
361-
"canonicalize",
362-
"cse",
363-
],
357+
["sdy-propagation-pipeline", "sdy-close-shardings", "canonicalize", "cse"],
364358
",",
365359
),
366360
)
@@ -375,7 +369,7 @@ function replicate_array_to_all_devices(array::Array, sharding, mesh, size_arr)
375369
num_partitions=length(mesh.device_ids),
376370
num_outputs=1, # unused
377371
num_parameters=1, # unused
378-
use_shardy_partitioner=false, # unused
372+
use_shardy_partitioner=true, # unused
379373
)
380374

381375
only(XLA.execute(exec, (array.buffer,), (UInt8(0),), Val(1)))

test/sharding.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ end
209209
@test contains(repr(hlo), "sharding_constraint")
210210
hlo = @code_hlo shardy_passes = :to_mhlo_shardings fn_with_constraint(x_ra)
211211
@test !contains(repr(hlo), "sharding_constraint")
212-
@test length(collect(eachmatch(r"mhlo.sharding", repr(hlo)))) == 6
212+
@test length(collect(eachmatch(r"mhlo.sharding", repr(hlo)))) == 5
213213

214214
z = Reactant.to_rarray(x; sharding=constraint)
215215
res = @jit fn_with_constraint(x_ra)
@@ -234,7 +234,7 @@ end
234234
x_ra_no_sharding
235235
)
236236
@test !contains(repr(hlo), "sharding_constraint")
237-
@test length(collect(eachmatch(r"mhlo.sharding", repr(hlo)))) == 6
237+
@test length(collect(eachmatch(r"mhlo.sharding", repr(hlo)))) == 5
238238

239239
res = @jit fn_with_constraint(x_ra_no_sharding)
240240
@test x .+ x Array(res)

0 commit comments

Comments
 (0)