Skip to content

Commit 7e5b52b

Browse files
committed
do not apply subgroup 2d block encoding to A transpose
1 parent f06f90b commit 7e5b52b

File tree

3 files changed

+98
-4
lines changed

3 files changed

+98
-4
lines changed

test/TritonIntelGPU/optimize-block-io-encoding.mlir

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,67 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.tar
6363
tt.return
6464
}
6565
}
66+
67+
// -----
68+
69+
// COM: Dot operand A transpose currently not supported by subgroup 2d block io encoding
70+
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}>
71+
#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 1], warpsPerCTA = [2, 16], order = [0, 1]}>
72+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}>
73+
// CHECK: #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}>
74+
// CHECK: #mma1 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
75+
// CHECK-NOT: #mma2
76+
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
77+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64"} {
78+
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
79+
%c4_i32 = arith.constant 4 : i32
80+
%c256_i32 = arith.constant 256 : i32
81+
%c1024_i64 = arith.constant 1024 : i64
82+
%c5120_i64 = arith.constant 5120 : i64
83+
%c1_i64 = arith.constant 1 : i64
84+
%c0_i32 = arith.constant 0 : i32
85+
%c4096_i64 = arith.constant 4096 : i64
86+
%c32_i32 = arith.constant 32 : i32
87+
%c64_i32 = arith.constant 64 : i32
88+
%c5120_i32 = arith.constant 5120 : i32
89+
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #blocked>
90+
%0 = tt.get_program_id x : i32
91+
%1 = arith.divsi %0, %c64_i32 : i32
92+
%2 = arith.muli %1, %c4_i32 : i32
93+
%3 = arith.subi %c4_i32, %2 : i32
94+
%4 = arith.minsi %3, %c4_i32 : i32
95+
%5 = arith.remsi %0, %4 : i32
96+
%6 = arith.addi %2, %5 : i32
97+
%7 = arith.remsi %0, %c64_i32 : i32
98+
%8 = arith.divsi %7, %4 : i32
99+
%9 = arith.muli %6, %c256_i32 : i32
100+
%10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c1_i64, %c1024_i64], [%9, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #blocked1>>
101+
%11 = arith.muli %8, %c256_i32 : i32
102+
%12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %11] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked2>>
103+
%13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #blocked>, !tt.ptr<tensor<256x32xf16, #blocked1>>, !tt.ptr<tensor<32x256xf16, #blocked2>>) : i32 {
104+
// CHECK: {{.*}} = tt.load %arg5 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "column_major"} : !tt.ptr<tensor<256x32xf16, #blocked1>>
105+
%17 = tt.load %arg5 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "column_major"} : !tt.ptr<tensor<256x32xf16, #blocked1>>
106+
// CHECK: %[[B_LOAD:.*]] = tt.load %arg6 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #mma>>
107+
// CHECK: {{.*}} = ttg.convert_layout %[[B_LOAD]] : tensor<32x256xf16, #mma> -> tensor<32x256xf16, #blocked2>
108+
%18 = tt.load %arg6 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #blocked2>>
109+
%19 = ttg.convert_layout %17 : tensor<256x32xf16, #blocked1> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>
110+
%20 = ttg.convert_layout %18 : tensor<32x256xf16, #blocked2> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>
111+
%21 = ttg.convert_layout %arg4 : tensor<256x256xf32, #blocked> -> tensor<256x256xf32, #mma>
112+
%22 = ttg.convert_layout %19 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
113+
%23 = ttg.convert_layout %20 : tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
114+
// CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 2}>> -> tensor<256x256xf32, #mma1>
115+
%24 = tt.dot %22, %23, %21, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma>
116+
%25 = ttg.convert_layout %24 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked>
117+
// CHECK: tt.advance {{.*}} : <tensor<256x32xf16, #blocked1>>
118+
%26 = tt.advance %arg5, [%c0_i32, %c32_i32] : <tensor<256x32xf16, #blocked1>>
119+
// CHECK: tt.advance {{.*}} : <tensor<32x256xf16, #mma>>
120+
%27 = tt.advance %arg6, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #blocked2>>
121+
scf.yield %25, %26, %27 : tensor<256x256xf32, #blocked>, !tt.ptr<tensor<256x32xf16, #blocked1>>, !tt.ptr<tensor<32x256xf16, #blocked2>>
122+
}
123+
%14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array<i32: 1, 0>} : <tensor<256x256xf16, #blocked2>>
124+
%15 = arith.truncf %13#0 : tensor<256x256xf32, #blocked> to tensor<256x256xf16, #blocked>
125+
%16 = ttg.convert_layout %15 : tensor<256x256xf16, #blocked> -> tensor<256x256xf16, #blocked2>
126+
tt.store %14, %16 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf16, #blocked2>>
127+
tt.return
128+
}
129+
}

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2247,6 +2247,11 @@ struct LoadOpConversion
22472247
Type valueElemTy = typeConverter->convertType(getElementTypeOrSelf(opType));
22482248
unsigned numElems = getTotalElemsPerThread(opType);
22492249
unsigned vec = getVectorSize(ptr);
2250+
LLVM_DEBUG({
2251+
llvm::dbgs() << "Vectorization for gather load:\n";
2252+
llvm::dbgs() << "\t" << valueElemTy << " [" << numElems << "]\n";
2253+
llvm::dbgs() << "\tvector size = " << vec << " for " << ptr << "\n";
2254+
});
22502255
if (llMask)
22512256
vec = std::min<size_t>(vec, getMaskAlignment(mask));
22522257

