diff --git a/csrc/host_ir/lowering.cpp b/csrc/host_ir/lowering.cpp index 9b5404df2c2..3924d5658aa 100644 --- a/csrc/host_ir/lowering.cpp +++ b/csrc/host_ir/lowering.cpp @@ -192,11 +192,11 @@ void lowerSegment( out, DomainType::kLoop, {ParallelType::Stream})) { - auto [i, inserted] = replacement_map.try_emplace( - in, - hir::shardByStream(in, innermost.loop->index(), communication)); - if (inserted) { - innermost_scope.pushBack(i->second->definition()); + Val*& sharded_in = replacement_map[in]; + if (sharded_in == nullptr) { + sharded_in = + hir::shardByStream(in, innermost.loop->index(), communication); + innermost_scope.pushBack(sharded_in->definition()); } } @@ -210,7 +210,7 @@ void lowerSegment( nullptr) { innermost.parent_scope->insert( innermost.parent_insertion_point, allocate); - auto [i, inserted] = replacement_map.try_emplace( + auto [i, inserted] = replacement_map.emplace( out, hir::shardByStream(out, innermost.loop->index(), communication)); NVF_ERROR(inserted, "The input segmented fusion should be SSA."); @@ -314,9 +314,6 @@ void lowerSegment( innermost.parent_insertion_point, allocate); // Loop is stream parallelized but allocation is not. Therefore, // `out` should be allocated outside the loop. - // - // I use try_emplace here so shardByStream is called only when `out` - // is missing. TensorView* sharded_out = hir::shardByStream(out, innermost.loop->index(), e); replacement_map[out] = sharded_out; diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index c968a8545ca..e1b45cbbb52 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -114,7 +114,7 @@ std::unordered_map mapDeviceAndStreamParallelTypeToId } NVF_ERROR( - parallel_type_to_id.try_emplace(parallel_type, id).second, + parallel_type_to_id.emplace(parallel_type, id).second, "Found multiple loop IterDomains with the same parallel type (", parallel_type, "): ",