Skip to content

Almost-exact graph: map outer splits across merge (reshape) outputs#6039

Open
wujingyue wants to merge 2 commits intomainfrom
wjy/merge
Open

Almost-exact graph: map outer splits across merge (reshape) outputs#6039
wujingyue wants to merge 2 commits intomainfrom
wjy/merge

Conversation

@wujingyue
Copy link
Copy Markdown
Collaborator

@wujingyue wujingyue commented Mar 24, 2026

Summary

Adds mapDivisibleMergeSplits to buildAlmostExactGraph(), invoked after mapDivisibleSplits. For a Merge (e.g. merging reshape), when the merge output and the merge\u0027s outer input each have a divisible outer Split with the same factor, the two outer IterDomains are mapped in the almost-exact ValGraph.

Also factors isDivisible(Split*) out for reuse with mapDivisibleSplits.

For #3987

Test

  • IdModelTest.MergingReshapeOuterSplit_Mapped: merging reshape {2*2,2} \u2192 {2*2*2}, matching outer_split(0, 2) on input and output; asserts strictAreMapped on in->axis(0) and out->axis(0).

Made with Cursor

Add mapDivisibleMergeSplits after mapDivisibleSplits: for a Merge (e.g. merging reshape), when the merge output and merge outer input each have a divisible outer Split with the same factor, map the two outer IterDomains.

Factor out isDivisible(Split*) for shared use with mapDivisibleSplits.

Add IdModelTest.MergingReshapeOuterSplit_Mapped.

Made-with: Cursor
@wujingyue
Copy link
Copy Markdown
Collaborator Author

!test

@wujingyue wujingyue requested a review from naoyam March 24, 2026 22:45
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 24, 2026

Greptile Summary

This PR extends buildAlmostExactGraph() with a new mapDivisibleMergeSplits pass that maps outer-split outputs across merge (reshape) operations when both the post-merge split and the pre-merge outer-input split share the same factor and are both provably divisible. It also refactors the previously-inline isDivisible lambda into a free function for reuse.

  • New isDivisible(Split*) helper (line 486): cleanly extracted from the mapDivisibleSplits lambda; used by both the old and new pass.
  • New mapDivisibleMergeSplits pass (lines 546–603): iterates every Merge expression group, finds downstream outer splits of the merge output and upstream outer splits of the merge's outer input with equal factors, and maps their outer outputs. The mapping is mathematically sound: if f | merge_outer_extent (guaranteed by isDivisible(split_outer)), then floor(merge_out / (merge_out_extent/f)) == floor(merge_outer / (merge_outer_extent/f)) for all valid index combinations.
  • Positive test MergingReshapeOuterSplit_Mapped: exercises the concrete case {4,2}→{8} with outer_split(0,2) on both input and output; asserts the two outer IterDomains are in the same almost-exact group.
  • test_indexing.cpp update: index variable names shifted by 4 (i126→i130) as a deterministic side effect of additional intermediate Val* objects created during the new isDivisible calls inside mapDivisibleMergeSplits.

Confidence Score: 5/5

  • Safe to merge; the new mapping pass is mathematically correct and validated by the existing consistency check.
  • The mapping logic is provably correct (f | merge_outer_extent ensures the outer split indices on merge_out and merge_outer are identical), all guard conditions (divisibility, outer-split direction, matching factors) are in place, validateConsistency() provides a post-hoc safety net, and the only finding is a minor loop-invariant recomputation style nit that has no effect on correctness.
  • No files require special attention.

Important Files Changed

Filename Overview
csrc/id_model/id_model.cpp Extracts isDivisible(Split*) helper and adds mapDivisibleMergeSplits which correctly maps split_merge->outer()split_outer->outer() when both are divisible outer splits with the same factor. Minor style nit: merge_outer_group is re-computed inside the inner loop where it only depends on the outer loop variable.
tests/cpp/test_id_model.cpp Adds MergingReshapeOuterSplit_Mapped positive test: reshape {4,2}→{8} with matching outer split factor=2 on both input axis(0) and output axis(0); correctly asserts strictAreMapped.
tests/cpp/test_indexing.cpp Updates expected string-matched index variable names from i126/i127/i128 to i130/i131/i132 (shift of 4) due to additional intermediate Val* objects created by the new isDivisible calls in mapDivisibleMergeSplits. Change is expected and mechanical.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["merge_outer (e.g. extent=4)"] --> M["[Merge]"]
    B["merge_inner (e.g. extent=2)"] --> M
    M --> C["merge_out (extent=8)"]
    C --> SM["[split_merge]\nouter split, factor=2\nisDivisible ✓"]
    SM --> SMO["split_merge->outer()\nextent=2"]
    SM --> SMI["split_merge->inner()\nextent=4"]

    A --> SO["[split_outer]\nouter split, factor=2\nisDivisible ✓"]
    SO --> SOO["split_outer->outer()\nextent=2"]
    SO --> SOI["split_outer->inner()\nextent=2"]

    SMO <-.->|"mapDivisibleMergeSplits:\nsame factor + both divisible\n→ mapped in AlmostExact graph"| SOO

    style SMO fill:#c8e6c9
    style SOO fill:#c8e6c9
Loading

Reviews (2): Last reviewed commit: "Fix test" | Re-trigger Greptile

Comment on lines +3388 to +3404
TEST_F(IdModelTest, MergingReshapeOuterSplit_Mapped) {
Fusion fusion;
FusionGuard fg(&fusion);

TensorView* in = makeContigConcreteTensor({2LL * 2, 2});
fusion.addInput(in);
TensorView* out = reshape(in, {2LL * 2, 2}, {2LL * 2 * 2});
fusion.addOutput(out);

in->outer_split(0, 2);
out->outer_split(0, 2);

IdModel id_model(&fusion);
const ValGraph& almost_exact_graph = id_model.buildAlmostExactGraph();
EXPECT_TRUE(almost_exact_graph.disjointValSets().strictAreMapped(
in->axis(0), out->axis(0)));
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing negative tests for mapDivisibleMergeSplits

The PR adds one positive test (MergingReshapeOuterSplit_Mapped) but no negative counterparts exercising the guard conditions inside mapDivisibleMergeSplits. The following cases are silently assumed to not map, but are not explicitly verified:

  1. Inner split instead of outer splitout->inner_split(0, 2) should NOT map in->axis(0) and out->axis(0) (the !split_merge->innerSplit() guard).
  2. Mismatched factorsin->outer_split(0, 2) plus out->outer_split(0, 4) (different factors) should NOT map.
  3. Non-divisible outer split — factors that don't evenly divide the dimension size should NOT map (the isDivisible guard).

The existing NonDivisibleSplits_NotMapped test covers mapDivisibleSplits, but there is no equivalent for mapDivisibleMergeSplits. Consider adding at least one negative test to document these invariants and prevent future regressions.

@wujingyue
Copy link
Copy Markdown
Collaborator Author

!test

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant