Skip to content

Commit 28f406e

Browse files
[Intel] Enable swizzling path (#4910)
Fixes #4887 Revert 4809815 , 0d2fd2d
2 parents d0e80f3 + b6025b4 commit 28f406e

File tree

7 files changed

+100
-114
lines changed

7 files changed

+100
-114
lines changed

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ class TargetInfoBase {
8989
virtual bool supportLdMatrix() const { return false; }
9090
virtual bool supportStMatrix() const { return false; }
9191
virtual bool isCuda() const { return false; }
92-
virtual bool isXpu() const { return false; }
9392

9493
// Annotate target specific information to local load operations during
9594
// lowering to LLVM. `llLoadOp` is the generated LLVM load op.

lib/Analysis/Allocation.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,10 +208,8 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
208208
auto dstTy = cvtLayout.getType();
209209
if (!cvtNeedsSharedMemory(srcTy, dstTy))
210210
return 0;
211-
// Pesimistically take the max. We will revisit later
212-
auto elems = std::max(getNumScratchElemsSwizzledCvt(srcTy, dstTy),
213-
getNumScratchElemsPaddedCvt(srcTy, dstTy));
214-
211+
// The generic pass uses swizzling
212+
auto elems = getNumScratchElemsSwizzledCvt(srcTy, dstTy);
215213
return elems * getBitwidth(srcTy) / 8;
216214
}
217215
if (isa<AtomicRMWOp, AtomicCASOp>(op)) {

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
6363
} else if (llvm::is_contained(dims, kWarp)) {
6464
// Case 2: Transfer between values in the same CTA, in which case we move
6565
// values through shared memory.
66-
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
66+
transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter);
67+
return success();
6768
} else if (llvm::is_contained(dims, kLane)) {
6869
// Case 3. Transfer between values in the same warp, in which case we try
6970
// to move values using warp shuffles, though if the pattern is
@@ -74,7 +75,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
7475
// TODO: Since data is only transferred within a warp over shared memory,
7576
// we should use `bar.warp.sync` instead of `barrier`, which will improve
7677
// latency when warps issue barriers on different cycles.
77-
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
78+
transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter);
79+
return success();
7880
} else if (llvm::is_contained(dims, kRegister)) {
7981
// Case 4. Transfer between values in the same thread, in which case we
8082
// simply reorder the elements of adaptor.getSrc().
@@ -110,27 +112,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
110112
return success();
111113
}
112114

113-
LogicalResult transferWithinBlock(ConvertLayoutOp op,
114-
const LinearLayout &srcLayout,
115-
const LinearLayout &dstLayout,
116-
OpAdaptor adaptor,
117-
ConversionPatternRewriter &rewriter) const {
118-
assert(cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
119-
120-
// Try to use swizzling to implement the conversion
121-
// HACK Remove once XPU tests pass for the swizzling path
122-
if (!targetInfo.isXpu()) {
123-
transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter);
124-
return success();
125-
}
126-
127-
Value result = transferWithinBlockPadding(op, adaptor.getSrc(), targetInfo,
128-
getTypeConverter(), rewriter);
129-
130-
rewriter.replaceOp(op, result);
131-
return success();
132-
}
133-
134115
SmallVector<Value> transferWithinBlockSwizzlingImpl(
135116
Location loc, ConversionPatternRewriter &rewriter,
136117
const LinearLayout &srcLayout, const LinearLayout &dstLayout,
@@ -163,6 +144,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
163144
inVals, [&](Value v) { return b.zext(i8ElemTy, v).getResult(); }));
164145
auto outVals = transferWithinBlockSwizzlingImpl(
165146
loc, rewriter, srcLayout, dstLayout, newInVals, i8ElemTy, smemBase);
147+
if (llvmElemTy.getIntOrFloatBitWidth() == 1) {
148+
auto zero = b.int_val(8, 0);
149+
for (auto &v : outVals)
150+
v = b.icmp_ne(v, zero);
151+
return outVals;
152+
}
166153
for (auto &v : outVals) {
167154
v = b.trunc(llvmElemTy, v);
168155
}

