Skip to content

Commit 54d48ae

Browse files
authored
Broadcast-based allgather in host for-loop (#5925)
<img width="1680" height="250" alt="Screenshot 2026-02-09 at 1 24 11 PM" src="https://github.com/user-attachments/assets/f439517d-3533-4d05-b15f-6c02fea731bf" /> The broadcast version is very slow so I am not comparing timings until we integrate this with multicast
1 parent dd866a7 commit 54d48ae

File tree

9 files changed

+241
-55
lines changed

9 files changed

+241
-55
lines changed

csrc/host_ir/lower_to_communication.cpp

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -137,35 +137,43 @@ void lowerToAllgather(
137137
backend));
138138
}
139139

140-
// Adds one or zero Broadcast communication to the vector 'comms'
140+
// Either of the following cases is happening:
141+
// 1. Same mesh: a broadcast-based allgather in a host for loop. `root` is the
142+
// for-loop index.
143+
// 2. Different meshes: we pick the first device in the sender mesh as root.
141144
void lowerToBroadcast(
142145
TensorView* input_tv,
143146
TensorView* output_tv,
144147
const CommunicatorBackend backend,
148+
Val* root,
145149
std::vector<Expr*>& comms) {
146-
// Either of the following two cases is happening.
147-
// 1. `sender_mesh` contains only one device. In this case, we broadcast
148-
// from that device.
149-
// 2. `sender_mesh` contains multiple devices but the input is not sharded.
150-
// In this case, we arbitrarily choose the first device of the sender mesh
151-
// to be the root.
152150
const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
153151
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
154152

155-
NVF_ERROR_EQ(sender_mesh.rank(), 1, "sender: ", input_tv);
156-
NVF_ERROR_EQ(receiver_mesh.rank(), 1, "receiver: ", output_tv);
157-
158-
DeviceIdxType root = sender_mesh.at(0);
159153
Team team = receiver_mesh.vector();
160-
if (!receiver_mesh.has(root)) {
161-
team.push_back(root);
154+
155+
if (sender_mesh == receiver_mesh) {
156+
NVF_ERROR(
157+
root != nullptr,
158+
"Root must be provided for broadcast-based allgather in a host for "
159+
"loop.");
160+
} else {
161+
NVF_ERROR_EQ(sender_mesh.rank(), 1, "sender: ", input_tv);
162+
NVF_ERROR_EQ(receiver_mesh.rank(), 1, "receiver: ", output_tv);
163+
DeviceIdxType root_device = sender_mesh.at(0);
164+
if (!receiver_mesh.has(root_device)) {
165+
team.push_back(root_device);
166+
}
167+
root = IrBuilder::create<Val>(
168+
getRelativeIndex(team, root_device), DataType::Index);
162169
}
170+
163171
comms.push_back(IrBuilder::create<Communication>(
164172
CommunicationType::Broadcast,
165173
output_tv,
166174
input_tv,
167175
team,
168-
getRelativeIndex(team, root),
176+
root,
169177
c10d::ReduceOp::RedOpType::UNUSED,
170178
backend));
171179
}
@@ -356,11 +364,15 @@ std::optional<CommunicationInfo> getCommunicationInfoForParallelType(
356364
pairwise_map.mapConsumerToProducer();
357365

358366
IterDomain* p_loop_id = getShardedIterDomain(producer, pt, DomainType::kLoop);
359-
IterDomain* c_loop_id = getShardedIterDomain(consumer, pt, DomainType::kLoop);
360367
IterDomain* p_logical_id =
361368
p_loop_id ? getLogicalFromLoopId(producer, p_loop_id) : nullptr;
369+
IterDomain* c_loop_id = getShardedIterDomain(consumer, pt, DomainType::kLoop);
362370
IterDomain* c_logical_id =
363371
c_loop_id ? getLogicalFromLoopId(consumer, c_loop_id) : nullptr;
372+
IterDomain* c_stream_id =
373+
getShardedIterDomain(consumer, ParallelType::Stream, DomainType::kLoop);
374+
IterDomain* c_logical_stream_id =
375+
c_stream_id ? getLogicalFromLoopId(consumer, c_stream_id) : nullptr;
364376

365377
const DeviceMesh& producer_mesh = producer->getDeviceMesh();
366378
const DeviceMesh& consumer_mesh = consumer->getDeviceMesh();
@@ -381,6 +393,19 @@ std::optional<CommunicationInfo> getCommunicationInfoForParallelType(
381393
}
382394

383395
if (p_loop_id && !c_loop_id) {
396+
// Check if we are going from DID -> Stream, which is a ring allgather.
397+
// This can be executed as a broadcast or send recvs, which is decided
398+
// by the presence of a swizzle in the stream id definition.
399+
if (c_logical_stream_id == p2c.at(p_logical_id)) {
400+
NVF_CHECK(
401+
same_mesh,
402+
"Broadcast based allgather in stream parallel requires same "
403+
"mesh.")
404+
return CommunicationInfo{
405+
.type = CommunicationType::Broadcast,
406+
.p_sharded_id = p_logical_id,
407+
.c_sharded_id = c_logical_stream_id};
408+
}
384409
CommunicationType type =
385410
same_mesh ? CommunicationType::Allgather : CommunicationType::Gather;
386411
return CommunicationInfo{
@@ -563,6 +588,7 @@ bool isCommunicationLayoutCompliant(Expr* e) {
563588
std::vector<Expr*> convertSingleOpToCommunication(
564589
Expr* e,
565590
DeviceIdxType my_device_idx,
591+
Val* root,
566592
const CommunicatorBackend backend) {
567593
FusionGuard fg(e->fusion());
568594

@@ -617,7 +643,7 @@ std::vector<Expr*> convertSingleOpToCommunication(
617643
lowerToAllgather(input_tv, output_tv, backend, comms, my_device_idx);
618644
break;
619645
case CommunicationType::Broadcast:
620-
lowerToBroadcast(input_tv, output_tv, backend, comms);
646+
lowerToBroadcast(input_tv, output_tv, backend, root, comms);
621647
break;
622648
case CommunicationType::SendRecv:
623649
lowerToSendRecv(input_tv, output_tv, backend, comms);

csrc/host_ir/lower_to_communication.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,15 @@ Layout getCommunicationLayout(
5555
const CommunicationType type,
5656
IterDomain* sharded_id);
5757

58+
// Creates a communication expr corresponding to the given
59+
// resharding expr. In most cases, `root` is inferred based
60+
// on communication type. However, in some cases, for e.g.
61+
// decomposing allgather as broadcast in a host for-loop, `root`
62+
// may be passed in through lowering.
5863
std::vector<Expr*> convertSingleOpToCommunication(
5964
Expr* c,
6065
DeviceIdxType my_device_idx,
66+
Val* root = nullptr,
6167
const CommunicatorBackend backend = CommunicatorBackend::kNccl);
6268

6369
} // namespace nvfuser

csrc/host_ir/lowering.cpp

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -192,10 +192,16 @@ void lowerSegment(
192192
out,
193193
DomainType::kLoop,
194194
{ParallelType::Stream})) {
195-
Val*& sharded_in = replacement_map[in];
196-
if (sharded_in == nullptr) {
197-
sharded_in = hir::shardByStream(in, innermost.loop->index(), e);
198-
innermost_scope.pushBack(sharded_in->definition());
195+
if (!replacement_map.contains(in)) {
196+
TensorView* sharded_in =
197+
hir::shardByStream(in, innermost.loop->index(), e);
198+
if (sharded_in != nullptr) {
199+
// `sharded_in` is nullptr if the input cannot be sharded by
200+
// stream such as in broadcast or collective-permute based
201+
// decomposition of allgather.
202+
replacement_map[in] = sharded_in;
203+
innermost_scope.pushBack(sharded_in->definition());
204+
}
199205
}
200206
}
201207

@@ -207,15 +213,23 @@ void lowerSegment(
207213
out, ParallelType::Stream, DomainType::kAllocation) == nullptr) {
208214
innermost.parent_scope->insert(
209215
innermost.parent_insertion_point, allocate);
210-
auto [i, inserted] = replacement_map.emplace(
211-
out, hir::shardByStream(out, innermost.loop->index(), e));
212-
NVF_ERROR(inserted, "The input segmented fusion should be SSA.");
213-
innermost_scope.pushBack(i->second->definition());
216+
NVF_ERROR(
217+
!replacement_map.contains(out),
218+
"The input segmented fusion should be SSA.");
219+
TensorView* sharded_out =
220+
hir::shardByStream(out, innermost.loop->index(), e);
221+
NVF_ERROR(
222+
sharded_out != nullptr,
223+
"Output could not be sharded by stream: ",
224+
out);
225+
replacement_map[out] = sharded_out;
226+
innermost_scope.pushBack(sharded_out->definition());
214227
} else {
215228
innermost_scope.pushBack(allocate);
216229
}
217230

218-
for (Expr* c : convertSingleOpToCommunication(e, device_id)) {
231+
Val* root = loop_nest.empty() ? nullptr : innermost.loop->index();
232+
for (Expr* c : convertSingleOpToCommunication(e, device_id, root)) {
219233
NVF_ERROR(
220234
c->isA<Communication>(),
221235
"Exprs in a Communication group should be Communication: ",
@@ -298,6 +312,10 @@ void lowerSegment(
298312
{ParallelType::Stream})) {
299313
TensorView* sharded_in =
300314
hir::shardByStream(in, innermost.loop->index(), e);
315+
NVF_ERROR(
316+
sharded_in != nullptr,
317+
"Input could not be sharded by stream: ",
318+
in);
301319
replacement_map[in] = sharded_in;
302320
innermost_scope.pushBack(sharded_in->definition());
303321
}
@@ -318,6 +336,10 @@ void lowerSegment(
318336
// `out` should be allocated outside the loop.
319337
TensorView* sharded_out =
320338
hir::shardByStream(out, innermost.loop->index(), e);
339+
NVF_ERROR(
340+
sharded_out != nullptr,
341+
"Output could not be sharded by stream: ",
342+
out);
321343
replacement_map[out] = sharded_out;
322344
innermost_scope.pushBack(sharded_out->definition());
323345
}

csrc/host_ir/ops.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@ TensorView* shardByStream(TensorView* source, Val* stream_index, Expr* e) {
3535
ops::newValLike(source, source->getDataType())->as<TensorView>();
3636

3737
if (std::ranges::find(e->inputs(), source) != e->inputs().end()) {
38-
// Propagate the allocation domain from `source` to `destination`.
39-
// Consider adding a config to TransformReplay::selfReplay to control what
40-
// to propagate, so we don't have to reset the loop domain.
38+
// Propagate the domain from `source` to `destination`.
39+
// Unparallelize the destination on `ParallelType::Stream` which
40+
// will be inferred based on the output of the expression.
4141
TransformReplay::selfReplay(source->domain(), destination->domain());
42-
destination->setLoopDomain(destination->getLogicalDomain());
42+
unparallelize(destination, {ParallelType::Stream});
4343

44-
// Propagate the loop domain from `e` to `destination`. There are two
44+
// Propagate ParallelType::Stream from `e` to `destination`. There are two
4545
// technical challenges:
4646
// 1. Loop domains are associated with TensorViews, not Exprs. So we
4747
// find e's reference output, `ref_out`, and propagate its loop domain.
@@ -58,7 +58,7 @@ TensorView* shardByStream(TensorView* source, Val* stream_index, Expr* e) {
5858
shardLoopLike(
5959
ref_out,
6060
destination,
61-
deviceAndStreamParallelTypes(),
61+
{ParallelType::Stream},
6262
PropagateDirection::kBackward);
6363
temp_e->fusion()->removeExpr(temp_e);
6464
// Fusion::removeExpr sets all outputs' definitions to nullptr, so we need
@@ -68,6 +68,14 @@ TensorView* shardByStream(TensorView* source, Val* stream_index, Expr* e) {
6868
for (auto* out : e->outputs()) {
6969
out->setDefinition(e);
7070
}
71+
72+
// Destination's loop domain may not be stream-parallelized if the
73+
// corresponding id is already sharded such as in
74+
// broadcast/collective-permute based decomposition of allgather.
75+
if (getShardedIterDomain(
76+
destination, ParallelType::Stream, DomainType::kLoop) == nullptr) {
77+
return nullptr;
78+
}
7179
} else {
7280
NVF_ERROR(
7381
std::ranges::find(e->outputs(), source) != e->outputs().end(),
@@ -89,8 +97,10 @@ TensorView* shardByStream(TensorView* source, Val* stream_index, Expr* e) {
8997
destination, ParallelType::Stream, DomainType::kAllocation) !=
9098
nullptr,
9199
"Destination allocation should be sharded on stream after "
92-
"shardAllocationAsLoop: ",
93-
destination);
100+
"shardAllocationAsLoop. ",
101+
destination->name(),
102+
":",
103+
destination->domain()->toString(0, /*loop_only=*/false));
94104

95105
// Refine the contiguity flags so `out` aliases `in`. This is done similar
96106
// to AliasFinder::handle(const SliceOp*). We scan through the allocation

csrc/host_ir/ops.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,20 @@
2121
namespace nvfuser::hir {
2222

2323
// Creates a ShardByStream without needing the destination TensorView. Returns
24-
// the destination TensorView. `e` is the Expr from which we propagate the loop
25-
// domain from. `source` must be either an input or an output of `e`. The
26-
// destination TensorView will have a loop domain that's consistent with `e` and
27-
// an allocation domain that's a shard of `source`.
24+
// the destination TensorView. `e` is the Expr from which we propagate
25+
// `ParallelType::Stream` domain from. `source` must be either an input or an
26+
// output of `e`. The destination TensorView will have a `ParallelType::Stream`
27+
// domain that's consistent with `e` and an allocation domain that's a shard of
28+
// `source`.
2829
//
2930
// Why is `e` unnecessary? I made a mistake previously to propagate `source`'s
3031
// loop domain to `destination`. This broke
3132
// test_stream.py::test_two_matmuls_not_inlinable because, when `source` is an
3233
// input of `e`, `source`'s loop domain reflects its producing Expr rather than
3334
// `e`.
35+
// If `destination` cannot be sharded by `ParallelType::Stream`, returns
36+
// nullptr. For e.g.: in decomposed allgather, we go from DIDx -> Stream.
37+
// `destination` is already sharded on `DIDx`
3438
TensorView* shardByStream(TensorView* source, Val* stream_index, Expr* e);
3539

3640
} // namespace nvfuser::hir

csrc/host_ir/pass/convert_op_to_communication.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@ void ConvertOpToCommunication::passImplementation(Fusion* fusion) {
3535
return new_top_level_exprs.push_back(top_level_expr);
3636
}
3737
for (auto* expr : nvfuser::convertSingleOpToCommunication(
38-
top_level_expr, my_device_index, params_.communicator_backend)) {
38+
top_level_expr,
39+
my_device_index,
40+
/*root=*/nullptr,
41+
params_.communicator_backend)) {
3942
// Allocate the recv buffers of communications
4043
if (expr->isA<Communication>()) {
4144
auto* communication = expr->as<Communication>();

csrc/multidevice/propagation.cpp

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -350,26 +350,28 @@ void canonicalizeLoopDomain(TensorView* tv) {
350350
{tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()},
351351
{tv->getLoopDomain().begin(), tv->getLoopDomain().end()}) |
352352
std::views::reverse) {
353-
auto* split = dynamic_cast<Split*>(transform);
354-
NVF_ERROR(
355-
split != nullptr,
356-
"Only splits are expected so far, but found: ",
357-
transform);
358-
359-
if (split->outer()->isParallelized() || split->inner()->isParallelized()) {
353+
if (std::ranges::any_of(
354+
ir_utils::filterByType<IterDomain>(transform->outputs()),
355+
[&loop](IterDomain* id) {
356+
return id->isParallelized() || !loop.contains(id);
357+
})) {
360358
continue;
361359
}
362-
363-
if (!loop.contains(split->outer()) || !loop.contains(split->inner())) {
360+
if (auto* swizzle1d = dynamic_cast<Swizzle1D*>(transform)) {
361+
auto it = loop.erase(swizzle1d->out()).second;
362+
loop.insert(it, swizzle1d->in(), std::monostate());
364363
continue;
365364
}
366-
367-
loop.erase(split->outer());
368-
const auto inner_i = loop.erase(split->inner()).second;
369-
// `inner_i` is picked arbitrarily as the insertion point. Given `in`,
370-
// `outer` and `inner` are all serial, `in`'s position in the loop domain
371-
// doesn't matter.
372-
loop.insert(inner_i, split->in(), std::monostate());
365+
if (auto* split = dynamic_cast<Split*>(transform)) {
366+
loop.erase(split->outer());
367+
const auto inner_i = loop.erase(split->inner()).second;
368+
// `inner_i` is picked arbitrarily as the insertion point. Given `in`,
369+
// `outer` and `inner` are all serial, `in`'s position in the loop domain
370+
// doesn't matter.
371+
loop.insert(inner_i, split->in(), std::monostate());
372+
continue;
373+
}
374+
NVF_THROW("Expected a swizzle1d or split transform. Got: ", transform);
373375
}
374376

375377
auto new_loop = std::views::keys(loop);

0 commit comments

Comments
 (0)