Skip to content

Commit 0d2fd2d

Browse files
[ConvertLayoutOpToLLVM] Reintroduce transferWithinBlock
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 200dc70 commit 0d2fd2d

File tree

3 files changed

+40
-19
lines changed

3 files changed

+40
-19
lines changed

lib/Analysis/Allocation.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,10 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
208208
auto dstTy = cvtLayout.getType();
209209
if (!cvtNeedsSharedMemory(srcTy, dstTy))
210210
return 0;
211-
// The generic pass uses swizzling
212-
auto elems = getNumScratchElemsSwizzledCvt(srcTy, dstTy);
211+
// Pesimistically take the max. We will revisit later
212+
auto elems = std::max(getNumScratchElemsSwizzledCvt(srcTy, dstTy),
213+
getNumScratchElemsPaddedCvt(srcTy, dstTy));
214+
213215
return elems * getBitwidth(srcTy) / 8;
214216
}
215217
if (isa<AtomicRMWOp, AtomicCASOp>(op)) {

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
6363
} else if (llvm::is_contained(dims, kWarp)) {
6464
// Case 2: Transfer between values in the same CTA, in which case we move
6565
// values through shared memory.
66-
transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter);
67-
return success();
66+
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
6867
} else if (llvm::is_contained(dims, kLane)) {
6968
// Case 3. Transfer between values in the same warp, in which case we try
7069
// to move values using warp shuffles, though if the pattern is
@@ -75,8 +74,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
7574
// TODO: Since data is only transferred within a warp over shared memory,
7675
// we should use `bar.warp.sync` instead of `barrier`, which will improve
7776
// latency when warps issue barriers on different cycles.
78-
transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter);
79-
return success();
77+
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
8078
} else if (llvm::is_contained(dims, kRegister)) {
8179
// Case 4. Transfer between values in the same thread, in which case we
8280
// simply reorder the elements of adaptor.getSrc().
@@ -112,6 +110,27 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
112110
return success();
113111
}
114112

