Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ class TargetInfoBase {
virtual bool supportLdMatrix() const { return false; }
virtual bool supportStMatrix() const { return false; }
virtual bool isCuda() const { return false; }
virtual bool isXpu() const { return false; }

// Annotate target specific information to local load operations during
// lowering to LLVM. `llLoadOp` is the generated LLVM load op.
Expand Down
6 changes: 2 additions & 4 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,8 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
auto dstTy = cvtLayout.getType();
if (!cvtNeedsSharedMemory(srcTy, dstTy))
return 0;
// Pesimistically take the max. We will revisit later
auto elems = std::max(getNumScratchElemsSwizzledCvt(srcTy, dstTy),
getNumScratchElemsPaddedCvt(srcTy, dstTy));

// The generic pass uses swizzling
auto elems = getNumScratchElemsSwizzledCvt(srcTy, dstTy);
return elems * getBitwidth(srcTy) / 8;
}
if (isa<AtomicRMWOp, AtomicCASOp>(op)) {
Expand Down
33 changes: 10 additions & 23 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
} else if (llvm::is_contained(dims, kWarp)) {
// Case 2: Transfer between values in the same CTA, in which case we move
// values through shared memory.
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter);
return success();
} else if (llvm::is_contained(dims, kLane)) {
// Case 3. Transfer between values in the same warp, in which case we try
// to move values using warp shuffles, though if the pattern is
Expand All @@ -74,7 +75,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// TODO: Since data is only transferred within a warp over shared memory,
// we should use `bar.warp.sync` instead of `barrier`, which will improve
// latency when warps issue barriers on different cycles.
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter);
return success();
} else if (llvm::is_contained(dims, kRegister)) {
// Case 4. Transfer between values in the same thread, in which case we
// simply reorder the elements of adaptor.getSrc().
Expand Down Expand Up @@ -110,27 +112,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
return success();
}

LogicalResult transferWithinBlock(ConvertLayoutOp op,
const LinearLayout &srcLayout,
const LinearLayout &dstLayout,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
assert(cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));

// Try to use swizzling to implement the conversion
// HACK Remove once XPU tests pass for the swizzling path
if (!targetInfo.isXpu()) {
transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter);
return success();
}

Value result = transferWithinBlockPadding(op, adaptor.getSrc(), targetInfo,
getTypeConverter(), rewriter);

rewriter.replaceOp(op, result);
return success();
}

