Skip to content

Commit e1e48a8

Browse files
authored
Lower the regular pointers to block io for DPAS layout (#4115)
Implement lowering of regular pointers to block IO for DPAS layout, enhancing the support for 2D block IO in tensor loads. Signed-off-by: Lu,Chengjun <[email protected]>
1 parent 0b43117 commit e1e48a8

File tree

2 files changed

+72
-2
lines changed

2 files changed

+72
-2
lines changed

test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,57 @@ module attributes {triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.sup
135135
tt.return
136136
}
137137
}
138+
139+
// -----
140+
141+
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 1], repCluster = [2, 2]}>
142+
#mma_1 = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1]}>
143+
#mma_2 = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 1], repCluster = [4, 2]}>
144+
module attributes {triton_intel_gpu.support_sg_2d_block, "ttg.num-warps" = 8 : i32} {
145+
// CHECK-LABEL: @regular_pointer_block_io
146+
tt.func public @regular_pointer_block_io(%arg0: tensor<256x64x!tt.ptr<f16>, #mma>,
147+
%arg1: tensor<256x64x!tt.ptr<f16>, #mma_1>,
148+
%arg2: tensor<128x64x!tt.ptr<f16>, #mma_2>,
149+
%arg3: tensor<256x64x!tt.ptr<f16>, #mma_2>) {
150+
151+
// CHECK-COUNT-4: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v32f16
152+
%0 = tt.load %arg0 {triton_intel_gpu.block_io = "row_major"} : tensor<256x64x!tt.ptr<f16>, #mma>
153+
154+
// CHECK-COUNT-16: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPDh
155+
%1 = tt.load %arg1 {triton_intel_gpu.block_io = "row_major"} : tensor<256x64x!tt.ptr<f16>, #mma_1>
156+
157+
// CHECK-COUNT-2: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_read_16b_32r16x2cPU3AS1viiiDv2_iPDh
158+
%2 = tt.load %arg3 {triton_intel_gpu.block_io = "row_major"} : tensor<256x64x!tt.ptr<f16>, #mma_2>
159+
160+
// COM: The data is duplicated in the warps because the warp shape is 32*8=256 larger than the tensor shape 128
161+
// CHECK-COUNT-2: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_read_16b_32r16x2cPU3AS1viiiDv2_iPDh
162+
%3 = tt.load %arg2 {triton_intel_gpu.block_io = "row_major"} : tensor<128x64x!tt.ptr<f16>, #mma_2>
163+
tt.return
164+
}
165+
}
166+
167+
// -----
168+
169+
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 1], repCluster = [2, 2]}>
170+
#mma_1 = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 8, 1], repCluster = [1, 2, 2]}>
171+
#mma_32 = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 32, warpsPerCTA = [8, 1], repCluster = [2, 2]}>
172+
module attributes {triton_intel_gpu.support_sg_2d_block, "ttg.num-warps" = 8 : i32} {
173+
// CHECK-LABEL: @regular_pointer_gather_io
174+
tt.func public @regular_pointer_gather_io(%arg0: tensor<128x64x!tt.ptr<f16>, #mma>,
175+
%arg1: tensor<128x64x!tt.ptr<f16>, #mma_32>,
176+
%arg2: tensor<2x128x64x!tt.ptr<f16>, #mma_1>) {
177+
// COM: The pitch is not available in the current implementation.
178+
// COM: Not from axis info or ptrs[{0, 0}] and ptrs[{1, 0}] in the same work item.
179+
// CHECK-COUNT-32: llvm.load {{.*}} {alignment = 2 : i64} : !llvm.ptr<1> -> i16
180+
%0 = tt.load %arg1 {triton_intel_gpu.block_io = "row_major"} : tensor<128x64x!tt.ptr<f16>, #mma_32>
181+
182+
// COM: Not support column major block io.
183+
// CHECK-COUNT-32: llvm.load {{.*}} {alignment = 2 : i64} : !llvm.ptr<1> -> i16
184+
%1 = tt.load %arg0 {triton_intel_gpu.block_io = "column_major"} : tensor<128x64x!tt.ptr<f16>, #mma>
185+
186+
// COM: Not support rank size > 2.
187+
// CHECK-COUNT-128: llvm.load {{.*}} {alignment = 2 : i64} : !llvm.ptr<1> -> i16
188+
%2 = tt.load %arg2 {triton_intel_gpu.block_io = "column_major"} : tensor<2x128x64x!tt.ptr<f16>, #mma_1>
189+
tt.return
190+
}
191+
}

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,10 @@ struct LoadOpToBlockIOConversion
830830
auto llAttr = LinearEncodingAttr::get(rewriter.getContext(), *llEncoding);
831831
SmallVector<unsigned> threadOrder(llAttr.getThreadOrder());
832832
size_t rank = threadOrder.size();
833-
assert(rank == 2 && "only support rank of 2 for now");
833+
if (rank != 2) {
834+
// only support rank of 2 for now.
835+
return failure();
836+
}
834837
const bool valueRowMajor =
835838
(threadOrder[rank - 2] == 1 && threadOrder[rank - 1] == 0);
836839
assert((valueRowMajor ||
@@ -936,6 +939,12 @@ struct LoadOpToBlockIOConversion
936939
}
937940
} break;
938941
case DpasEncodingAttr::OpIdx::OperandC:
942+
warpShape = std::move(dpasLayout.getShapeC());
943+
dpasInstShape = std::move(dpasLayout.getDPASInstShapeC());
944+
dimOuter = rank - 2;
945+
dimInner = rank - 1;
946+
usePackedType = false;
947+
break;
939948
default:
940949
llvm_unreachable("unknown DPAS operands index type.");
941950
break;
@@ -1056,6 +1065,9 @@ struct LoadOpToBlockIOConversion
10561065
numOperandsPer2DLoadN = repCluster[dimOuter];
10571066
break;
10581067
case DpasEncodingAttr::OpIdx::OperandC:
1068+
numOperandsPer2DLoadM = repCluster[dimOuter];
1069+
numOperandsPer2DLoadN = repCluster[dimInner];
1070+
break;
10591071
default:
10601072
llvm_unreachable("unknown DPAS operands index type.");
10611073
break;
@@ -1137,6 +1149,10 @@ struct LoadOpToBlockIOConversion
11371149
repInnerStride = warpShape[dimInner] * numOperandsInnerDimPerLoad;
11381150
break;
11391151
case DpasEncodingAttr::OpIdx::OperandC:
1152+
numRepOuter = numReps[dimOuter];
1153+
numRepInner = numReps[dimInner];
1154+
repInnerStride = warpShape[dimInner] * innerDimWarpNum;
1155+
break;
11401156
default:
11411157
llvm_unreachable("unknown DPAS operands index type.");
11421158
break;
@@ -1320,6 +1336,7 @@ struct LoadOpToBlockIOConversion
13201336

13211337
// Save the decomposed vals to the map;
13221338
switch (opIdx) {
1339+
case DpasEncodingAttr::OpIdx::OperandC:
13231340
case DpasEncodingAttr::OpIdx::OperandA: {
13241341
unsigned o = outer * numLoadPerOutRepCluster *
13251342
numOperandsOuterDimPerLoad +
@@ -1343,7 +1360,6 @@ struct LoadOpToBlockIOConversion
13431360
loadVals[{o, i}] =
13441361
b.bitcast(loadVal, unpackedDPASOperandType);
13451362
} break;
1346-
case DpasEncodingAttr::OpIdx::OperandC:
13471363
default: {
13481364
llvm_unreachable("unknown DPAS operands index type.");
13491365
} break;

0 commit comments

Comments
 (0)