Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions csrc/id_model/id_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "device_lower/lower2device.h"
#include "device_lower/utils.h"
#include "disjoint_set.h"
#include "expr_simplifier.h"
#include "id_model/loop_promotion.h"
#include "id_model/to_string.h"
#include "id_model/transform_replay.h"
Expand Down Expand Up @@ -481,6 +482,66 @@ std::vector<std::vector<Val*>> getTriviallyMappedIds(Expr* expr) {
return mapped_ids;
}

// The following is a subpattern of
// https://github.com/NVIDIA/Fuser/blob/main/doc/reading/iterdomain.md#2-properties-of-iterdomain-transformations
//
// outer, _ = split(root)
// outermost_grand, _ = split(outer)
// outer', _ = split(root)
//
// If outermost_grand and outer' have the same extent, map them.
// The splits must be divisible for this mapping to be valid.
void mapDivisibleSplits(ValGraph& graph) {
auto is_divisible = [](Split* s) {
return simplifyExpr(s->isDivisible())->isTrue();
};

std::vector<std::pair<Val*, Val*>> ids_to_map;
for (const ValGroup& root : graph.disjointValSets().disjointSets()) {
const ExprGroups& uses_of_root = graph.getUses(root);
std::vector<ValGroup> outermost_grands;
for (const ExprGroup& use_of_root : uses_of_root) {
auto* split0 = dynamic_cast<Split*>(use_of_root->front());
if (split0 == nullptr || !is_divisible(split0)) {
continue;
}
// Only follow the outer output of the first split; outer and inner
// must not be conflated.
const ValGroup& outer = graph.toGroup(split0->outer());
for (const ExprGroup& use_of_outer : graph.getUses(outer)) {
auto* split1 = dynamic_cast<Split*>(use_of_outer->front());
if (split1 == nullptr || !is_divisible(split1)) {
continue;
}
const ValGroup& outermost_grand = graph.toGroup(split1->outer());
outermost_grands.push_back(outermost_grand);
}
}

for (const ValGroup& outermost_grand : outermost_grands) {
Val* extent_of_grand =
outermost_grand->front()->as<IterDomain>()->extent();

for (const ExprGroup& use_of_root : uses_of_root) {
auto* split = dynamic_cast<Split*>(use_of_root->front());
if (split == nullptr || !is_divisible(split)) {
continue;
}

const ValGroup& outer = graph.toGroup(split->outer());
if (outer->front()->as<IterDomain>()->extent()->sameAs(
extent_of_grand)) {
ids_to_map.emplace_back(outermost_grand->front(), outer->front());
Comment on lines +521 to +534
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second loop can self-map outermost_grand through its own ancestor

The second loop re-scans every use of root and maps outer->front() to outermost_grand->front() whenever their extents agree. Nothing prevents that outer from being the very outer that feeds into split1 to produce outermost_grand — i.e. the direct parent of outermost_grand in the split chain.

In practice this case cannot occur when split1.factor > 1 (the parent has a strictly larger extent), but when split1.factor == 1 (a degenerate, length-1 inner split), outermost_grand.extent == outer.extent. Then ids_to_map receives a pair (outermost_grand->front(), outer->front()) that maps the outer output of split1 to its own input. mapVals will not fire because is_divisible checks split1->isDivisible(), and a factor-1 split is always divisible, so the pair would be queued.

A guard comparing the two ValGroup pointers before emitting the pair would make the intent explicit and protect against future degenerate fusions:

if (outer != outermost_grand &&
    outer->front()->as<IterDomain>()->extent()->sameAs(extent_of_grand)) {
  ids_to_map.emplace_back(outermost_grand->front(), outer->front());
}

}
}
}
Comment on lines +521 to +537
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate mapping entries possible

When multiple split-split paths from the same root share the same outermost_grand, that group gets added to outermost_grands more than once. The second loop will then emit duplicate (id1, id2) pairs into ids_to_map for every repeated grand. mapVals is idempotent, so correctness is preserved, but the deduplication of outermost_grands (e.g. using an UnorderedSetOfValGroup rather than a std::vector) would prevent the redundant work and keep ids_to_map minimal.

}

for (const auto& [id1, id2] : ids_to_map) {
graph.mapVals(id1, id2);
}
}

} // namespace