SmallVector<Value> transferWithinBlockSwizzlingImpl(
Location loc, ConversionPatternRewriter &rewriter,
const LinearLayout &srcLayout, const LinearLayout &dstLayout,
Expand Down Expand Up @@ -163,6 +144,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
inVals, [&](Value v) { return b.zext(i8ElemTy, v).getResult(); }));
auto outVals = transferWithinBlockSwizzlingImpl(
loc, rewriter, srcLayout, dstLayout, newInVals, i8ElemTy, smemBase);
if (llvmElemTy.getIntOrFloatBitWidth() == 1) {
auto zero = b.int_val(8, 0);
for (auto &v : outVals)
v = b.icmp_ne(v, zero);
return outVals;
}
for (auto &v : outVals) {
v = b.trunc(llvmElemTy, v);
}
Expand Down
26 changes: 13 additions & 13 deletions test/Analysis/test-allocation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -201,22 +201,22 @@ tt.func @longlive(%A : !tt.ptr<f16>) {

// This example triggers graph coloring with > 1 colors.
// expected-remark @below {{multi_color}}
// expected-remark @below {{size = 1504}}
// expected-remark @below {{size = 1376}}
tt.func @multi_color(%A : !tt.ptr<f16>) {
// expected-remark @below {{offset = 1152, size = 64}}
// expected-remark @below {{offset = 1024, size = 64}}
%cst = ttg.local_alloc : () -> !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable>
// expected-remark @below {{offset = 1472, size = 32}}
// expected-remark @below {{offset = 1344, size = 32}}
%cst_0 = ttg.local_alloc : () -> !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable>
// expected-remark @below {{offset = 1216, size = 128}}
// expected-remark @below {{offset = 1088, size = 128}}
%cst_1 = ttg.local_alloc : () -> !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
// expected-remark @below {{scratch offset = 0, size = 1152}}
// expected-remark @below {{scratch offset = 0, size = 1024}}
%0 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL>
%1 = ttg.local_load %cst : !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x8xf16, #AL>
// expected-remark @below {{offset = 0, size = 128}}
%cst_3 = ttg.local_alloc : () -> !ttg.memdesc<4x16xf16, #A_SHARED, #ttg.shared_memory, mutable>
%2 = ttg.local_load %cst_0 : !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x4xf16, #AL>
// expected-remark @below {{scratch offset = 0, size = 1152}}
// expected-remark @below {{scratch offset = 0, size = 1024}}
%3 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL>
// expected-remark @below {{offset = 512, size = 256}}
%cst_4 = ttg.local_alloc : () -> !ttg.memdesc<4x32xf16, #A_SHARED, #ttg.shared_memory, mutable>
Expand All @@ -226,7 +226,7 @@ tt.func @multi_color(%A : !tt.ptr<f16>) {
%5 = ttg.local_load %cst_5 : !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x8xf16, #AL>
// expected-remark @below {{offset = 0, size = 512}}
%cst_6 = ttg.local_alloc : () -> !ttg.memdesc<8x32xf16, #A_SHARED, #ttg.shared_memory, mutable>
// expected-remark @below {{offset = 1344, size = 128}}
// expected-remark @below {{offset = 1216, size = 128}}
%cst_7 = ttg.local_alloc : () -> !ttg.memdesc<2x32xf16, #A_SHARED, #ttg.shared_memory, mutable>
%6 = ttg.local_load %cst_0 : !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x4xf16, #AL>
// expected-remark @below {{offset = 0, size = 512}}
Expand All @@ -237,7 +237,7 @@ tt.func @multi_color(%A : !tt.ptr<f16>) {
%cst_10 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable>
%7 = ttg.local_load %cst_1 : !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x4xf16, #AL>
%8 = ttg.local_load %cst_4 : !ttg.memdesc<4x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x32xf16, #AL>
// expected-remark @below {{scratch offset = 0, size = 1152}}
// expected-remark @below {{scratch offset = 0, size = 1024}}
%9 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL>
%cst_11 = arith.constant dense<0.000000e+00> : tensor<4x4xf16, #AL>
%10 = ttg.local_load %cst_7 : !ttg.memdesc<2x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<2x32xf16, #AL>
Expand All @@ -248,16 +248,16 @@ tt.func @multi_color(%A : !tt.ptr<f16>) {

// This example triggers graph coloring with multiple rounds
// expected-remark @below {{multi_color_multi_rounds}}
// expected-remark @below {{size = 9504}}
// expected-remark @below {{size = 9376}}
tt.func @multi_color_multi_rounds(%arg0: !tt.ptr<f16>) {
// expected-remark @below {{offset = 9472, size = 32}}
// expected-remark @below {{offset = 9344, size = 32}}
%cst = ttg.local_alloc : () -> !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable>
// expected-remark @below {{offset = 9344, size = 128}}
// expected-remark @below {{offset = 9216, size = 128}}
%cst_0 = ttg.local_alloc : () -> !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable>
// expected-remark @below {{offset = 0, size = 8192}}
%cst_1 = ttg.local_alloc : () -> !ttg.memdesc<1024x4xf16, #A_SHARED, #ttg.shared_memory, mutable>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
// expected-remark @below {{scratch offset = 8192, size = 1152}}
// expected-remark @below {{scratch offset = 8192, size = 1024}}
%0 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL>
%1 = ttg.local_load %cst : !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x4xf16, #AL>
// expected-remark @below {{offset = 8704, size = 128}}
Expand All @@ -267,7 +267,7 @@ tt.func @multi_color_multi_rounds(%arg0: !tt.ptr<f16>) {
%cst_4 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable>
%3 = ttg.local_load %cst_0 : !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x4xf16, #AL>
%4 = ttg.local_load %cst_1 : !ttg.memdesc<1024x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<1024x4xf16, #AL>
// expected-remark @below {{scratch offset = 0, size = 1152}}
// expected-remark @below {{scratch offset = 0, size = 1024}}
%5 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL>
%6 = ttg.local_load %cst_3 : !ttg.memdesc<2x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<2x32xf16, #AL>
tt.return
Expand Down
Loading
Loading