@@ -393,23 +393,21 @@ CommunicationInfo getCommunicationInfo(Expr* e) {
393393
394394 const DeviceMesh& producer_mesh = producer->getDeviceMesh ();
395395 const DeviceMesh& consumer_mesh = consumer->getDeviceMesh ();
396- const bool p_sharded = p_loop_did != nullptr && producer_mesh.size () > 1 ;
397- const bool c_sharded = c_loop_did != nullptr && consumer_mesh.size () > 1 ;
398396 const bool same_mesh = producer_mesh == consumer_mesh;
399397
400398 if (e->isA <LoadStoreOp>()) {
401- if (p_sharded && !c_sharded ) {
399+ if (p_loop_did && !c_loop_did ) {
402400 IterDomain* p_logical_id = getLogicalFromLoopId (producer, p_loop_did);
403401 CommunicationType type = same_mesh ? CommunicationType::Allgather
404402 : CommunicationType::Gather;
405403 fill_communication_info (type, p_logical_id, p2c_map.at (p_logical_id));
406404 }
407- if (!p_sharded && c_sharded ) {
405+ if (!p_loop_did && c_loop_did ) {
408406 IterDomain* c_logical_id = getLogicalFromLoopId (consumer, c_loop_did);
409407 fill_communication_info (
410408 CommunicationType::Scatter, c2p_map.at (c_logical_id), c_logical_id);
411409 }
412- if (p_sharded && c_sharded ) {
410+ if (p_loop_did && c_loop_did ) {
413411 IterDomain* p_logical_id = getLogicalFromLoopId (producer, p_loop_did);
414412 IterDomain* c_logical_id = getLogicalFromLoopId (consumer, c_loop_did);
415413 // TODO(#4604): This is problematic for 2D sharding.
@@ -424,12 +422,12 @@ CommunicationInfo getCommunicationInfo(Expr* e) {
424422 }
425423 } else {
426424 NVF_ERROR (e->isA <ReductionOp>() || e->isA <SqueezeOp>());
427- if (!p_sharded ) {
425+ if (!p_loop_did ) {
428426 // Not a reduction based communication.
429427 continue ;
430428 }
431429
432- if (!c_sharded ) {
430+ if (!c_loop_did ) {
433431 IterDomain* p_logical_id = getLogicalFromLoopId (producer, p_loop_did);
434432 CommunicationType type = same_mesh ? CommunicationType::Allreduce
435433 : CommunicationType::Reduce;
0 commit comments