Skip to content

Commit a285ae2

Browse files
[TritonGEN] Lower to GenISA for 2d_block_prefetch_16b_2r8x1c (#4649)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 26af8c1 commit a285ae2

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

test/TritonGEN/tritongen-2Dblockprefetch-to-llvm.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,34 @@ llvm.func @triton_gen.2Dblockprefetch(%ptr : !llvm.ptr<1>, %base_width : i32, %b
5959

6060
// -----
6161

62+
llvm.func @triton_gen.2Dblockprefetch(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
63+
// CHECK: [[ONE0:%.*]] = llvm.mlir.constant(1 : i32) : i32
64+
// CHECK: [[PTR:%.*]] = llvm.ptrtoint %arg0 : !llvm.ptr<1> to i64
65+
// CHECK: [[VAL_63:%.*]] = llvm.mlir.constant(-64 : i64) : i64
66+
// CHECK: [[VAL_64:%.*]] = llvm.and [[PTR]], [[VAL_63]] : i64
67+
// CHECK: [[VAL_65:%.*]] = llvm.inttoptr [[VAL_64]] : i64 to !llvm.ptr<1>
68+
// CHECK: [[CL:%.*]] = llvm.mlir.constant(63 : i64) : i64
69+
// CHECK: [[AND:%.*]] = llvm.and [[PTR]], [[CL]] : i64
70+
// CHECK: [[TRUNC:%.*]] = llvm.trunc [[AND]] : i64 to i32
71+
// CHECK: [[ADD:%.*]] = llvm.add %arg1, [[TRUNC]] : i32
72+
// CHECK: [[TWO:%.*]] = llvm.mlir.constant(2 : i32) : i32
73+
// CHECK: [[SHR:%.*]] = llvm.udiv [[TRUNC]], [[TWO]] : i32
74+
// CHECK: [[X:%.*]] = llvm.add %arg4, [[SHR]] : i32
75+
// CHECK: [[BASE_ALIGNED:%.*]] = llvm.ptrtoint [[VAL_65]] : !llvm.ptr<1> to i64
76+
// CHECK: [[BASEWIDTH:%.*]] = llvm.sub [[ADD]], [[ONE0]] : i32
77+
// CHECK: [[ELEM_BITS:%.*]] = llvm.mlir.constant(16 : i32) : i32
78+
// CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(2 : i32) : i32
79+
// CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(8 : i32) : i32
80+
// CHECK: [[VBLOCKS:%.*]] = llvm.mlir.constant(1 : i32) : i32
81+
// CHECK: [[TRANSPOSE:%.*]] = llvm.mlir.constant(false) : i1
82+
// CHECK: [[VNNI:%.*]] = llvm.mlir.constant(false) : i1
83+
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockPrefetch.isVoid([[BASE_ALIGNED]], [[BASEWIDTH]], {{.*}}, [[X]], {{.*}}, [[ELEM_BITS]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[VBLOCKS]], [[TRANSPOSE]], [[VNNI]], {{.*}})
84+
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=2, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
85+
llvm.return
86+
}
87+
88+
// -----
89+
6290
llvm.func @triton_gen.2Dblockprefetch(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
6391
// CHECK: llvm.mlir.constant(2 : i32) : i32
6492
// CHECK: [[ElemSize:%.*]] = llvm.mlir.constant(2 : i32) : i32

third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,16 @@ static bool isSPVBuiltinAvailable(TritonGEN::Matrix2DBlockLoadOp op) {
144144
return true;
145145
}
146146

147+
static bool isSPVBuiltinAvailable(TritonGEN::Matrix2DBlockPrefetchOp op) {
148+
// FIXME: The following signatures are not valid in SPV interface.
149+
// intel_sub_group_2d_block_prefetch_16b_2r8x1c
150+
if (op.getElemSizeInBits() == 16 && op.getTileHeight() == 8 &&
151+
op.getTileWidth() == 2 && op.getVBlocks() == 1)
152+
return false;
153+
154+
return true;
155+
}
156+
147157
// HW requires base address to be 64-byte aligned. Compensate the non-64-byte
148158
// alignment base address by adjusting the base width and x-coordinate offset.
149159
template <
@@ -651,6 +661,12 @@ struct TritonMatrix2DBlockPrefetchLowering
651661
LogicalResult
652662
matchAndRewrite(TritonGEN::Matrix2DBlockPrefetchOp op, OpAdaptor adaptor,
653663
ConversionPatternRewriter &rewriter) const override {
664+
if (!isSPVBuiltinAvailable(op)) {
665+
// Fallback to GenISA interface.
666+
rewriter.replaceOp(op, createGenISA2DBlockPrefetch(op, rewriter));
667+
return success();
668+
}
669+
654670
MLIRContext *ctx = rewriter.getContext();
655671
Location loc = op->getLoc();
656672
auto b = TritonLLVMOpBuilder(loc, rewriter);

0 commit comments

Comments
 (0)