Skip to content

Commit 21c232a

Browse files
[LoadStoreOpToLLVM] Fix load with base height == 1 (#4602)
When `strides[0]` is 0, we only want to load the first row, so we set the base height to be 1. (<= done in another PR) When base height is less than tile height and base height is 1, only the first row contain valid data. To ensure the entire tile is filled with valid data, we must replicate the first row throughout the tile. --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent da87f29 commit 21c232a

File tree

2 files changed

+89
-1
lines changed

2 files changed

+89
-1
lines changed

test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,48 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32} {
246246
tt.return
247247
}
248248
}
249+
250+
// -----
251+
252+
// COM: Check codegen when base height is 1 and tile height is > 1.
253+
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}>
254+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} {
255+
// CHECK-LABEL: @baseheight1
256+
tt.func public @baseheight1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
257+
%18 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>}>>
258+
%19 = tt.expand_dims %18 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>}>> -> tensor<1x32xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
259+
%20 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1x32x!tt.ptr<f32>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
260+
%21 = tt.addptr %20, %19 : tensor<1x32x!tt.ptr<f32>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>, tensor<1x32xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
261+
%22 = tt.broadcast %21 : tensor<1x32x!tt.ptr<f32>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> -> tensor<64x32x!tt.ptr<f32>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
262+
%50 = tt.load %22 {ttig.block_io = "row_major"} : tensor<64x32x!tt.ptr<f32>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
263+
// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : i32
264+
// CHECK: [[LOAD:%.*]] = triton_gen.2Dblockload %{{.*}}, %{{.*}}, [[C1]], %{{.*}}, %{{.*}}, %{{.*}} {elem_size_in_bits = 32, tile_width = 8, tile_height = 16, v_blocks = 2
265+
266+
// CHECK: [[VEC:%.*]] = llvm.mlir.undef : vector<2xi32>
267+
268+
// CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) : i32
269+
// CHECK: [[OLDVAL:%.*]] = llvm.extractelement [[LOAD]][[[C0]] : i32] : vector<16xi32>
270+
// CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) : i32
271+
// CHECK: [[THREADID_i64:%.*]] = llvm.call spir_funccc @_Z12get_local_idj([[C0]])
272+
// CHECK: [[THREADID:%.*]] = llvm.trunc [[THREADID_i64]] : i64 to i32
273+
// CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : i32) : i32
274+
// CHECK: [[REM:%.*]] = llvm.urem [[THREADID]], [[C8]] : i32
275+
// CHECK: [[NEWVAL:%.*]] = llvm.call spir_funccc @_Z17sub_group_shuffleij([[OLDVAL]], [[REM]])
276+
// CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) : i32
277+
// CHECK: [[VEC1:%.*]] = llvm.insertelement [[NEWVAL]], [[VEC]][[[C0]] : i32] : vector<2xi32>
278+
279+
// CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : i32) : i32
280+
// CHECK: [[OLDVAL:%.*]] = llvm.extractelement [[LOAD]][[[C8]] : i32] : vector<16xi32>
281+
// CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) : i32
282+
// CHECK: [[THREADID_i64:%.*]] = llvm.call spir_funccc @_Z12get_local_idj([[C0]])
283+
// CHECK: [[THREADID:%.*]] = llvm.trunc [[THREADID_i64]] : i64 to i32
284+
// CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : i32) : i32
285+
// CHECK: [[REM:%.*]] = llvm.urem [[THREADID]], [[C8]] : i32
286+
// CHECK: [[NEWVAL:%.*]] = llvm.call spir_funccc @_Z17sub_group_shuffleij([[OLDVAL]], [[REM]])
287+
// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : i32
288+
// CHECK: [[VEC2:%.*]] = llvm.insertelement [[NEWVAL]], [[VEC1]][[[C1]] : i32] : vector<2xi32>
289+
290+
// CHECK: llvm.shufflevector [[VEC2]], [[VEC2]] [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
291+
tt.return
292+
}
293+
}

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1283,7 +1283,8 @@ struct LoadOpToBlockIOConversion
12831283

12841284
// If the stride is 0, we want to load only the first row.
12851285
int stride = getStride(ptr, 0);
1286-
Value baseHeight = b.i32_val(stride == 0 ? 1 : tileHeight);
1286+
unsigned baseHeightInt = (stride == 0 ? 1 : tileHeight);
1287+
Value baseHeight = b.i32_val(baseHeightInt);
12871288

12881289
StringAttr kRegister = str_attr("register");
12891290
StringAttr kLane = str_attr("lane");
@@ -1380,6 +1381,48 @@ struct LoadOpToBlockIOConversion
13801381
(usePackedType && opIdx == DpasEncodingAttr::OpIdx::OperandB &&
13811382
!isTransposeRequired && originalElemBits != 32));
13821383

1384+
// When strides[0] is 0, we only want to load the first row, so we
1385+
// set the base height to be 1. If tile height is bigger than 1,
1386+
// then only the first row contain valid data. To ensure the entire
1387+
// tile is filled with valid data, we must replicate the first row
1388+
// throughout the tile.
1389+
if (baseHeightInt < tileHeight && baseHeightInt == 1) {
1390+
unsigned numIndicesPerMatrix = numValuesPerLoad / vBlocks;
1391+
SmallVector<int32_t> shuffleIndices(numValuesPerLoad);
1392+
1393+
// Create a vector to store the data of the first index of each
1394+
// matrix.
1395+
VectorType vecTy = vec_ty(loadResultElemType, vBlocks);
1396+
Value firstIndexVec = b.undef(vecTy);
1397+
1398+
for (unsigned valueIndex = 0; valueIndex < numValuesPerLoad;
1399+
++valueIndex) {
1400+
unsigned firstIndexVecIdx = valueIndex / numIndicesPerMatrix;
1401+
// Handle case where an index spans two rows.
1402+
if (valueIndex % numIndicesPerMatrix == 0) {
1403+
Value oldVal = b.extract_element(ret, b.i32_val(valueIndex));
1404+
Value newVal = oldVal;
1405+
if (tileWidth < threadsPerWarp) {
1406+
assert(tileWidth * 2 == threadsPerWarp &&
1407+
"Expecting tileWidth to be 2x threadsPerWarp");
1408+
Value threadId = getThreadId(rewriter, loc);
1409+
newVal = targetInfo.shuffleIdx(
1410+
rewriter, loc, oldVal,
1411+
b.urem(threadId, b.i32_val(tileWidth)));
1412+
}
1413+
firstIndexVec =
1414+
b.insert_element(firstIndexVec.getType(), firstIndexVec,
1415+
newVal, b.i32_val(firstIndexVecIdx));
1416+
}
1417+
1418+
shuffleIndices[valueIndex] = firstIndexVecIdx;
1419+
}
1420+
DenseI32ArrayAttr attr =
1421+
rewriter.getDenseI32ArrayAttr(shuffleIndices);
1422+
ret = rewriter.create<LLVM::ShuffleVectorOp>(
1423+
loc, load2DGenXType, firstIndexVec, firstIndexVec, attr);
1424+
}
1425+
13831426
if (others.size()) {
13841427
assert(masks.size() == others.size() &&
13851428
"The mask value has to be provided when "

0 commit comments

Comments
 (0)