test/Analysis/test-allocation.mlir

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -201,22 +201,22 @@ tt.func @longlive(%A : !tt.ptr<f16>) {
201201

202202
// This example triggers graph coloring with > 1 colors.
203203
// expected-remark @below {{multi_color}}
204-
// expected-remark @below {{size = 1504}}
204+
// expected-remark @below {{size = 1376}}
205205
tt.func @multi_color(%A : !tt.ptr<f16>) {
206-
// expected-remark @below {{offset = 1152, size = 64}}
206+
// expected-remark @below {{offset = 1024, size = 64}}
207207
%cst = ttg.local_alloc : () -> !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable>
208-
// expected-remark @below {{offset = 1472, size = 32}}
208+
// expected-remark @below {{offset = 1344, size = 32}}
209209
%cst_0 = ttg.local_alloc : () -> !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable>
210-
// expected-remark @below {{offset = 1216, size = 128}}
210+
// expected-remark @below {{offset = 1088, size = 128}}
211211
%cst_1 = ttg.local_alloc : () -> !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable>
212212
%cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
213-
// expected-remark @below {{scratch offset = 0, size = 1152}}
213+
// expected-remark @below {{scratch offset = 0, size = 1024}}
214214
%0 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL>
215215
%1 = ttg.local_load %cst : !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x8xf16, #AL>
216216
// expected-remark @below {{offset = 0, size = 128}}
217217
%cst_3 = ttg.local_alloc : () -> !ttg.memdesc<4x16xf16, #A_SHARED, #ttg.shared_memory, mutable>
218218
%2 = ttg.local_load %cst_0 : !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x4xf16, #AL>
219-
// expected-remark @below {{scratch offset = 0, size = 1152}}
219+
// expected-remark @below {{scratch offset = 0, size = 1024}}
220220
%3 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL>
221221
// expected-remark @below {{offset = 512, size = 256}}
222222
%cst_4 = ttg.local_alloc : () -> !ttg.memdesc<4x32xf16, #A_SHARED, #ttg.shared_memory, mutable>
@@ -226,7 +226,7 @@ tt.func @multi_color(%A : !tt.ptr<f16>) {
226226
%5 = ttg.local_load %cst_5 : !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x8xf16, #AL>
227227
// expected-remark @below {{offset = 0, size = 512}}
228228
%cst_6 = ttg.local_alloc : () -> !ttg.memdesc<8x32xf16, #A_SHARED, #ttg.shared_memory, mutable>
229-
// expected-remark @below {{offset = 1344, size = 128}}
229+
// expected-remark @below {{offset = 1216, size = 128}}
230230
%cst_7 = ttg.local_alloc : () -> !ttg.memdesc<2x32xf16, #A_SHARED, #ttg.shared_memory, mutable>
231231
%6 = ttg.local_load %cst_0 : !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x4xf16, #AL>
232232
// expected-remark @below {{offset = 0, size = 512}}
@@ -237,7 +237,7 @@ tt.func @multi_color(%A : !tt.ptr<f16>) {
237237
%cst_10 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable>
238238
%7 = ttg.local_load %cst_1 : !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x4xf16, #AL>
239239
%8 = ttg.local_load %cst_4 : !ttg.memdesc<4x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x32xf16, #AL>
240-
// expected-remark @below {{scratch offset = 0, size = 1152}}
240+
// expected-remark @below {{scratch offset = 0, size = 1024}}
241241
%9 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL>
242242
%cst_11 = arith.constant dense<0.000000e+00> : tensor<4x4xf16, #AL>
243243
%10 = ttg.local_load %cst_7 : !ttg.memdesc<2x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<2x32xf16, #AL>
@@ -248,16 +248,16 @@ tt.func @multi_color(%A : !tt.ptr<f16>) {
248248

249249
// This example triggers graph coloring with multiple rounds
250250
// expected-remark @below {{multi_color_multi_rounds}}
251-
// expected-remark @below {{size = 9504}}
251+
// expected-remark @below {{size = 9376}}
252252
tt.func @multi_color_multi_rounds(%arg0: !tt.ptr<f16>) {
253-
// expected-remark @below {{offset = 9472, size = 32}}
253+
// expected-remark @below {{offset = 9344, size = 32}}
254254
%cst = ttg.local_alloc : () -> !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable>
255-
// expected-remark @below {{offset = 9344, size = 128}}
255+
// expected-remark @below {{offset = 9216, size = 128}}
256256
%cst_0 = ttg.local_alloc : () -> !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable>
257257
// expected-remark @below {{offset = 0, size = 8192}}
258258
%cst_1 = ttg.local_alloc : () -> !ttg.memdesc<1024x4xf16, #A_SHARED, #ttg.shared_memory, mutable>
259259
%cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
260-
// expected-remark @below {{scratch offset = 8192, size = 1152}}
260+
// expected-remark @below {{scratch offset = 8192, size = 1024}}
261261
%0 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL>
262262
%1 = ttg.local_load %cst : !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x4xf16, #AL>
263263
// expected-remark @below {{offset = 8704, size = 128}}
@@ -267,7 +267,7 @@ tt.func @multi_color_multi_rounds(%arg0: !tt.ptr<f16>) {
267267
%cst_4 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable>
268268
%3 = ttg.local_load %cst_0 : !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x4xf16, #AL>
269269
%4 = ttg.local_load %cst_1 : !ttg.memdesc<1024x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<1024x4xf16, #AL>
270-
// expected-remark @below {{scratch offset = 0, size = 1152}}
270+
// expected-remark @below {{scratch offset = 0, size = 1024}}
271271
%5 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL>
272272
%6 = ttg.local_load %cst_3 : !ttg.memdesc<2x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<2x32xf16, #AL>
273273
tt.return

0 commit comments

Comments
 (0)