Skip to content

Commit 407b8a3

Browse files
[ConSan] Fixes for Warp Specialization support (#8265)
This PR introduces two functional and one performance improvement. 1) Identifying if an allocation is used as a multibuffer is currently done with IR tracking in ConSan - we check if the allocation is being used by `subview_index`. There was a missing case, as we were not tracking values into the WarpSpecializeOp, interpreting multibuferred allocations as regular ones. This PR fixes this issue. 2) For tcgen5 mma with barrier (so doing commit implicitly) we were emitting `track_visible_reads/writes` only for tensor core buffers, so consan was only tracking status of TC buffers accessed by mma op. This meant for example that if partition A waited for mma issued by partition B, A still couldn't legally write to mma's operands, even though waiting for mma to finish should be enough for it to be legal. This is fixed now. 3) We were emitting `track_visible_reads/writes` after checks of every operand for mma op. This is expensive, as it consists of number of global memory accesses. This is now rewritten to emit these ops just once after all the operands are checked. With these changes ConSan shows no false positives in the test_warp_specialization.py tests.
1 parent 748a47e commit 407b8a3

File tree

3 files changed

+225
-185
lines changed

3 files changed

+225
-185
lines changed

lib/Dialect/TritonInstrument/IR/Utility.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,21 @@ bool canAllocBeInstrumented(Operation *op) {
134134
}
135135

136136
// Interpret local_allocs that are used in ttg.memdesc_index as multibuffered
137-
bool isMultiBuffered(Operation *op) {
138-
return llvm::any_of(op->getUsers(), [](Operation *user) {
139-
return isa<MemDescIndexOp>(user);
140-
});
137+
bool isMultiBuffered(Value v) {
138+
for (auto &use : v.getUses()) {
139+
if (isa<MemDescIndexOp>(use.getOwner())) {
140+
return true;
141+
}
142+
if (auto wsOp = dyn_cast<WarpSpecializeOp>(use.getOwner())) {
143+
int opNumber = use.getOperandNumber();
144+
for (Region *region : wsOp.getPartitionRegions()) {
145+
if (isMultiBuffered(region->getArguments()[opNumber])) {
146+
return true;
147+
}
148+
}
149+
}
150+
}
151+
return false;
141152
}
142153

143154
uint64_t getAllocationOffset(LocalAllocOp op) {

0 commit comments

Comments
 (0)