Skip to content

Commit 35e6e01

Browse files
fix: warn on using single device sharding (#1274)
* fix: warn on using single device sharding * fix: warn if users are creating single device mesh * Update src/Sharding.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 8b0ba6c commit 35e6e01

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -832,8 +832,14 @@ GenerateCompileOptions(int64_t device_id, const int64_t *mesh_ids,
832832
options.executable_build_options.set_num_replicas(num_replicas);
833833
options.executable_build_options.set_num_partitions(num_partitions);
834834

835-
if (num_replicas > 1 || num_partitions > 1) {
836-
assert(device_id < 0);
835+
if (device_id < 0) {
836+
if (num_replicas == 1 && num_partitions == 1) {
837+
llvm::errs()
838+
<< "[libReactantExtra] num_replicas & num_partitions are both 1, but "
839+
"device_id is negative. This can happen if you are sharding with "
840+
"a single device.\n";
841+
}
842+
837843
assert(num_replicas * num_partitions == num_mesh_ids);
838844

839845
options.executable_build_options.set_use_spmd_partitioning(

src/Sharding.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@ struct Mesh{D,ID<:AbstractVector{Int}}
5454
axis_names::NTuple{D,Union{String,Symbol}},
5555
axis_sizes::Dims{D},
5656
) where {D}
57+
@assert length(logical_device_ids) 1
58+
if length(logical_device_ids) == 1
59+
@warn "Constructing a single device mesh is not well supported and is \
60+
equivalent to not specifying any sharding. If you want to mock \
61+
multi-device setup on a single cpu host, set the environment variable \
62+
XLA_FLAGS=\"--xla_force_host_platform_device_count=12\" before loading \
63+
Reactant.jl and force reactant to use `cpu` devices using \
64+
`Reactant.set_default_backend(\"cpu\")`." maxlog = 1
65+
end
5766
return new{D,typeof(logical_device_ids)}(
5867
sorted_device_ids, logical_device_ids, Symbol.(axis_names), axis_sizes
5968
)

0 commit comments

Comments
 (0)