Skip to content

Commit 14a85e7

Browse files
[BE] Fix for colliding tmem allocation boundaries (#6318)
It may happen that two tmem allocations share the same liverange end boundary (if it ends at block bound). This case was not handled properly in the tmem allocation pass, causing tmem overallocation.
1 parent 5ce3754 commit 14a85e7

File tree

2 files changed

+48
-4
lines changed

2 files changed

+48
-4
lines changed

lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ static Interval<int> getLiveIntervals(Value value, Liveness &liveness,
142142
}
143143

144144
static void updateMap(MemoryBitMap &memoryMap, Interval<int> liveInterval,
145-
std::map<int, TMemChunk> &intervalLiverangeEnd) {
145+
std::multimap<int, TMemChunk> &intervalLiverangeEnd) {
146146
int start = liveInterval.start();
147147
// Add any dead liverange to the list of free intervals.
148148
for (auto it = intervalLiverangeEnd.begin();
@@ -247,7 +247,7 @@ allocateTMem(Operation *parentOp,
247247
int totalMemorySize = 0;
248248
MemoryBitMap memoryMap;
249249
Liveness liveness(parentOp);
250-
std::map<int, TMemChunk> intervalLiverangeEnd;
250+
std::multimap<int, TMemChunk> intervalLiverangeEnd;
251251
DenseMap<TMEMAllocOp, TMemChunk> allocChunks;
252252
// Implement a linear scan first fit algorithm. We expect that fragmentation
253253
// won't be a problem, if it is this should be revisited.
@@ -283,7 +283,7 @@ allocateTMem(Operation *parentOp,
283283
allocChunks.insert({alloc, chunkAllocated});
284284
// currently naively constraint allocs based on the first one we find.
285285
rowIdConstraints.addConstraints(alloc, chunkAllocated.startRow);
286-
intervalLiverangeEnd[liveInterval.end()] = chunkAllocated;
286+
intervalLiverangeEnd.insert({liveInterval.end(), chunkAllocated});
287287
int colOffset = chunkAllocated.startCol;
288288
int rowOffset = chunkAllocated.startRow * 16;
289289

test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
6262
#tmem1 = #ttng.tensor_memory_encoding<blockM = 64, blockN = 128, unpacked = true>
6363
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65536 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
6464
// CHECK: ttg.tensor_memory_size = 512
65-
// CHECK: alloc_tensor_memory
65+
// CHECK: alloc_tensor_memory_re_use
6666
tt.func public @alloc_tensor_memory_re_use(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {
6767
%true = arith.constant true
6868
%c1 = arith.constant 1 : i32
@@ -113,6 +113,50 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
113113

114114
// -----
115115

116+
#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
117+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 64, unpacked = true>
118+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65536 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
119+
// CHECK: ttg.tensor_memory_size = 128
120+
// CHECK: alloc_tensor_memory_re_use_liverange_end_collision
121+
tt.func public @alloc_tensor_memory_re_use_liverange_end_collision(
122+
%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>,
123+
%lb: index, %ub: index, %step: index) {
124+
%true = arith.constant true
125+
%c1 = arith.constant 1 : i32
126+
%c0 = arith.constant 0 : i32
127+
%cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #blocked>
128+
%cst0 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #blocked>
129+
%cst1 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #blocked>
130+
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #blocked>
131+
132+
// CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32}
133+
%a = ttng.tmem_alloc %cst0 : (tensor<128x64xf32, #blocked>) -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable>
134+
135+
// CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 64 : i32, tensor_memory_row_offset = 0 : i32}
136+
%b = ttng.tmem_alloc %cst : (tensor<128x64xf32, #blocked>) -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable>
137+
138+
scf.for %i = %lb to %ub step %step {
139+
ttng.tmem_store %cst2, %a, %true : tensor<128x64xf32, #blocked> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable>
140+
ttng.tmem_store %cst2, %b, %true : tensor<128x64xf32, #blocked> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable>
141+
scf.yield
142+
}
143+
// Liveranges of both allocations end at the same time, at the boundary of the loop. Make sure we can handle this case.
144+
145+
// CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32}
146+
%c = ttng.tmem_alloc %cst0 : (tensor<128x64xf32, #blocked>) -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable>
147+
148+
// CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 64 : i32, tensor_memory_row_offset = 0 : i32}
149+
%d = ttng.tmem_alloc %cst : (tensor<128x64xf32, #blocked>) -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable>
150+
151+
ttng.tmem_store %cst2, %c, %true : tensor<128x64xf32, #blocked> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable>
152+
ttng.tmem_store %cst2, %d, %true : tensor<128x64xf32, #blocked> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable>
153+
154+
tt.return
155+
}
156+
}
157+
158+
// -----
159+
116160
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
117161
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true, CTASplitM = 2>
118162
#tmem1 = #ttng.tensor_memory_encoding<blockM = 128, blockN = 64, unpacked = true, CTASplitN = 2>

0 commit comments

Comments
 (0)