Skip to content

Commit 96b746a

Browse files
authored
lower_to_communication: includes, error macros, and validation (#5973)
1 parent 64e66ae commit 96b746a

File tree

1 file changed

+43
-35
lines changed

1 file changed

+43
-35
lines changed

csrc/host_ir/lower_to_communication.cpp

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,19 @@
88

99
#include "host_ir/lower_to_communication.h"
1010

11-
#include "host_ir/container.h"
12-
#include "ir/all_nodes.h"
13-
#include "ir/allocation_utils.h"
11+
#include <algorithm>
12+
#include <iterator>
13+
#include <optional>
14+
#include <vector>
15+
1416
#include "ir/builder.h"
1517
#include "ir/internal_base_nodes.h"
1618
#include "ir/iostream.h"
1719
#include "ir/utils.h"
18-
#include "kernel_ir.h"
20+
#include "logical_domain_map.h"
1921
#include "multidevice/post_communication.h"
2022
#include "multidevice/resharding.h"
2123
#include "multidevice/utils.h"
22-
#include "ops/all_ops.h"
2324

2425
namespace nvfuser {
2526

@@ -57,10 +58,11 @@ void lowerToScatter(
5758
const CommunicatorBackend backend,
5859
std::vector<Expr*>& comms) {
5960
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
60-
NVF_ERROR(
61-
receiver_mesh.rank() == 1,
62-
"Gather only supported on a 1D mesh. Given ",
63-
receiver_mesh);
61+
NVF_ERROR_EQ(
62+
receiver_mesh.rank(),
63+
1,
64+
"Scatter only supported on a 1D mesh. Given ",
65+
output_tv);
6466

6567
// Find a common device between input and receiver meshes to be the root
6668
std::vector<DeviceIdxType> input_devices = input_tv->getDeviceMesh().vector();
@@ -149,8 +151,8 @@ void lowerToBroadcast(
149151
const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
150152
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
151153

152-
NVF_ERROR_EQ(sender_mesh.rank(), 1, "sender: ", input_tv->toString());
153-
NVF_ERROR_EQ(receiver_mesh.rank(), 1, "receiver: ", output_tv->toString());
154+
NVF_ERROR_EQ(sender_mesh.rank(), 1, "sender: ", input_tv);
155+
NVF_ERROR_EQ(receiver_mesh.rank(), 1, "receiver: ", output_tv);
154156

155157
DeviceIdxType root = sender_mesh.at(0);
156158
Team team = receiver_mesh.vector();
@@ -214,12 +216,14 @@ void lowerToReduce(
214216
std::vector<Expr*>& comms) {
215217
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
216218
const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
217-
NVF_ERROR(
218-
sender_mesh.rank() == 1,
219+
NVF_ERROR_EQ(
220+
sender_mesh.rank(),
221+
1,
219222
"Reduce only supported a 1D mesh. Given ",
220223
sender_mesh);
221-
NVF_ERROR(
222-
receiver_mesh.rank() == 1,
224+
NVF_ERROR_EQ(
225+
receiver_mesh.rank(),
226+
1,
223227
"Reduce only supported a 1D mesh. Given ",
224228
receiver_mesh);
225229
const auto reduce_op_type = getC10dReduceOpType(op_type);
@@ -323,8 +327,9 @@ void lowerToAllToAll(
323327
IterDomain* getLogicalFromLoopId(TensorView* tv, IterDomain* loop_id) {
324328
std::vector<IterDomain*> logical_ids =
325329
ir_utils::getReachableIds(tv->getLogicalDomain(), {loop_id});
326-
NVF_ERROR(
327-
logical_ids.size() == 1,
330+
NVF_ERROR_EQ(
331+
logical_ids.size(),
332+
1,
328333
"Expected exactly one logical ID producing the device dimension ",
329334
loop_id);
330335
return logical_ids.front();
@@ -356,9 +361,11 @@ CommunicationInfo getCommunicationInfo(Expr* e) {
356361
"getCommunicationInfo should only be called when `e` is known to be a "
357362
"communication. Given: ",
358363
e);
359-
364+
NVF_ERROR_EQ(e->inputs().size(), 1, "Expected 1 input, but got ", e);
360365
auto* producer = e->inputs().at(0)->as<TensorView>();
366+
NVF_ERROR_EQ(e->outputs().size(), 1, "Expected 1 output, but got ", e);
361367
auto* consumer = e->outputs().at(0)->as<TensorView>();
368+
362369
std::optional<CommunicationInfo> communication_info = std::nullopt;
363370

364371
// Fill `communication_info` instead of returning the result, so we can catch
@@ -369,13 +376,15 @@ CommunicationInfo getCommunicationInfo(Expr* e) {
369376
NVF_ERROR(
370377
!communication_info.has_value(),
371378
"Expected at most one sharding change: ",
372-
e->toString());
379+
e);
373380
communication_info = CommunicationInfo{type, p_sharded_id, c_sharded_id};
374381
};
375382

376-
const auto pairwise_map = PairwiseLogicalDomainMap(producer, consumer);
377-
const auto p2c_map = pairwise_map.mapProducerToConsumer();
378-
const auto c2p_map = pairwise_map.mapConsumerToProducer();
383+
const PairwiseLogicalDomainMap pairwise_map(producer, consumer);
384+
const std::unordered_map<IterDomain*, IterDomain*> p2c =
385+
pairwise_map.mapProducerToConsumer();
386+
const std::unordered_map<IterDomain*, IterDomain*> c2p =
387+
pairwise_map.mapConsumerToProducer();
379388

380389
// This ignores device dimensions on reduction axis.
381390
auto producer_pt_to_did =
@@ -401,19 +410,19 @@ CommunicationInfo getCommunicationInfo(Expr* e) {
401410
IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did);
402411
CommunicationType type = same_mesh ? CommunicationType::Allgather
403412
: CommunicationType::Gather;
404-
fill_communication_info(type, p_logical_id, p2c_map.at(p_logical_id));
413+
fill_communication_info(type, p_logical_id, p2c.at(p_logical_id));
405414
}
406415
if (!p_loop_did && c_loop_did) {
407416
IterDomain* c_logical_id = getLogicalFromLoopId(consumer, c_loop_did);
408417
fill_communication_info(
409-
CommunicationType::Scatter, c2p_map.at(c_logical_id), c_logical_id);
418+
CommunicationType::Scatter, c2p.at(c_logical_id), c_logical_id);
410419
}
411420
if (p_loop_did && c_loop_did) {
412421
IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did);
413422
IterDomain* c_logical_id = getLogicalFromLoopId(consumer, c_loop_did);
414423
// TODO(#4604): This is problematic for 2D sharding.
415424

416-
if (c_logical_id == p2c_map.at(p_logical_id)) {
425+
if (c_logical_id == p2c.at(p_logical_id)) {
417426
fill_communication_info(
418427
CommunicationType::SendRecv, p_logical_id, c_logical_id);
419428
} else {
@@ -432,27 +441,25 @@ CommunicationInfo getCommunicationInfo(Expr* e) {
432441
IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did);
433442
CommunicationType type = same_mesh ? CommunicationType::Allreduce
434443
: CommunicationType::Reduce;
435-
fill_communication_info(type, p_logical_id, p2c_map.at(p_logical_id));
444+
fill_communication_info(type, p_logical_id, p2c.at(p_logical_id));
436445
continue;
437446
}
438447

439448
// Check if the p_logical_ids is reduced in the output.
440449
IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did);
441450
IterDomain* c_logical_id = getLogicalFromLoopId(consumer, c_loop_did);
442451

443-
auto c_it = p2c_map.find(p_logical_id);
452+
auto c_it = p2c.find(p_logical_id);
444453
NVF_ERROR(
445-
c_it != p2c_map.end(),
454+
c_it != p2c.end(),
446455
"Cannot find the mapped consumer logical ID for the producer logical "
447456
"ID ",
448-
p_logical_id->toString());
457+
p_logical_id);
449458
if (!c_it->second->isReduction()) {
450459
continue;
451460
}
452461
fill_communication_info(
453-
CommunicationType::ReduceScatter,
454-
c2p_map.at(c_logical_id),
455-
c_logical_id);
462+
CommunicationType::ReduceScatter, c2p.at(c_logical_id), c_logical_id);
456463
}
457464
}
458465

@@ -492,8 +499,9 @@ Layout getCommunicationLayout(
492499

493500
const int64_t sharded_id_pos =
494501
posInDomain(layout.allocation_domain(), sharded_id);
495-
NVF_ERROR(
496-
sharded_id_pos >= 0,
502+
NVF_ERROR_GE(
503+
sharded_id_pos,
504+
0,
497505
"Sharded ID (",
498506
sharded_id,
499507
") not found in the allocation domain of the tensor view: ",
@@ -566,7 +574,7 @@ std::vector<Expr*> convertSingleOpToCommunication(
566574
NVF_ERROR(
567575
isCommunicationLayoutCompliant(e),
568576
"Resharding on an inner axis is not lowerable ",
569-
e->toString());
577+
e);
570578

571579
CommunicationInfo communication_info = getCommunicationInfo(e);
572580

0 commit comments

Comments
 (0)