Skip to content

Commit 5bbe0fd

Browse files
authored
[Prefetch] Prefetch lowering code cleanup (#5535)
Use the linear layout for the offsets evaluation. And fixes a bug that the base offsetX nad offsetY is not swapped for column major memory. Signed-off-by: Lu,Chengjun <[email protected]>
1 parent 20be069 commit 5bbe0fd

File tree

2 files changed

+111
-118
lines changed

2 files changed

+111
-118
lines changed

test/TritonIntelGPU/prefetch-to-llvm.mlir

Lines changed: 14 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -21,43 +21,25 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
2121
// CHECK: %[[VAL_12:.*]] = llvm.insertvalue %[[ROW_STRIDE]], %[[VAL_11]][4] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
2222
// CHECK: %[[VAL_13:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_12]][5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
2323
// CHECK: %[[BLOCK_POINTER:.*]] = llvm.insertvalue %[[BASE]], %[[VAL_13]][6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
24-
// CHECK: %[[SUB_GROUP_ID_RAW:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32
25-
// CHECK: %[[SUB_GROUP_ID_EXT:.*]] = llvm.zext %[[SUB_GROUP_ID_RAW]] : i32 to i64
26-
// CHECK: %[[SUB_GROUP_ID:.*]] = llvm.trunc %[[SUB_GROUP_ID_EXT]] : i64 to i32
27-
// CHECK: %[[VAL_18:.*]] = llvm.mlir.constant(1 : i32) : i32
28-
// CHECK: %[[VAL_19:.*]] = llvm.urem %[[SUB_GROUP_ID]], %[[VAL_18]] : i32
29-
// CHECK: %[[VAL_20:.*]] = llvm.udiv %[[SUB_GROUP_ID]], %[[VAL_18]] : i32
30-
// CHECK: %[[CST_8:.*]] = llvm.mlir.constant(8 : i32) : i32
31-
// CHECK: %[[VAL_22:.*]] = llvm.urem %[[VAL_20]], %[[CST_8]] : i32
32-
// CHECK: %[[VAL_23:.*]] = llvm.udiv %[[VAL_20]], %[[CST_8]] : i32
3324
// CHECK: %[[OFFSET_0:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][0] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
3425
// CHECK: %[[OFFSET_1:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][1] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
35-
// CHECK: %[[WIDTH_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][2] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
36-
// CHECK: %[[HEIGHT_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][3] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
26+
// CHECK: %[[HEIGHT_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][2] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
27+
// CHECK: %[[WIDTH_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][3] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
3728
// CHECK: %[[ROW_STRIDE_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][4] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
3829
// CHECK: %[[COL_STRIDE_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
3930
// CHECK: %[[BASE_:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
4031
// CHECK: %[[CST_2:.*]] = llvm.mlir.constant(2 : i64) : i64
41-
// CHECK: %[[VAL_21:.*]] = llvm.mul %[[HEIGHT_i64]], %[[CST_2]] : i64
32+
// CHECK: %[[VAL_21:.*]] = llvm.mul %[[WIDTH_i64]], %[[CST_2]] : i64
4233
// CHECK: %[[ROW_MAJOR_BASE_WIDTH:.*]] = llvm.trunc %[[VAL_21]] : i64 to i32
43-
// CHECK: %[[ROW_MAJOR_BASE_HEIGHT:.*]] = llvm.trunc %[[WIDTH_i64]] : i64 to i32
34+
// CHECK: %[[ROW_MAJOR_BASE_HEIGHT:.*]] = llvm.trunc %[[HEIGHT_i64]] : i64 to i32
4435
// CHECK: %[[CST_2:.*]] = llvm.mlir.constant(2 : i64) : i64
4536
// CHECK: %[[VAL_24:.*]] = llvm.mul %[[ROW_STRIDE_i64]], %[[CST_2]] : i64
4637
// CHECK: %[[ROW_MAJOR_PITCH:.*]] = llvm.trunc %[[VAL_24]] : i64 to i32
47-
// CHECK: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32
48-
// CHECK: %[[CST_32:.*]] = llvm.mlir.constant(32 : i32) : i32
49-
// CHECK: %[[VAL_26:.*]] = llvm.mul %[[VAL_19]], %[[CST_32]] : i32
50-
// CHECK: %[[VAL_27:.*]] = llvm.add %[[VAL_26]], %[[CST_0]] : i32
51-
// CHECK: %[[CST_32:.*]] = llvm.mlir.constant(32 : i32) : i32
52-
// CHECK: %[[VAL_28:.*]] = llvm.urem %[[VAL_27]], %[[CST_32]] : i32
53-
// CHECK: %[[ROW_MAJOR_OFFSET_X:.*]] = llvm.add %[[VAL_28]], %[[OFFSET_1]] : i32
54-
// CHECK: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32
55-
// CHECK: %[[CST_2:.*]] = llvm.mlir.constant(2 : i32) : i32
56-
// CHECK: %[[VAL_30:.*]] = llvm.mul %[[VAL_22]], %[[CST_2]] : i32
57-
// CHECK: %[[VAL_31:.*]] = llvm.add %[[VAL_30]], %[[CST_0]] : i32
58-
// CHECK: %[[CST_16:.*]] = llvm.mlir.constant(16 : i32) : i32
59-
// CHECK: %[[VAL_32:.*]] = llvm.urem %[[VAL_31]], %[[CST_16]] : i32
60-
// CHECK: %[[ROW_MAJOR_OFFSET_Y:.*]] = llvm.add %[[VAL_32]], %[[OFFSET_0]] : i32
38+
// CHECK: %[[SUB_GROUP_ID_RAW:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32
39+
// CHECK: %[[SUB_GROUP_ID_EXT:.*]] = llvm.zext %[[SUB_GROUP_ID_RAW]] : i32 to i64
40+
// CHECK: %[[SUB_GROUP_ID:.*]] = llvm.trunc %[[SUB_GROUP_ID_EXT]] : i64 to i32
41+
// CHECK: %[[ROW_MAJOR_OFFSET_X:.*]] = llvm.add {{.*}}, %[[OFFSET_1]] : i32
42+
// CHECK: %[[ROW_MAJOR_OFFSET_Y:.*]] = llvm.add {{.*}}, %[[OFFSET_0]] : i32
6143
// CHECK: triton_gen.2Dblockprefetch %[[BASE_]], %[[ROW_MAJOR_BASE_WIDTH]], %[[ROW_MAJOR_BASE_HEIGHT]], %[[ROW_MAJOR_PITCH]], %[[ROW_MAJOR_OFFSET_X]], %[[ROW_MAJOR_OFFSET_Y]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 2, v_blocks = 2, cache_control = L1C_L3C}
6244
%rowMajorPtr = tt.make_tensor_ptr %arg0, [%arg2, %arg4], [%arg5, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<16x32xf16>>
6345
ttig.prefetch %rowMajorPtr {cache = 1 : i32, evict = 1 : i32, isVolatile = false, ttig.block_io = "row_major"} : !tt.ptr<tensor<16x32xf16>>
@@ -70,15 +52,6 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
7052
// CHECK: %[[VAL_12:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_11]][4] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
7153
// CHECK: %[[VAL_13:.*]] = llvm.insertvalue %[[ROW_STRIDE]], %[[VAL_12]][5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
7254
// CHECK: %[[BLOCK_POINTER:.*]] = llvm.insertvalue %[[BASE]], %[[VAL_13]][6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
73-
// CHECK: %[[SUB_GROUP_ID_RAW:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32
74-
// CHECK: %[[SUB_GROUP_ID_EXT:.*]] = llvm.zext %[[SUB_GROUP_ID_RAW]] : i32 to i64
75-
// CHECK: %[[SUB_GROUP_ID:.*]] = llvm.trunc %[[SUB_GROUP_ID_EXT]] : i64 to i32
76-
// CHECK: %[[VAL_18:.*]] = llvm.mlir.constant(1 : i32) : i32
77-
// CHECK: %[[VAL_19:.*]] = llvm.urem %[[SUB_GROUP_ID]], %[[VAL_18]] : i32
78-
// CHECK: %[[VAL_20:.*]] = llvm.udiv %[[SUB_GROUP_ID]], %[[VAL_18]] : i32
79-
// CHECK: %[[CST_8:.*]] = llvm.mlir.constant(8 : i32) : i32
80-
// CHECK: %[[VAL_22:.*]] = llvm.urem %[[VAL_20]], %[[CST_8]] : i32
81-
// CHECK: %[[VAL_23:.*]] = llvm.udiv %[[VAL_20]], %[[CST_8]] : i32
8255
// CHECK: %[[OFFSET_0:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][0] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
8356
// CHECK: %[[OFFSET_1:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][1] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
8457
// CHECK: %[[WIDTH_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][2] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
@@ -93,20 +66,11 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
9366
// CHECK: %[[CST_2:.*]] = llvm.mlir.constant(2 : i64) : i64
9467
// CHECK: %[[VAL_24:.*]] = llvm.mul %[[COL_STRIDE_i64]], %[[CST_2]] : i64
9568
// CHECK: %[[COL_MAJOR_PITCH:.*]] = llvm.trunc %[[VAL_24]] : i64 to i32
96-
// CHECK: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32
97-
// CHECK: %[[CST_32:.*]] = llvm.mlir.constant(32 : i32) : i32
98-
// CHECK: %[[VAL_26:.*]] = llvm.mul %[[VAL_19]], %[[CST_32]] : i32
99-
// CHECK: %[[VAL_27:.*]] = llvm.add %[[VAL_26]], %[[CST_0]] : i32
100-
// CHECK: %[[CST_32:.*]] = llvm.mlir.constant(32 : i32) : i32
101-
// CHECK: %[[VAL_28:.*]] = llvm.urem %[[VAL_27]], %[[CST_32]] : i32
102-
// CHECK: %[[COL_MAJOR_OFFSET_X:.*]] = llvm.add %[[VAL_28]], %[[OFFSET_1]] : i32
103-
// CHECK: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32
104-
// CHECK: %[[CST_2:.*]] = llvm.mlir.constant(2 : i32) : i32
105-
// CHECK: %[[VAL_30:.*]] = llvm.mul %[[VAL_22]], %[[CST_2]] : i32
106-
// CHECK: %[[VAL_31:.*]] = llvm.add %[[VAL_30]], %[[CST_0]] : i32
107-
// CHECK: %[[CST_16:.*]] = llvm.mlir.constant(16 : i32) : i32
108-
// CHECK: %[[VAL_32:.*]] = llvm.urem %[[VAL_31]], %[[CST_16]] : i32
109-
// CHECK: %[[COL_MAJOR_OFFSET_Y:.*]] = llvm.add %[[VAL_32]], %[[OFFSET_0]] : i32
69+
// CHECK: %[[SUB_GROUP_ID_RAW:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32
70+
// CHECK: %[[SUB_GROUP_ID_EXT:.*]] = llvm.zext %[[SUB_GROUP_ID_RAW]] : i32 to i64
71+
// CHECK: %[[SUB_GROUP_ID:.*]] = llvm.trunc %[[SUB_GROUP_ID_EXT]] : i64 to i32
72+
// CHECK: %[[COL_MAJOR_OFFSET_X:.*]] = llvm.add {{.*}}, %[[OFFSET_0]] : i32
73+
// CHECK: %[[COL_MAJOR_OFFSET_Y:.*]] = llvm.add {{.*}}, %[[OFFSET_1]] : i32
11074
// CHECK: triton_gen.2Dblockprefetch %[[BASE_]], %[[COL_MAJOR_BASE_WIDTH]], %[[COL_MAJOR_BASE_HEIGHT]], %[[COL_MAJOR_PITCH]], %[[COL_MAJOR_OFFSET_X]], %[[COL_MAJOR_OFFSET_Y]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 2, v_blocks = 2, cache_control = L1C_L3C}
11175
%columnMajorPtr = tt.make_tensor_ptr %arg0, [%arg4, %arg2], [%c1_i64, %arg5], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<32x16xf16>>
11276
ttig.prefetch %columnMajorPtr {cache = 1 : i32, evict = 1 : i32, isVolatile = false, ttig.block_io = "column_major"} : !tt.ptr<tensor<32x16xf16>>

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 97 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -706,8 +706,6 @@ struct PrefetchOpConversion
706706
Attribute blockIOAttr =
707707
op->getAttr(TritonIntelGPUDialect::getBlockIOAttrName());
708708
if (!blockIOAttr) {
709-
// TODO: Fallback to gather semantic prefetching. Simply erase the
710-
// prefetching op which is not supported for now.
711709
rewriter.eraseOp(op);
712710
return success();
713711
}
@@ -727,33 +725,20 @@ struct PrefetchOpConversion
727725
// Swap the shape to make it row major and then get the tiling
728726
// size base on row major shape.
729727
std::swap(tensorShape[0], tensorShape[1]);
730-
731-
// Create the new tensor type with swapped row and col.
732-
tensorType = RankedTensorType::get(
733-
tensorShape, tensorType.getElementType(), tensorType.getEncoding());
734728
}
735-
736729
unsigned numWarps = triton::gpu::lookupNumWarps(op);
737730

738-
SmallVector<unsigned, 2> shapePerWarp =
739-
get2DPrefetchShapePerWarp(tensorType);
731+
auto [tileHeightInElem, tileWidthInElem, warpsM, warpsN] =
732+
get2DPrefetchWarpsPerCTA(tensorShape, eltTy, numWarps);
740733

741-
SmallVector<unsigned, 2> warpsPerCTA =
742-
getWarpsPerCTA(tensorShape, shapePerWarp, numWarps);
734+
auto llEncoding = getLinearLayout(
735+
tensorShape, {tileHeightInElem, tileWidthInElem}, {warpsM, warpsN});
743736

744-
// To adjust the row shape per warp to fit the tensor shape and avoid
745-
// duplication in prefetching.
746-
unsigned factor =
747-
mlir::ceil(shapePerWarp[0] * warpsPerCTA[0], (unsigned)tensorShape[0]);
748-
shapePerWarp[0] = mlir::ceil(shapePerWarp[0], factor);
749-
750-
SmallVector<int64_t> numReps = {
751-
mlir::ceil<int64_t>(tensorShape[0], shapePerWarp[0] * warpsPerCTA[0]),
752-
mlir::ceil<int64_t>(tensorShape[1], shapePerWarp[1] * warpsPerCTA[1])};
737+
unsigned tileSizeInElem = tileHeightInElem * tileWidthInElem;
738+
unsigned numTilesPerWarp =
739+
(tensorShape[0] * tensorShape[1]) / (tileSizeInElem * numWarps);
753740

754741
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth();
755-
unsigned tileWidthInElem = shapePerWarp[1];
756-
unsigned tileHeightInElem = shapePerWarp[0];
757742
unsigned vBlocks = 1;
758743
switch (elemSizeInBits) {
759744
case 8:
@@ -774,12 +759,6 @@ struct PrefetchOpConversion
774759
break;
775760
}
776761

777-
Value warpId = rewriter.create<arith::IndexCastOp>(
778-
loc, i32_ty,
779-
rewriter.create<mlir::gpu::SubgroupIdOp>(loc, /*upperBound=*/nullptr));
780-
SmallVector<Value> multiDimWarpId =
781-
mlir::LLVM::delinearize(rewriter, loc, warpId, warpsPerCTA, {1, 0});
782-
783762
auto [base, baseWidth, baseHeight, rowStride, colStride, offsetBaseX,
784763
offsetBaseY] =
785764
getValuesFromBlockPointerStruct(adaptor.getPtr(), rewriter);
@@ -788,6 +767,7 @@ struct PrefetchOpConversion
788767
// Swap the width/height and strides to the row major.
789768
std::swap(baseWidth, baseHeight);
790769
std::swap(colStride, rowStride);
770+
std::swap(offsetBaseX, offsetBaseY);
791771
}
792772

