88
99#include " host_ir/lower_to_communication.h"
1010
11- #include " host_ir/container.h"
12- #include " ir/all_nodes.h"
13- #include " ir/allocation_utils.h"
11+ #include < algorithm>
12+ #include < iterator>
13+ #include < optional>
14+ #include < vector>
15+
1416#include " ir/builder.h"
1517#include " ir/internal_base_nodes.h"
1618#include " ir/iostream.h"
1719#include " ir/utils.h"
18- #include " kernel_ir .h"
20+ #include " logical_domain_map .h"
1921#include " multidevice/post_communication.h"
2022#include " multidevice/resharding.h"
2123#include " multidevice/utils.h"
22- #include " ops/all_ops.h"
2324
2425namespace nvfuser {
2526
@@ -57,10 +58,11 @@ void lowerToScatter(
5758 const CommunicatorBackend backend,
5859 std::vector<Expr*>& comms) {
5960 const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh ();
60- NVF_ERROR (
61- receiver_mesh.rank () == 1 ,
62- " Gather only supported on a 1D mesh. Given " ,
63- receiver_mesh);
61+ NVF_ERROR_EQ (
62+ receiver_mesh.rank (),
63+ 1 ,
64+ " Scatter only supported on a 1D mesh. Given " ,
65+ output_tv);
6466
6567 // Find a common device between input and receiver meshes to be the root
6668 std::vector<DeviceIdxType> input_devices = input_tv->getDeviceMesh ().vector ();
@@ -149,8 +151,8 @@ void lowerToBroadcast(
149151 const DeviceMesh& sender_mesh = input_tv->getDeviceMesh ();
150152 const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh ();
151153
152- NVF_ERROR_EQ (sender_mesh.rank (), 1 , " sender: " , input_tv-> toString () );
153- NVF_ERROR_EQ (receiver_mesh.rank (), 1 , " receiver: " , output_tv-> toString () );
154+ NVF_ERROR_EQ (sender_mesh.rank (), 1 , " sender: " , input_tv);
155+ NVF_ERROR_EQ (receiver_mesh.rank (), 1 , " receiver: " , output_tv);
154156
155157 DeviceIdxType root = sender_mesh.at (0 );
156158 Team team = receiver_mesh.vector ();
@@ -214,12 +216,14 @@ void lowerToReduce(
214216 std::vector<Expr*>& comms) {
215217 const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh ();
216218 const DeviceMesh& sender_mesh = input_tv->getDeviceMesh ();
217- NVF_ERROR (
218- sender_mesh.rank () == 1 ,
219+ NVF_ERROR_EQ (
220+ sender_mesh.rank (),
221+ 1 ,
219222 " Reduce only supported a 1D mesh. Given " ,
220223 sender_mesh);
221- NVF_ERROR (
222- receiver_mesh.rank () == 1 ,
224+ NVF_ERROR_EQ (
225+ receiver_mesh.rank (),
226+ 1 ,
223227 " Reduce only supported a 1D mesh. Given " ,
224228 receiver_mesh);
225229 const auto reduce_op_type = getC10dReduceOpType (op_type);
@@ -323,8 +327,9 @@ void lowerToAllToAll(
323327IterDomain* getLogicalFromLoopId (TensorView* tv, IterDomain* loop_id) {
324328 std::vector<IterDomain*> logical_ids =
325329 ir_utils::getReachableIds (tv->getLogicalDomain (), {loop_id});
326- NVF_ERROR (
327- logical_ids.size () == 1 ,
330+ NVF_ERROR_EQ (
331+ logical_ids.size (),
332+ 1 ,
328333 " Expected exactly one logical ID producing the device dimension " ,
329334 loop_id);
330335 return logical_ids.front ();
@@ -356,9 +361,11 @@ CommunicationInfo getCommunicationInfo(Expr* e) {
356361 " getCommunicationInfo should only be called when `e` is known to be a "
357362 " communication. Given: " ,
358363 e);
359-
364+ NVF_ERROR_EQ (e-> inputs (). size (), 1 , " Expected 1 input, but got " , e);
360365 auto * producer = e->inputs ().at (0 )->as <TensorView>();
366+ NVF_ERROR_EQ (e->outputs ().size (), 1 , " Expected 1 output, but got " , e);
361367 auto * consumer = e->outputs ().at (0 )->as <TensorView>();
368+
362369 std::optional<CommunicationInfo> communication_info = std::nullopt ;
363370
364371 // Fill `communication_info` instead of returning the result, so we can catch
@@ -369,13 +376,15 @@ CommunicationInfo getCommunicationInfo(Expr* e) {
369376 NVF_ERROR (
370377 !communication_info.has_value (),
371378 " Expected at most one sharding change: " ,
372- e-> toString () );
379+ e);
373380 communication_info = CommunicationInfo{type, p_sharded_id, c_sharded_id};
374381 };
375382
376- const auto pairwise_map = PairwiseLogicalDomainMap (producer, consumer);
377- const auto p2c_map = pairwise_map.mapProducerToConsumer ();
378- const auto c2p_map = pairwise_map.mapConsumerToProducer ();
383+ const PairwiseLogicalDomainMap pairwise_map (producer, consumer);
384+ const std::unordered_map<IterDomain*, IterDomain*> p2c =
385+ pairwise_map.mapProducerToConsumer ();
386+ const std::unordered_map<IterDomain*, IterDomain*> c2p =
387+ pairwise_map.mapConsumerToProducer ();
379388
380389 // This ignores device dimensions on reduction axis.
381390 auto producer_pt_to_did =
@@ -401,19 +410,19 @@ CommunicationInfo getCommunicationInfo(Expr* e) {
401410 IterDomain* p_logical_id = getLogicalFromLoopId (producer, p_loop_did);
402411 CommunicationType type = same_mesh ? CommunicationType::Allgather
403412 : CommunicationType::Gather;
404- fill_communication_info (type, p_logical_id, p2c_map .at (p_logical_id));
413+ fill_communication_info (type, p_logical_id, p2c .at (p_logical_id));
405414 }
406415 if (!p_loop_did && c_loop_did) {
407416 IterDomain* c_logical_id = getLogicalFromLoopId (consumer, c_loop_did);
408417 fill_communication_info (
409- CommunicationType::Scatter, c2p_map .at (c_logical_id), c_logical_id);
418+ CommunicationType::Scatter, c2p .at (c_logical_id), c_logical_id);
410419 }
411420 if (p_loop_did && c_loop_did) {
412421 IterDomain* p_logical_id = getLogicalFromLoopId (producer, p_loop_did);
413422 IterDomain* c_logical_id = getLogicalFromLoopId (consumer, c_loop_did);
414423 // TODO(#4604): This is problematic for 2D sharding.
415424
416- if (c_logical_id == p2c_map .at (p_logical_id)) {
425+ if (c_logical_id == p2c .at (p_logical_id)) {
417426 fill_communication_info (
418427 CommunicationType::SendRecv, p_logical_id, c_logical_id);
419428 } else {
@@ -432,27 +441,25 @@ CommunicationInfo getCommunicationInfo(Expr* e) {
432441 IterDomain* p_logical_id = getLogicalFromLoopId (producer, p_loop_did);
433442 CommunicationType type = same_mesh ? CommunicationType::Allreduce
434443 : CommunicationType::Reduce;
435- fill_communication_info (type, p_logical_id, p2c_map .at (p_logical_id));
444+ fill_communication_info (type, p_logical_id, p2c .at (p_logical_id));
436445 continue ;
437446 }
438447
439448 // Check if the p_logical_ids is reduced in the output.
440449 IterDomain* p_logical_id = getLogicalFromLoopId (producer, p_loop_did);
441450 IterDomain* c_logical_id = getLogicalFromLoopId (consumer, c_loop_did);
442451
443- auto c_it = p2c_map .find (p_logical_id);
452+ auto c_it = p2c .find (p_logical_id);
444453 NVF_ERROR (
445- c_it != p2c_map .end (),
454+ c_it != p2c .end (),
446455 " Cannot find the mapped consumer logical ID for the producer logical "
447456 " ID " ,
448- p_logical_id-> toString () );
457+ p_logical_id);
449458 if (!c_it->second ->isReduction ()) {
450459 continue ;
451460 }
452461 fill_communication_info (
453- CommunicationType::ReduceScatter,
454- c2p_map.at (c_logical_id),
455- c_logical_id);
462+ CommunicationType::ReduceScatter, c2p.at (c_logical_id), c_logical_id);
456463 }
457464 }
458465
@@ -492,8 +499,9 @@ Layout getCommunicationLayout(
492499
493500 const int64_t sharded_id_pos =
494501 posInDomain (layout.allocation_domain (), sharded_id);
495- NVF_ERROR (
496- sharded_id_pos >= 0 ,
502+ NVF_ERROR_GE (
503+ sharded_id_pos,
504+ 0 ,
497505 " Sharded ID (" ,
498506 sharded_id,
499507 " ) not found in the allocation domain of the tensor view: " ,
@@ -566,7 +574,7 @@ std::vector<Expr*> convertSingleOpToCommunication(
566574 NVF_ERROR (
567575 isCommunicationLayoutCompliant (e),
568576 " Resharding on an inner axis is not lowerable " ,
569- e-> toString () );
577+ e);
570578
571579 CommunicationInfo communication_info = getCommunicationInfo (e);
572580
0 commit comments