@@ -137,35 +137,43 @@ void lowerToAllgather(
137137 backend));
138138}
139139
140- // Adds one or zero Broadcast communication to the vector 'comms'
140+ // Either of the following cases is happening:
141+ // 1. Same mesh: a broadcast-based allgather in a host for loop. `root` is the
142+ // for-loop index.
143+ // 2. Different meshes: we pick the first device in the sender mesh as root.
141144void lowerToBroadcast (
142145 TensorView* input_tv,
143146 TensorView* output_tv,
144147 const CommunicatorBackend backend,
148+ Val* root,
145149 std::vector<Expr*>& comms) {
146- // Either of the following two cases is happening.
147- // 1. `sender_mesh` contains only one device. In this case, we broadcast
148- // from that device.
149- // 2. `sender_mesh` contains multiple devices but the input is not sharded.
150- // In this case, we arbitrarily choose the first device of the sender mesh
151- // to be the root.
152150 const DeviceMesh& sender_mesh = input_tv->getDeviceMesh ();
153151 const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh ();
154152
155- NVF_ERROR_EQ (sender_mesh.rank (), 1 , " sender: " , input_tv);
156- NVF_ERROR_EQ (receiver_mesh.rank (), 1 , " receiver: " , output_tv);
157-
158- DeviceIdxType root = sender_mesh.at (0 );
159153 Team team = receiver_mesh.vector ();
160- if (!receiver_mesh.has (root)) {
161- team.push_back (root);
154+
155+ if (sender_mesh == receiver_mesh) {
156+ NVF_ERROR (
157+ root != nullptr ,
158+ " Root must be provided for broadcast-based allgather in a host for "
159+ " loop." );
160+ } else {
161+ NVF_ERROR_EQ (sender_mesh.rank (), 1 , " sender: " , input_tv);
162+ NVF_ERROR_EQ (receiver_mesh.rank (), 1 , " receiver: " , output_tv);
163+ DeviceIdxType root_device = sender_mesh.at (0 );
164+ if (!receiver_mesh.has (root_device)) {
165+ team.push_back (root_device);
166+ }
167+ root = IrBuilder::create<Val>(
168+ getRelativeIndex (team, root_device), DataType::Index);
162169 }
170+
163171 comms.push_back (IrBuilder::create<Communication>(
164172 CommunicationType::Broadcast,
165173 output_tv,
166174 input_tv,
167175 team,
168- getRelativeIndex (team, root) ,
176+ root,
169177 c10d::ReduceOp::RedOpType::UNUSED,
170178 backend));
171179}
@@ -356,11 +364,15 @@ std::optional<CommunicationInfo> getCommunicationInfoForParallelType(
356364 pairwise_map.mapConsumerToProducer ();
357365
358366 IterDomain* p_loop_id = getShardedIterDomain (producer, pt, DomainType::kLoop );
359- IterDomain* c_loop_id = getShardedIterDomain (consumer, pt, DomainType::kLoop );
360367 IterDomain* p_logical_id =
361368 p_loop_id ? getLogicalFromLoopId (producer, p_loop_id) : nullptr ;
369+ IterDomain* c_loop_id = getShardedIterDomain (consumer, pt, DomainType::kLoop );
362370 IterDomain* c_logical_id =
363371 c_loop_id ? getLogicalFromLoopId (consumer, c_loop_id) : nullptr ;
372+ IterDomain* c_stream_id =
373+ getShardedIterDomain (consumer, ParallelType::Stream, DomainType::kLoop );
374+ IterDomain* c_logical_stream_id =
375+ c_stream_id ? getLogicalFromLoopId (consumer, c_stream_id) : nullptr ;
364376
365377 const DeviceMesh& producer_mesh = producer->getDeviceMesh ();
366378 const DeviceMesh& consumer_mesh = consumer->getDeviceMesh ();
@@ -381,6 +393,19 @@ std::optional<CommunicationInfo> getCommunicationInfoForParallelType(
381393 }
382394
383395 if (p_loop_id && !c_loop_id) {
396+ // Check if we are going from DID -> Stream, which is a ring allgather.
397+ // This can be executed as a broadcast or send recvs, which is decided
398+ // by the presence of a swizzle in the stream id definition.
399+ if (c_logical_stream_id == p2c.at (p_logical_id)) {
400+ NVF_CHECK (
401+ same_mesh,
402+ " Broadcast based allgather in stream parallel requires same "
403+ " mesh." )
404+ return CommunicationInfo{
405+ .type = CommunicationType::Broadcast,
406+ .p_sharded_id = p_logical_id,
407+ .c_sharded_id = c_logical_stream_id};
408+ }
384409 CommunicationType type =
385410 same_mesh ? CommunicationType::Allgather : CommunicationType::Gather;
386411 return CommunicationInfo{
@@ -563,6 +588,7 @@ bool isCommunicationLayoutCompliant(Expr* e) {
563588std::vector<Expr*> convertSingleOpToCommunication (
564589 Expr* e,
565590 DeviceIdxType my_device_idx,
591+ Val* root,
566592 const CommunicatorBackend backend) {
567593 FusionGuard fg (e->fusion ());
568594
@@ -617,7 +643,7 @@ std::vector<Expr*> convertSingleOpToCommunication(
617643 lowerToAllgather (input_tv, output_tv, backend, comms, my_device_idx);
618644 break ;
619645 case CommunicationType::Broadcast:
620- lowerToBroadcast (input_tv, output_tv, backend, comms);
646+ lowerToBroadcast (input_tv, output_tv, backend, root, comms);
621647 break ;
622648 case CommunicationType::SendRecv:
623649 lowerToSendRecv (input_tv, output_tv, backend, comms);
0 commit comments