third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,15 @@ class TritonIntelGPUOptimizeBlockIOEncodingPass
222222
return;
223223

224224
auto dotOperandType = cast<RankedTensorType>(operand.getType());
225+
auto layout = ttg::toLinearEncoding(dotOperandType);
226+
auto order = layout.getThreadOrder();
227+
auto rank = order.size();
228+
if (rank != 2) {
229+
loadOp.emitWarning(
230+
"Subgroup 2D Block Encoding layouts only support rank 2 operands.");
231+
return;
232+
}
233+
225234
auto dotOperandEncoding =
226235
cast<DotOperandEncodingAttr>(dotOperandType.getEncoding());
227236
// layout width is determined by the DPAS operand encoding width
@@ -232,6 +241,23 @@ class TritonIntelGPUOptimizeBlockIOEncodingPass
232241
if (!blockIOAttr)
233242
return;
234243

244+
const bool valueRowMajor =
245+
getOrderForDotOperand(0, rank, /*kContig=*/true) == order;
246+
const bool memoryRowMajor =
247+
blockIOAttr == StringAttr::get(&getContext(), "row_major");
248+
const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;
249+
LLVM_DEBUG({
250+
DBGS() << "Original layout: " << dotOperandEncoding << "\n";
251+
DBGS() << "\tvalueRowMajor = " << valueRowMajor << "\n";
252+
DBGS() << "\tmemoryRowMajor = " << memoryRowMajor << "\n";
253+
DBGS() << "\tisTransposeRequired = " << isTransposeRequired << "\n";
254+
});
255+
if (dotOperandEncoding.getOpIdx() == 0 && isTransposeRequired) {
256+
LLVM_DEBUG(DBGS() << "Transposed 'A' operand does not yet support "
257+
"Subgroup 2D Block Encoding layout.\n");
258+
return;
259+
}
260+
235261
// get the MakeTensorPtr Op for the load
236262
Value ptr = loadOp.getPtr();
237263
if (!isTensorPointerType(ptr.getType())) {
@@ -261,16 +287,15 @@ class TritonIntelGPUOptimizeBlockIOEncodingPass
261287

262288
auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout(
263289
cast<DistributedEncodingTrait>(dotOperandEncoding),
264-
oldTensorType.getShape(),
265-
blockIOAttr == StringAttr::get(&getContext(), "row_major"),
266-
elemSizeInBits / 8, &getContext());
290+
oldTensorType.getShape(), memoryRowMajor, elemSizeInBits / 8,
291+
&getContext());
267292
SmallVector<unsigned> instrShape{tileParams[0], tileParams[1]};
268293
const unsigned vBlocks = tileParams[2];
269294

270295
auto subgroup2DBlockEncoding = Subgroup2DBlockEncodingAttr::get(
271296
&getContext(), dpasLayout.getWarpsPerCTA(), CTALayout, instrShape,
272297
tileParams[2],
273-
getOrderForDotOperand(dotOperandEncoding.getOpIdx(), /*rank*/ 2,
298+
getOrderForDotOperand(dotOperandEncoding.getOpIdx(), /*rank*/ rank,
274299
/*kContig*/ true),
275300
kWidth, dpasLayout.getThreadsPerWarp());
276301

0 commit comments

Comments
 (0)