Skip to content

Commit 0169c00

Browse files
[LoadStoreOpToLLVM] Improve rewriteTensorPointerStore (#4667)
This PR improves the `rewriteTensorPointerStore`: - enhanced support for 2D block store operations with proper size constraints - removed redundant early failure condition - refactored variable names and reorganizes code structure for better readability --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent 41f27f4 commit 0169c00

File tree

2 files changed

+45
-34
lines changed

2 files changed

+45
-34
lines changed

test/TritonIntelGPU/blockptr_store.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32,
9898
// CHECK: %[[BASE_PTR:.*]] = llvm.extractvalue %[[BLOCK_PTR]][6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
9999
%13 = tt.make_tensor_ptr %base, [%width, %height], [%rowStride, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x32xf16, #dpas>>
100100

101+
// CHECK: %[[HEIGHT_i32:.*]] = llvm.trunc %[[HEIGHT_i64]] : i64 to i32
102+
// CHECK: %[[ROW_STRIDE_i32:.*]] = llvm.trunc %[[ROW_STRIDE_i64]] : i64 to i32
103+
// CHECK: %[[baseWidth:.*]] = llvm.mul %[[HEIGHT_i32]], %[[CST_2]] : i32
104+
// CHECK: %[[baseHeight:.*]] = llvm.trunc %[[WIDTH_i64]] : i64 to i32
105+
// CHECK: %[[basePitch:.*]] = llvm.mul %[[ROW_STRIDE_i32]], %[[CST_2]] : i32
106+
101107
// COM: The decomposed values of the tensor with DPAS layout.
102108
// CHECK: %[[VAL_97:.*]] = llvm.extractvalue %[[VAL_71]][0]
103109
// CHECK: %[[VAL_98:.*]] = llvm.extractvalue %[[VAL_71]][1]
@@ -164,11 +170,6 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32,
164170
// CHECK: %[[VAL_159:.*]] = llvm.extractvalue %[[VAL_71]][62]
165171
// CHECK: %[[VAL_160:.*]] = llvm.extractvalue %[[VAL_71]][63]
166172

167-
// CHECK: %[[HEIGHT_i32:.*]] = llvm.trunc %[[HEIGHT_i64]] : i64 to i32
168-
// CHECK: %[[baseHeight:.*]] = llvm.trunc %[[WIDTH_i64]] : i64 to i32
169-
// CHECK: %[[ROW_STRIDE_i32:.*]] = llvm.trunc %[[ROW_STRIDE_i64]] : i64 to i32
170-
// CHECK: %[[baseWidth:.*]] = llvm.mul %[[HEIGHT_i32]], %[[CST_2]] : i32
171-
// CHECK: %[[basePitch:.*]] = llvm.mul %[[ROW_STRIDE_i32]], %[[CST_2]] : i32
172173
// CHECK: %[[VAL_166:.*]] = llvm.mlir.constant(1 : i32) : i32
173174
// CHECK: %[[outerDimWarpId:.*]] = llvm.urem %[[SUB_GROUP_ID_M]], %[[VAL_166]] : i32
174175
// CHECK: %[[VAL_168:.*]] = llvm.mlir.constant(1 : i32) : i32
@@ -181,7 +182,6 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32,
181182
// CHECK: %[[warpId1Offset:.*]] = llvm.add %[[dimWarpId1]], %[[OFFSET_1]] : i32
182183
// CHECK: %[[VAL_176:.*]] = llvm.mlir.constant(0 : i32) : i32
183184

184-
185185
// COM: The shape of DPAS layout replica is [4, 2]
186186
// COM: The replica order are [0, 1]
187187
// COM: [2, 3]

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2456,11 +2456,6 @@ struct StoreOpToBlockIOConversion
24562456
auto b = TritonLLVMOpBuilder(loc, rewriter);
24572457
Type resultType = op.getValue().getType();
24582458
auto tensorType = cast<RankedTensorType>(resultType);
2459-
2460-
// Only lower StoreOp with dpas layout encoding.
2461-
if (!hasDpasEncoding(tensorType))
2462-
return failure();
2463-
24642459
auto dpasLayout = cast<DpasEncodingAttr>(tensorType.getEncoding());
24652460
LLVMTypeConverter *typeConverter = getTypeConverter();
24662461
MLIRContext *ctx = rewriter.getContext();
@@ -2471,14 +2466,21 @@ struct StoreOpToBlockIOConversion
24712466
const ArrayRef<int64_t> tensorShape = tensorType.getShape();
24722467
size_t rank = tensorShape.size();
24732468
unsigned numElems = getTotalElemsPerThread(tensorType);
2469+
24742470
SmallVector<unsigned> elemsPerInstr = dpasLayout.getDPASInstShapeC();
2471+
// 2D block store supports 8 rows at most.
2472+
unsigned tileHeight = std::min(8u, elemsPerInstr[0]);
2473+
// 2D block store supports 64 bytes per row at most.
2474+
unsigned tileWidth = elemsPerInstr[1];
2475+
unsigned totalBytesPerRowPerMatrix = tileWidth * elemSizeInBits / 8;
2476+
if (totalBytesPerRowPerMatrix > 64)
2477+
return failure();
2478+
24752479
auto warpsPerCTA = dpasLayout.getWarpsPerCTA();
24762480
SmallVector<int64_t> numReps =
24772481
dpasLayout.getDPASRepetitions(tensorShape, 2);
24782482
SmallVector<unsigned> dpasWarpsOrder =
24792483
getMatrixOrder(warpsPerCTA.size(), /*rowMajor*/ true);
2480-
unsigned threadsPerWarp =
2481-
product<unsigned>(getThreadsPerWarp(dpasLayout, tensorShape));
24822484

24832485
Value warpId = rewriter.create<arith::IndexCastOp>(
24842486
loc, i32_ty,
@@ -2487,25 +2489,34 @@ struct StoreOpToBlockIOConversion
24872489
SmallVector<Value> multiDimWarpId = mlir::LLVM::delinearize(
24882490
rewriter, loc, warpId, warpsPerCTA, dpasWarpsOrder);
24892491

2490-
int64_t elemsPerLane = product<unsigned>(elemsPerInstr) / threadsPerWarp;
2491-
Type store2DGenXType =
2492-
LLVM::getVectorType(IntegerType::get(ctx, elemSizeInBits),
2493-
elemsPerLane); // make it opaque type.
2494-
24952492
Value blockPtr = adaptor.getPtr();
24962493
auto [base, width, height, rowStride, colStride, offsetBaseX, offsetBaseY] =
24972494
getValuesFromBlockPointerStruct(blockPtr, rewriter);
24982495

2499-
auto vals = unpackLLElements(loc, adaptor.getValue(), rewriter);
2500-
assert(vals.size() == numElems);
2501-
25022496
width = b.trunc(i32_ty, width);
2503-
height = b.trunc(i32_ty, height);
25042497
rowStride = b.trunc(i32_ty, rowStride);
25052498
// encoded as bytes.
25062499
Value baseWidth = b.mul(width, elemSizeInBytes);
2500+
Value baseHeight = b.trunc(i32_ty, height);
25072501
// encoded as bytes.
2508-
Value basePitch = b.mul(rowStride, elemSizeInBytes);
2502+
Value pitch = b.mul(rowStride, elemSizeInBytes);
2503+
// 2D block store only supports vBlocks = 1.
2504+
unsigned vBlocks = 1;
2505+
2506+
// Get the LLVM values for store values
2507+
SmallVector<Value> valElems =
2508+
unpackLLElements(loc, adaptor.getValue(), rewriter);
2509+
assert(valElems.size() == numElems &&
2510+
"the number of store values does not match the number of elements");
2511+
2512+
unsigned threadsPerWarp =
2513+
TritonGPUDialect::getThreadsPerWarp(op->getParentOfType<ModuleOp>());
2514+
2515+
int64_t elemsPerLane = tileHeight * tileWidth / threadsPerWarp;
2516+
Type opaqueType = IntegerType::get(ctx, elemSizeInBits);
2517+
Type store2DGenXType =
2518+
LLVM::getVectorType(opaqueType,
2519+
elemsPerLane); // make it opaque type.
25092520

25102521
// A warp stride for the replicates.
25112522
SmallVector<unsigned> repClusterShape = dpasLayout.getShapeC();
@@ -2538,34 +2549,34 @@ struct StoreOpToBlockIOConversion
25382549
for (int m = 0; m < numRepOuter; ++m) {
25392550
for (int n = 0; n < numRepInner; ++n) {
25402551
for (int repM = 0; repM < repCluster[0]; ++repM) {
2541-
Value offsetY =
2542-
b.add(warpId0Offset,
2543-
b.i32_val(m * replicaStride[0] + repM * elemsPerInstr[0]));
2552+
Value offsetY = b.add(warpId0Offset, b.i32_val(m * replicaStride[0] +
2553+
repM * tileHeight));
25442554
for (int repN = 0; repN < repCluster[1]; ++repN) {
25452555
Value offsetX =
2546-
b.add(warpId1Offset, b.i32_val(n * replicaStride[1] +
2547-
repN * elemsPerInstr[1]));
2556+
b.add(warpId1Offset,
2557+
b.i32_val(n * replicaStride[1] + repN * tileWidth));
2558+
25482559
Value storeVal = rewriter.create<LLVM::UndefOp>(
25492560
loc, LLVM::getVectorType(typeConverter->convertType(eltTy),
25502561
elemsPerLane));
25512562
for (size_t i = 0; i < elemsPerLane; ++i) {
25522563
storeVal =
2553-
b.insert_element(storeVal, vals[valOffset], b.i32_val(i));
2564+
b.insert_element(storeVal, valElems[valOffset], b.i32_val(i));
25542565
++valOffset;
25552566
}
25562567

25572568
auto newOp = rewriter.create<TritonGEN::Matrix2DBlockStoreOp>(
25582569
loc,
25592570
/*ptr*/ base,
25602571
/*base_width*/ baseWidth,
2561-
/*base_height*/ height,
2562-
/*base_pitch*/ basePitch,
2572+
/*base_height*/ baseHeight,
2573+
/*base_pitch*/ pitch,
25632574
/*x*/ offsetX,
25642575
/*y*/ offsetY,
25652576
/*elem_size_in_bits*/ elemSizeInBits,
2566-
/*tile_width*/ elemsPerInstr[1],
2567-
/*tile_height*/ elemsPerInstr[0],
2568-
/*v_blocks*/ 1,
2577+
/*tile_width*/ tileWidth,
2578+
/*tile_height*/ tileHeight,
2579+
/*v_blocks*/ vBlocks,
25692580
/*stored_val*/ b.bitcast(storeVal, store2DGenXType));
25702581

25712582
if (failed(newOp.verify())) {

0 commit comments

Comments
 (0)