ValGraph& IdModel::buildAlmostExactGraph() {
Expand Down Expand Up @@ -540,6 +601,8 @@ ValGraph& IdModel::buildAlmostExactGraph() {
almost_exact_graph.mapVals(id1, id2);
}

mapDivisibleSplits(almost_exact_graph);

almost_exact_graph.validateConsistency();

if (!allow_self_mapping_) {
Expand Down
13 changes: 7 additions & 6 deletions csrc/val_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,8 @@ std::string ValGraph::toString() const {
ss << "IdGraph { \n";
ss << "Disjoint Ids:\n"
<< idGroupsString(*this, 1) << "\n\nDisjoint Expression groups:\n"
<< exprGroupsString(*this, 1) << std::endl;
ss << " } IdGraph\n" << std::endl;
<< exprGroupsString(*this, 1) << '\n';
ss << " } IdGraph\n";
return ss.str();
}

Expand Down Expand Up @@ -397,11 +397,12 @@ const ExprGroups& ValGraph::getDefinitions(const ValGroup& val_group) const {

const ExprGroups& ValGraph::getUses(const ValGroup& val_group) const {
NVF_ERROR(val_group, "Nullptr not allowed");

static const ExprGroups empty_expr_groups;
const auto it = unique_uses_.find(val_group);
NVF_ERROR(
it != unique_uses_.end(),
"Use group not found for ",
nvfuser::toString(val_group));
if (it == unique_uses_.end()) {
return empty_expr_groups;
}
return it->second;
}
Comment on lines 398 to 407
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Behavioral change silently relaxes a previous invariant

The old implementation treated a missing val_group entry in unique_uses_ as a hard error (via NVF_ERROR). The new implementation silently returns an empty set. While this is functionally required by mapDivisibleSplits (which calls getUses on leaf nodes that have no entries), it also removes the diagnostic for callers that previously relied on the error to detect graphs built incorrectly. By contrast, getDefinitions still throws on a missing entry.

Consider whether a comment, or a separate hasUses()/tryGetUses() accessor, would make the relaxed contract explicit without silently hiding misuse.


Expand Down
100 changes: 74 additions & 26 deletions tests/cpp/test_id_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
*/
// clang-format on

#include <fstream>

#include <gmock/gmock-matchers.h>
#include <gtest/gtest.h>

Expand All @@ -17,7 +15,6 @@
#include "id_model/loop_promotion.h"
#include "id_model/schedule.h"
#include "id_model/to_string.h"
#include "ir/graphviz.h"
#include "ops/all_ops.h"
#include "scheduler/tools/inlining.h"
#include "scheduler/tools/resize_utils.h"
Expand Down Expand Up @@ -235,8 +232,7 @@ void validateIELResolution(
auto promotion_id = iel_promotion_map_it->second;
ASSERT_TRUE(
exact_graph.disjointValSets().strictAreMapped(promotion_id, ref_id))
<< "Unexpected promotion. "
<< "Expected: " << ref_id->toString()
<< "Unexpected promotion. Expected: " << ref_id->toString()
<< ". Actual: " << promotion_id->toString();
ASSERT_TRUE(loop_graph.disjointValSets().strictAreMapped(id, promotion_id))
<< "Promotion of " << id->toString()
Expand Down Expand Up @@ -376,9 +372,9 @@ void checkStep4Results(
const auto& iel_promotion_map = tester.s4_iel_promotion_map;

EXPECT_EQ(iel_promotion_map.size(), ref_promotion_map.size())
<< "Mismatched Step-4 result map. "
<< "Expected to have " << ref_promotion_map.size()
<< " mappings but found " << iel_promotion_map.size();
<< "Mismatched Step-4 result map. Expected to have "
<< ref_promotion_map.size() << " mappings but found "
<< iel_promotion_map.size();

for (const auto& ref_promotion_pair : ref_promotion_map) {
const auto& ref_promotion_group = ref_promotion_pair.first;
Expand Down Expand Up @@ -2937,9 +2933,8 @@ TEST_F(IdModelTest, LoopPromotionCyclicGraphWar) {
// Test to verify the split-aware covered group analysis. See
// also https://github.com/NVIDIA/Fuser/pull/3877.
TEST_F(IdModelTest, CoveredGroups) {
auto fusion_ptr = std::make_unique<Fusion>();
auto& fusion = *fusion_ptr;
FusionGuard fg(fusion_ptr.get());
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeContigConcreteTensor({-1, 1});
fusion.addInput(tv0);
Expand Down Expand Up @@ -3000,7 +2995,7 @@ TEST_F(IdModelTest, CoveredGroups) {
TEST_F(IdModelTest, InvalidLoopPromotion) {
auto fusion_ptr = std::make_unique<Fusion>();
auto& fusion = *fusion_ptr;
FusionGuard fg(fusion_ptr.get());
FusionGuard fg(&fusion);

auto T0 = makeContigConcreteTensor({1, 32, 6});
fusion.addInput(T0);
Expand Down Expand Up @@ -3086,9 +3081,8 @@ TEST_F(IdModelTest, InvalidLoopPromotion) {
// When a loop group only includes broadcast IDs, the group should not
// need to be promoted
TEST_F(IdModelTest, BroadcastOnlyNoLoopPromotion) {
auto fusion_ptr = std::make_unique<Fusion>();
auto& fusion = *fusion_ptr;
FusionGuard fg(fusion_ptr.get());
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeContigConcreteTensor({-1, 1});
fusion.addInput(tv0);
Expand Down Expand Up @@ -3130,9 +3124,8 @@ TEST_F(IdModelTest, BroadcastOnlyNoLoopPromotion) {

// Scatter output uses unique mapping schemes
TEST_F(IdModelTest, ScatterLoopMapping) {
auto fusion_ptr = std::make_unique<Fusion>();
auto& fusion = *fusion_ptr;
FusionGuard fg(fusion_ptr.get());
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeContigTensor(1);
fusion.addInput(tv0);
Expand Down Expand Up @@ -3185,8 +3178,7 @@ TEST_F(IdModelTest, ScatterLoopMapping) {
// required but is a WAR for special ops like
// PreprocessGroupedMatmulInputSf. See also issue #5391.
TEST_F(IdModelTest, LoopPromotionIncludeOnlyLoopIds) {
auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeSymbolicTensor(2);
Expand Down Expand Up @@ -3219,8 +3211,7 @@ TEST_F(IdModelTest, LoopPromotionIncludeOnlyLoopIds) {
}

TEST_F(IdModelTest, PermissiveResizeGraph) {
auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeConcreteTensor({36});
Expand Down Expand Up @@ -3270,8 +3261,7 @@ TEST_F(IdModelTest, PermissiveResizeGraph) {
// This is the failing segment of the reproducer of
// https://github.com/NVIDIA/Fuser/issues/5803.
TEST_F(IdModelTest, ReproIssue5803) {
auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
Fusion fusion;
FusionGuard fg(&fusion);

auto tv2 = makeContigConcreteTensor({4}, DataType::Int);
Expand Down Expand Up @@ -3312,8 +3302,7 @@ TEST_F(IdModelTest, ReproIssue5803) {
// This is a minimal fusion pattern to trigger the loop promotion
// issue as reported in https://github.com/NVIDIA/Fuser/issues/5803
TEST_F(IdModelTest, ReproIssue5803Minimal) {
auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeConcreteTensor({4, 8});
Expand All @@ -3337,4 +3326,63 @@ TEST_F(IdModelTest, ReproIssue5803Minimal) {
IdModel id_model(&fusion, true);
}

TEST_F(IdModelTest, SplittingReshape_Mapped) {
Fusion fusion;
FusionGuard fg(&fusion);

TensorView* in = makeContigConcreteTensor({2 * 2 * 2});
fusion.addInput(in);
TensorView* out = reshape(in, {2 * 2 * 2}, {2 * 2, 2});
fusion.addOutput(out);

in->outer_split(0, 2);
out->outer_split(0, 2);

IdModel id_model(&fusion);
const ValGraph& almost_exact_graph = id_model.buildAlmostExactGraph();
EXPECT_TRUE(almost_exact_graph.disjointValSets().strictAreMapped(
in->axis(0), out->axis(0)));
EXPECT_FALSE(almost_exact_graph.disjointValSets().strictAreMapped(
in->axis(0), out->axis(1)));
EXPECT_FALSE(almost_exact_graph.disjointValSets().strictAreMapped(
in->axis(1), out->axis(2)));
}

TEST_F(IdModelTest, SplitingReshape_DifferentExtents_NotMapped) {
Fusion fusion;
FusionGuard fg(&fusion);

TensorView* in = makeContigConcreteTensor({12});
fusion.addInput(in);
TensorView* out = reshape(in, {12}, {6, 2});
fusion.addOutput(out);

in->outer_split(0, 2);
out->outer_split(0, 3);

IdModel id_model(&fusion);
const ValGraph& almost_exact_graph = id_model.buildAlmostExactGraph();
EXPECT_FALSE(almost_exact_graph.disjointValSets().strictAreMapped(
in->axis(0), out->axis(0)));
}

TEST_F(IdModelTest, NonDivisibleSplits_NotMapped) {
Fusion fusion;
FusionGuard fg(&fusion);

TensorView* in = makeContigConcreteTensor({6});
fusion.addInput(in);
TensorView* out = set(in);
fusion.addOutput(out);

in->outer_split(0, 2);
out->inner_split(0, 4);
out->outer_split(0, 2);

IdModel id_model(&fusion);
const ValGraph& almost_exact_graph = id_model.buildAlmostExactGraph();
EXPECT_FALSE(almost_exact_graph.disjointValSets().strictAreMapped(
in->axis(0), out->axis(0)));
}

} // namespace nvfuser
4 changes: 2 additions & 2 deletions tests/cpp/test_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -860,9 +860,9 @@ TEST_F(IndexingTest, Reshape) {
// to provide the extent of the group. However, since everything
// should be deterministic, string match should also work.
return std::string(
"( ( ( ( ( i114 * 20 ) + ( ( i115 * 10 ) + i116 ) ) / 25 ) * 25 "
"( ( ( ( ( i126 * 20 ) + ( ( i127 * 10 ) + i128 ) ) / 25 ) * 25 "
") "
"+ ( ( ( i114 * 20 ) + ( ( i115 * 10 ) + i116 ) ) % 25 ) )");
"+ ( ( ( i126 * 20 ) + ( ( i127 * 10 ) + i128 ) ) % 25 ) )");
}
default:
return std::string();
Expand Down
Loading