Skip to content

Commit 0a6adb1

Browse files
authored
Avoid mesh size checks (#5904)
Prefer simplicity over unnecessary performance optimizations. Single-GPU performance for a distributed program is largely irrelevant.
1 parent 23c4134 commit 0a6adb1

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

csrc/host_ir/lower_to_communication.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -393,23 +393,21 @@ CommunicationInfo getCommunicationInfo(Expr* e) {
393393

394394
const DeviceMesh& producer_mesh = producer->getDeviceMesh();
395395
const DeviceMesh& consumer_mesh = consumer->getDeviceMesh();
396-
const bool p_sharded = p_loop_did != nullptr && producer_mesh.size() > 1;
397-
const bool c_sharded = c_loop_did != nullptr && consumer_mesh.size() > 1;
398396
const bool same_mesh = producer_mesh == consumer_mesh;
399397

400398
if (e->isA<LoadStoreOp>()) {
401-
if (p_sharded && !c_sharded) {
399+
if (p_loop_did && !c_loop_did) {
402400
IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did);
403401
CommunicationType type = same_mesh ? CommunicationType::Allgather
404402
: CommunicationType::Gather;
405403
fill_communication_info(type, p_logical_id, p2c_map.at(p_logical_id));
406404
}
407-
if (!p_sharded && c_sharded) {
405+
if (!p_loop_did && c_loop_did) {
408406
IterDomain* c_logical_id = getLogicalFromLoopId(consumer, c_loop_did);
409407
fill_communication_info(
410408
CommunicationType::Scatter, c2p_map.at(c_logical_id), c_logical_id);
411409
}
412-
if (p_sharded && c_sharded) {
410+
if (p_loop_did && c_loop_did) {
413411
IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did);
414412
IterDomain* c_logical_id = getLogicalFromLoopId(consumer, c_loop_did);
415413
// TODO(#4604): This is problematic for 2D sharding.
@@ -424,12 +422,12 @@ CommunicationInfo getCommunicationInfo(Expr* e) {
424422
}
425423
} else {
426424
NVF_ERROR(e->isA<ReductionOp>() || e->isA<SqueezeOp>());
427-
if (!p_sharded) {
425+
if (!p_loop_did) {
428426
// Not a reduction based communication.
429427
continue;
430428
}
431429

432-
if (!c_sharded) {
430+
if (!c_loop_did) {
433431
IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did);
434432
CommunicationType type = same_mesh ? CommunicationType::Allreduce
435433
: CommunicationType::Reduce;

0 commit comments

Comments
 (0)