793773
baseWidth = b.mul(baseWidth, b.i64_val(eltTy.getIntOrFloatBitWidth() / 8));
@@ -799,46 +779,43 @@ struct PrefetchOpConversion
799779
b.mul(rowStride, b.i64_val(eltTy.getIntOrFloatBitWidth() / 8));
800780
rowStrideInBytes = b.trunc(i32_ty, rowStrideInBytes);
801781

802-
for (int row = 0; row < numReps[0]; ++row) {
803-
for (int col = 0; col < numReps[1]; ++col) {
804-
Value offsetX, offsetY;
805-
offsetX = b.add(
806-
// the offset of this warp.
807-
b.mul(multiDimWarpId[1], b.i32_val(shapePerWarp[1])),
808-
// add the replica offset with a warp stride.
809-
b.i32_val(col * warpsPerCTA[1] * shapePerWarp[1]));
810-
// Round the offset into to the tensor shape
811-
offsetX = b.urem(offsetX, b.i32_val(tensorShape[1]));
812-
offsetX = b.add(offsetX, offsetBaseX);
813-
offsetY = b.add(
814-
// the offset of this warp.
815-
b.mul(multiDimWarpId[0], b.i32_val(shapePerWarp[0])),
816-
// add the replica offset with a warp stride.
817-
b.i32_val(row * warpsPerCTA[0] * shapePerWarp[0]));
818-
// Round the offset into to the tensor shape
819-
offsetY = b.urem(offsetY, b.i32_val(tensorShape[0]));
820-
offsetY = b.add(offsetY, offsetBaseY);
821-
822-
auto newOp = rewriter.create<TritonGEN::Matrix2DBlockPrefetchOp>(
823-
loc,
824-
/*ptr*/ base,
825-
/*base_width*/ baseWidth,
826-
/*base_height*/ baseHeight,
827-
/*base_pitch*/ rowStrideInBytes,
828-
/*x*/ offsetX,
829-
/*y*/ offsetY,
830-
/*elem_size_in_bits*/ elemSizeInBits,
831-
/*tile_width*/ tileWidthInElem,
832-
/*tile_height*/ tileHeightInElem,
833-
/*v_blocks*/ vBlocks,
834-
/*cache_opt*/ TritonGEN::LoadCacheControl::L1C_L3C);
835-
if (failed(newOp.verify())) {
836-
// delete the op so that the verifier will not abort the pass
837-
// pipeline later, as we can fail this path and try a different
838-
// approach.
839-
rewriter.eraseOp(newOp);
840-
return failure();
841-
}
782+
MLIRContext *ctx = getContext();
783+
StringAttr kOffset = S("offset");
784+
StringAttr kWarp = S("warp");
785+
StringAttr kBlock = S("block");
786+
787+
Value warpId = rewriter.create<arith::IndexCastOp>(
788+
loc, i32_ty,
789+
rewriter.create<mlir::gpu::SubgroupIdOp>(loc,
790+
/*upperBound=*/nullptr));
791+
792+
for (unsigned tile = 0; tile < numTilesPerWarp; ++tile) {
793+
unsigned off = tile * tileSizeInElem;
794+
auto offsets = applyLinearLayout(
795+
loc, rewriter, llEncoding,
796+
{{kOffset, b.i32_val(off)}, {kWarp, warpId}, {kBlock, b.i32_val(0)}});
797+
Value offsetX = b.add(offsets[1].second, offsetBaseX);
798+
Value offsetY = b.add(offsets[0].second, offsetBaseY);
799+
800+
auto newOp = rewriter.create<TritonGEN::Matrix2DBlockPrefetchOp>(
801+
loc,
802+
/*ptr*/ base,
803+
/*base_width*/ baseWidth,
804+
/*base_height*/ baseHeight,
805+
/*base_pitch*/ rowStrideInBytes,
806+
/*x*/ offsetX,
807+
/*y*/ offsetY,
808+
/*elem_size_in_bits*/ elemSizeInBits,
809+
/*tile_width*/ tileWidthInElem,
810+
/*tile_height*/ tileHeightInElem,
811+
/*v_blocks*/ vBlocks,
812+
/*cache_opt*/ TritonGEN::LoadCacheControl::L1C_L3C);
813+
if (failed(newOp.verify())) {
814+
// delete the op so that the verifier will not abort the pass
815+
// pipeline later, as we can fail this path and try a different
816+
// approach.
817+
rewriter.eraseOp(newOp);
818+
return failure();
842819
}
843820
}
844821

@@ -1050,6 +1027,58 @@ struct PrefetchOpConversion
10501027
rewriter.eraseOp(op);
10511028
return success();
10521029
}
1030+
1031+
private:
1032+
// tensor shape has to be in row major.
1033+
// Returns:
1034+
// Prefetch Op Shape in {M, N}
1035+
// Warps per CTA in {M, N}
1036+
std::tuple<unsigned, unsigned, unsigned, unsigned>
1037+
get2DPrefetchWarpsPerCTA(const ArrayRef<int64_t> tensorShape, Type eltTy,
1038+
unsigned numWarps) const {
1039+
unsigned rank = tensorShape.size();
1040+
assert(rank >= 2 && "Only rank >= 2 tensor is supported for now");
1041+
unsigned dimM = rank - 2, dimN = rank - 1;
1042+
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth();
1043+
unsigned elemSizeInBytes = elemSizeInBits / 8;
1044+
constexpr unsigned maxBytesPerRow = 64;
1045+
unsigned numColsPerPrefOps =
1046+
std::min<unsigned>(tensorShape[dimN], maxBytesPerRow / elemSizeInBytes);
1047+
1048+
unsigned repNumN =
1049+
mlir::ceil((unsigned)tensorShape[dimN], numColsPerPrefOps);
1050+
unsigned warpsNumN = std::min(numWarps, repNumN);
1051+
unsigned warpsNumM = mlir::ceil(numWarps, warpsNumN);
1052+
1053+
// Get the number of rows per warp to fit the shape to the tensor shape to
1054+
// avoid duplication in prefetching.
1055+
unsigned rowNumPerWarp = mlir::ceil<unsigned>(tensorShape[dimM], warpsNumM);
1056+
unsigned numRowsPerPrefOps = std::min<unsigned>(rowNumPerWarp, 32);
1057+
SmallVector<unsigned, 2> tilePerPrefOps{numRowsPerPrefOps,
1058+
numColsPerPrefOps};
1059+
1060+
return {numRowsPerPrefOps, numColsPerPrefOps, warpsNumM, warpsNumN};
1061+
}
1062+
1063+
// Get the linear layout for the cooperative prefetching.
1064+
LinearLayout getLinearLayout(const ArrayRef<int64_t> tensorShape,
1065+
const ArrayRef<unsigned> tileShape,
1066+
const ArrayRef<unsigned> warpsPerCTA) const {
1067+
MLIRContext *ctx = getContext();
1068+
unsigned rank = warpsPerCTA.size();
1069+
assert(rank >= 2 && "Only rank >= 2 tensor is supported for now");
1070+
SmallVector<unsigned> order(rank);
1071+
for (size_t i = 0; i < warpsPerCTA.size(); ++i) {
1072+
// The fastest change dim is the first.
1073+
order[i] = rank - i - 1;
1074+
}
1075+
LinearLayout ctaLayout = identityStandardND(S("offset"), tileShape, order) *
1076+
identityStandardND(S("warp"), warpsPerCTA, order);
1077+
1078+
return combineCtaCgaWithShape(std::move(ctaLayout),
1079+
CTALayoutAttr::getDefault(ctx, rank),
1080+
tensorShape);
1081+
}
10531082
};
10541083

10551084
struct LoadOpToBlockIOConversion

0 commit comments

Comments
 (0)