Skip to content

Commit de650ad

Browse files
authored
[BACKEND] Don't allocate shmem for warps with repeated data in tt.scan (#5910)
It turns out that the previous changes within reduce to support LLs had already trimmed its shmem memory use to the right size.
1 parent 464d1f1 commit de650ad

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

lib/Analysis/Utility.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ bool ScanLoweringHelper::isSupported() {
267267
}
268268

269269
unsigned ScanLoweringHelper::getScratchSizeInElems() {
270-
unsigned numWarps = lookupNumWarps(scanOp);
270+
unsigned numWarps = product(getEncoding().getWarpsPerCTA());
271271
unsigned numNonAxisElementsPerWarp =
272272
getNonAxisNumThreadsPerWarp() * getNonAxisNumElementsPerThread();
273273
unsigned numElements = numWarps * numNonAxisElementsPerWarp *

test/Analysis/test-allocation.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,4 +615,14 @@ tt.func @call_graph_2(%A : !tt.ptr<f16>, %cond : i1) {
615615
// CHECK-NEXT: size = 1024
616616
}
617617

618+
// CHECK-LABEL: scan_alloc
619+
tt.func @scan_alloc(%x : tensor<8x16xf32, #AL>) {
620+
// CHECK: offset = 0, size = 128
621+
%a = "tt.scan"(%x) <{axis = 0 : i32, reverse = false}>({
622+
^bb0(%arg0: f32, %arg1: f32):
623+
%add = arith.addf %arg0, %arg1 : f32
624+
tt.scan.return %add : f32
625+
}) : (tensor<8x16xf32, #AL>) -> tensor<8x16xf32, #AL>
626+
tt.return
627+
}
618628
}

0 commit comments

Comments
 (0)