Skip to content

Commit f810652

Browse files
authored
[Triton] Fix LoopAwareCSE by removing the equivalence cache (#6894)
I made a classic blunder trying to cache context-sensitive DFS. The DFS cache persisting from one context to another was causing miscompiles because operations that were equal in one context are not equal in another. I just removed the cache for now because properly caching the DFS here is very tricky. If compile time of this pass becomes a problem, we can revisit it.
1 parent 0f1e09e commit f810652

File tree

2 files changed

+31
-11
lines changed

2 files changed

+31
-11
lines changed

lib/Dialect/Triton/Transforms/LoopAwareCSE.cpp

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ struct LoopCSEDriver {
4343
bool areEqualInLoop(Value a, Value b);
4444

4545
scf::ForOp loop;
46-
ValueEquivalence equalValues;
46+
SmallVector<std::pair<int, int>> argStack;
4747
};
4848
} // namespace
4949

@@ -52,14 +52,15 @@ bool LoopCSEDriver::areIterArgsEqual(int i, int j) {
5252
return true;
5353
if (loop.getInitArgs()[i] != loop.getInitArgs()[j])
5454
return false;
55+
if (llvm::is_contained(argStack, std::make_pair(i, j)))
56+
return true;
5557
BlockArgument aArg = loop.getRegionIterArg(i);
5658
BlockArgument bArg = loop.getRegionIterArg(j);
5759
// First, assume the arguments are equal. This is how recursion is broken.
58-
equalValues.setKnownEquivalence(aArg, bArg, true);
60+
argStack.push_back({i, j});
5961
bool result =
6062
areEqualInLoop(loop.getYieldedValues()[i], loop.getYieldedValues()[j]);
61-
// Now update the equivalence based on the actual result.
62-
equalValues.setKnownEquivalence(aArg, bArg, result);
63+
argStack.pop_back();
6364
return result;
6465
}
6566

@@ -83,14 +84,10 @@ bool LoopCSEDriver::areEqualInLoop(Value a, Value b) {
8384
if (a == loop.getInductionVar() || b == loop.getInductionVar())
8485
return false;
8586

86-
if (std::optional<bool> eq = equalValues.getKnownEquivalence(a, b))
87-
return *eq;
88-
8987
if (auto aArg = dyn_cast<BlockArgument>(a)) {
9088
auto bArg = cast<BlockArgument>(b);
9189
bool result =
9290
areIterArgsEqual(aArg.getArgNumber() - 1, bArg.getArgNumber() - 1);
93-
equalValues.setKnownEquivalence(a, b, result);
9491
return result;
9592
}
9693

@@ -107,9 +104,7 @@ bool LoopCSEDriver::areEqualInLoop(Value a, Value b) {
107104
bool result = OperationEquivalence::isEquivalentTo(
108105
aDef, bDef,
109106
[&](Value a, Value b) { return success(areEqualInLoop(a, b)); },
110-
[&](Value a, Value b) { equalValues.setKnownEquivalence(a, b, true); },
111-
OperationEquivalence::IgnoreLocations);
112-
equalValues.setKnownEquivalence(a, b, result);
107+
/*markEquivalent=*/nullptr, OperationEquivalence::IgnoreLocations);
113108
return result;
114109
}
115110

test/Triton/loop_cse.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,28 @@ tt.func @loop_buffer_phase_args(%arg0: i32) {
4545
}
4646
tt.return
4747
}
48+
49+
// CHECK-LABEL: @invalid_cache_test
50+
tt.func public @invalid_cache_test(%arg0: i32, %arg1: i32) -> (i32, i32) {
51+
%c1_i32 = arith.constant 1 : i32
52+
%c3_i32 = arith.constant 3 : i32
53+
%c0_i32 = arith.constant 0 : i32
54+
// CHECK: %0:4 = scf.for
55+
%0:4 = scf.for %arg2 = %c0_i32 to %arg0 step %arg1 iter_args(%arg3 = %c0_i32, %arg4 = %c0_i32, %arg5 = %c0_i32, %arg6 = %c0_i32) -> (i32, i32, i32, i32) : i32 {
56+
57+
%1 = arith.addi %arg5, %c1_i32 : i32
58+
%2 = arith.xori %arg6, %c1_i32 : i32
59+
%3 = arith.cmpi eq, %1, %c3_i32 : i32
60+
%4 = arith.select %3, %2, %arg6 : i32
61+
%5 = arith.select %3, %c1_i32, %1 : i32
62+
63+
%6 = arith.addi %arg3, %c1_i32 : i32
64+
%7 = arith.xori %arg4, %c1_i32 : i32
65+
%8 = arith.cmpi eq, %6, %c3_i32 : i32
66+
%9 = arith.select %8, %c0_i32, %6 : i32
67+
%10 = arith.select %8, %7, %arg4 : i32
68+
69+
scf.yield %9, %10, %5, %4 : i32, i32, i32, i32
70+
}
71+
tt.return %0#1, %0#3 : i32, i32
72+
}

0 commit comments

Comments
 (0)