113+
LogicalResult transferWithinBlock(ConvertLayoutOp op,
114+
const LinearLayout &srcLayout,
115+
const LinearLayout &dstLayout,
116+
OpAdaptor adaptor,
117+
ConversionPatternRewriter &rewriter) const {
118+
assert(cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
119+
120+
// Try to use swizzling to implement the conversion
121+
// HACK Remove once XPU tests pass for the swizzling path
122+
if (!targetInfo.isXpu()) {
123+
transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter);
124+
return success();
125+
}
126+
127+
Value result = transferWithinBlockPadding(op, adaptor.getSrc(), targetInfo,
128+
getTypeConverter(), rewriter);
129+
130+
rewriter.replaceOp(op, result);
131+
return success();
132+
}
133+
115134
SmallVector<Value> transferWithinBlockSwizzlingImpl(
116135
Location loc, ConversionPatternRewriter &rewriter,
117136
const LinearLayout &srcLayout, const LinearLayout &dstLayout,

test/Analysis/test-allocation.mlir

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -201,22 +201,22 @@ tt.func @longlive(%A : !tt.ptr<f16>) {
201201

202202
// This example triggers graph coloring with > 1 colors.
203203
// expected-remark @below {{multi_color}}
204-
// expected-remark @below {{size = 1376}}
204+
// expected-remark @below {{size = 1504}}
205205
tt.func @multi_color(%A : !tt.ptr<f16>) {
206-
// expected-remark @below {{offset = 1024, size = 64}}
206+
// expected-remark @below {{offset = 1152, size = 64}}
207207
%cst = ttg.local_alloc : () -> !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable>
208-
// expected-remark @below {{offset = 1344, size = 32}}
208+
// expected-remark @below {{offset = 1472, size = 32}}
209209
%cst_0 = ttg.local_alloc : () -> !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable>
210-
// expected-remark @below {{offset = 1088, size = 128}}
210+
// expected-remark @below {{offset = 1216, size = 128}}
211211
%cst_1 = ttg.local_alloc : () -> !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable>
212212
%cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
213-
// expected-remark @below {{scratch offset = 0, size = 1024}}
213+
// expected-remark @below {{scratch offset = 0, size = 1152}}
214214
%0 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL>
215215
%1 = ttg.local_load %cst : !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x8xf16, #AL>
216216
// expected-remark @below {{offset = 0, size = 128}}
217217
%cst_3 = ttg.local_alloc : () -> !ttg.memdesc<4x16xf16, #A_SHARED, #ttg.shared_memory, mutable>
218218
%2 = ttg.local_load %cst_0 : !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x4xf16, #AL>
219-
// expected-remark @below {{scratch offset = 0, size = 1024}}
219+
// expected-remark @below {{scratch offset = 0, size = 1152}}
220220
%3 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL>
221221
// expected-remark @below {{offset = 512, size = 256}}
222222
%cst_4 = ttg.local_alloc : () -> !ttg.memdesc<4x32xf16, #A_SHARED, #ttg.shared_memory, mutable>
@@ -226,7 +226,7 @@ tt.func @multi_color(%A : !tt.ptr<f16>) {
226226
%5 = ttg.local_load %cst_5 : !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x8xf16, #AL>
227227
// expected-remark @below {{offset = 0, size = 512}}
228228
%cst_6 = ttg.local_alloc : () -> !ttg.memdesc<8x32xf16, #A_SHARED, #ttg.shared_memory, mutable>
229-
// expected-remark @below {{offset = 1216, size = 128}}
229+
// expected-remark @below {{offset = 1344, size = 128}}
230230
%cst_7 = ttg.local_alloc : () -> !ttg.memdesc<2x32xf16, #A_SHARED, #ttg.shared_memory, mutable>
231231
%6 = ttg.local_load %cst_0 : !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x4xf16, #AL>
232232
// expected-remark @below {{offset = 0, size = 512}}
@@ -237,7 +237,7 @@ tt.func @multi_color(%A : !tt.ptr<f16>) {
237237
%cst_10 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable>
238238
%7 = ttg.local_load %cst_1 : !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x4xf16, #AL>
239239
%8 = ttg.local_load %cst_4 : !ttg.memdesc<4x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x32xf16, #AL>
240-
// expected-remark @below {{scratch offset = 0, size = 1024}}
240+
// expected-remark @below {{scratch offset = 0, size = 1152}}
241241
%9 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL>
242242
%cst_11 = arith.constant dense<0.000000e+00> : tensor<4x4xf16, #AL>
243243
%10 = ttg.local_load %cst_7 : !ttg.memdesc<2x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<2x32xf16, #AL>
@@ -248,16 +248,16 @@ tt.func @multi_color(%A : !tt.ptr<f16>) {
248248

249249
// This example triggers graph coloring with multiple rounds
250250
// expected-remark @below {{multi_color_multi_rounds}}
251-
// expected-remark @below {{size = 9376}}
251+
// expected-remark @below {{size = 9504}}
252252
tt.func @multi_color_multi_rounds(%arg0: !tt.ptr<f16>) {
253-
// expected-remark @below {{offset = 9344, size = 32}}
253+
// expected-remark @below {{offset = 9472, size = 32}}
254254
%cst = ttg.local_alloc : () -> !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable>
255-
// expected-remark @below {{offset = 9216, size = 128}}
255+
// expected-remark @below {{offset = 9344, size = 128}}
256256
%cst_0 = ttg.local_alloc : () -> !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable>
257257
// expected-remark @below {{offset = 0, size = 8192}}
258258
%cst_1 = ttg.local_alloc : () -> !ttg.memdesc<1024x4xf16, #A_SHARED, #ttg.shared_memory, mutable>
259259
%cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
260-
// expected-remark @below {{scratch offset = 8192, size = 1024}}
260+
// expected-remark @below {{scratch offset = 8192, size = 1152}}
261261
%0 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL>
262262
%1 = ttg.local_load %cst : !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x4xf16, #AL>
263263
// expected-remark @below {{offset = 8704, size = 128}}
@@ -267,7 +267,7 @@ tt.func @multi_color_multi_rounds(%arg0: !tt.ptr<f16>) {
267267
%cst_4 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable>
268268
%3 = ttg.local_load %cst_0 : !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x4xf16, #AL>
269269
%4 = ttg.local_load %cst_1 : !ttg.memdesc<1024x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<1024x4xf16, #AL>
270-
// expected-remark @below {{scratch offset = 0, size = 1024}}
270+
// expected-remark @below {{scratch offset = 0, size = 1152}}
271271
%5 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL>
272272
%6 = ttg.local_load %cst_3 : !ttg.memdesc<2x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<2x32xf16, #AL>
273273
tt.return

0 commit comments

Comments
 (0)