Skip to content

Commit c3579f0

Browse files
authored
[MLIR][XeGPU][Conversion] Add 2D block op support for sub byte types (llvm#169099)
Some usage case or shapes for 2D block op with sub byte types can be emulated with 2D block operations for non-sub byte types. Add sub byte type i4 as a valid XeGPU type. And add lowering of certain 2D block operations by emulating with larger element types.
1 parent f88d060 commit c3579f0

File tree

4 files changed

+184
-14
lines changed

4 files changed

+184
-14
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ include "mlir/Dialect/XeGPU/IR/XeGPUAttrs.td"
1313
include "mlir/Dialect/XeGPU/IR/XeGPUDialect.td"
1414
include "mlir/IR/BuiltinTypes.td"
1515

16-
def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64, UI1, UI8, UI16, UI32, UI64]>;
17-
def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
16+
def XeGPU_IntType : AnyTypeOf<[I1, I<4>, I8, I16, I32, I64, SI1, SI8, SI16,
17+
SI32, SI64, UI1, UI8, UI16, UI32, UI64]>;
18+
def XeGPU_FloatType : AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
1819
def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
1920
def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>;
2021
def XeGPU_BaseAddrType

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 80 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,14 @@ translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
150150
}
151151
}
152152

153+
//
154+
// Note:
155+
// Block operations for tile of sub byte element types are handled by
156+
// emulating with larger element types.
157+
// Tensor descriptor are keep intact and only ops consuming them are
158+
// emulated
159+
//
160+
153161
class CreateNdDescToXeVMPattern
154162
: public OpConversionPattern<xegpu::CreateNdDescOp> {
155163
using OpConversionPattern::OpConversionPattern;
@@ -262,9 +270,57 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
262270
op, "Expected offset rank to match descriptor rank.");
263271
auto elemType = tdescTy.getElementType();
264272
auto elemBitSize = elemType.getIntOrFloatBitWidth();
265-
if (elemBitSize % 8 != 0)
273+
bool isSubByte = elemBitSize < 8;
274+
uint64_t wScaleFactor = 1;
275+
276+
if (!isSubByte && (elemBitSize % 8 != 0))
266277
return rewriter.notifyMatchFailure(
267278
op, "Expected element type bit width to be multiple of 8.");
279+
auto tileW = tdescTy.getDimSize(tileRank - 1);
280+
// For sub byte types, only 4bits are currently supported.
281+
if (isSubByte) {
282+
if (elemBitSize != 4)
283+
return rewriter.notifyMatchFailure(
284+
op, "Only sub byte types of 4bits are supported.");
285+
if (tileRank != 2)
286+
return rewriter.notifyMatchFailure(
287+
op, "Sub byte types are only supported for 2D tensor descriptors.");
288+
auto subByteFactor = 8 / elemBitSize;
289+
auto tileH = tdescTy.getDimSize(0);
290+
// Handle special case for packed load.
291+
if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
292+
if (op.getPacked().value_or(false)) {
293+
// packed load is implemented as packed loads of 8bit elements.
294+
if (tileH == systolicDepth * 4 &&
295+
tileW == executionSize * subByteFactor) {
296+
// Usage case for loading as Matrix B with pack request.
297+
// source is assumed to pre-packed into 8bit elements
298+
// Emulate with 8bit loads with pack request.
299+
// scaled_tileW = executionSize
300+
elemType = rewriter.getIntegerType(8);
301+
tileW = executionSize;
302+
wScaleFactor = subByteFactor;
303+
}
304+
}
305+
}
306+
// If not handled by packed load case above, handle other cases.
307+
if (wScaleFactor == 1) {
308+
auto sub16BitFactor = subByteFactor * 2;
309+
if (tileW == executionSize * sub16BitFactor) {
310+
// Usage case for loading as Matrix A operand
311+
// Emulate with 16bit loads/stores.
312+
// scaled_tileW = executionSize
313+
elemType = rewriter.getIntegerType(16);
314+
tileW = executionSize;
315+
wScaleFactor = sub16BitFactor;
316+
} else {
317+
return rewriter.notifyMatchFailure(
318+
op, "Unsupported tile shape for sub byte types.");
319+
}
320+
}
321+
// recompute element bit size for emulation.
322+
elemBitSize = elemType.getIntOrFloatBitWidth();
323+
}
268324

269325
// Get address space from tensor descriptor memory space.
270326
auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
@@ -298,15 +354,27 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
298354
// Convert base pointer (i64) to LLVM pointer type.
299355
Value basePtrLLVM =
300356
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
357+
// FIXME: width or pitch is not the same as baseShapeW it should be the
358+
// stride of the second to last dimension in row major layout.
301359
// Compute width in bytes.
302-
Value baseWidthByte =
360+
Value baseShapeWInBytes =
303361
arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
304362
// Compute pitch in bytes.
305-
Value basePitchByte =
363+
Value basePitchBytes =
306364
arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
307365

308-
// Get tile width from the tensor descriptor type.
309-
auto tileW = tdescTy.getDimSize(tileRank - 1);
366+
if (wScaleFactor > 1) {
367+
// Scale offsetW, baseShapeWInBytes for sub byte emulation.
368+
// Note: tileW is already scaled above.
369+
Value wScaleFactorValLog2 = arith::ConstantIntOp::create(
370+
rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor));
371+
baseShapeWInBytes = arith::ShRSIOp::create(
372+
rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2);
373+
basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes,
374+
wScaleFactorValLog2);
375+
offsetW =
376+
arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2);
377+
}
310378
// Get tile height from the tensor descriptor type.
311379
auto tileH = tdescTy.getDimSize(0);
312380
// Get vblocks from the tensor descriptor type.
@@ -330,17 +398,17 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
330398
auto storeCacheControl =
331399
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
332400
xevm::BlockStore2dOp::create(
333-
rewriter, loc, basePtrLLVM, baseWidthByte, baseShapeH,
334-
basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH, src,
401+
rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
402+
basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src,
335403
xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
336404
rewriter.eraseOp(op);
337405
} else {
338406
auto loadCacheControl =
339407
translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
340408
if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
341409
xevm::BlockPrefetch2dOp::create(
342-
rewriter, loc, basePtrLLVM, baseWidthByte, baseShapeH,
343-
basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH,
410+
rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
411+
basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH,
344412
vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
345413
rewriter.eraseOp(op);
346414
} else {
@@ -354,9 +422,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
354422
: rewriter.getIntegerType(elemBitSize));
355423

356424
Value resultFlatVec = xevm::BlockLoad2dOp::create(
357-
rewriter, loc, loadedTy, basePtrLLVM, baseWidthByte, baseShapeH,
358-
basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH,
359-
vblocks, transpose, vnni,
425+
rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes,
426+
baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW,
427+
tileH, vblocks, transpose, vnni,
360428
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
361429
resultFlatVec = vector::BitCastOp::create(
362430
rewriter, loc,
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
2+
3+
gpu.module @load_store_check {
4+
// CHECK-LABEL: gpu.func @load_store_matrix_a
5+
// CHECK-SAME: %[[ARG0:.*]]: memref<16x128xi4, 1>, %[[ARG1:.*]]: memref<16x128xi4, 1>
6+
gpu.func @load_store_matrix_a(%src: memref<16x128xi4, 1>, %dst: memref<16x128xi4, 1>) kernel {
7+
// CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32
8+
// CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32
9+
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4xi64>
10+
// CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
11+
// CHECK: %[[C128_I32:.*]] = arith.constant 128 : i32
12+
// CHECK: %[[SRCCE:.*]] = memref.memory_space_cast %[[ARG0]]
13+
// CHECK: %[[SRCINDEX:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]]
14+
// CHECK: %[[SRCPTR64:.*]] = arith.index_castui %[[SRCINDEX]] : index to i64
15+
%srcce = memref.memory_space_cast %src : memref<16x128xi4, 1> to memref<16x128xi4>
16+
// CHECK: %[[DSTTE:.*]] = memref.memory_space_cast %[[ARG1]]
17+
// CHECK: %[[DSTINDEX:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]]
18+
// CHECK: %[[DSTPTR64:.*]] = arith.index_castui %[[DSTINDEX]] : index to i64
19+
%dstte = memref.memory_space_cast %dst : memref<16x128xi4, 1> to memref<16x128xi4>
20+
21+
// CHECK: %[[PAYLOAD_SRC:.*]] = vector.insert %[[SRCPTR64]], %[[CST]] [0] : i64 into vector<4xi64>
22+
// CHECK: %[[BITCAST1_SRC:.*]] = vector.bitcast %[[PAYLOAD_SRC]] : vector<4xi64> to vector<8xi32>
23+
// CHECK: %[[PAYLOAD1_SRC:.*]] = vector.insert %[[C128_I32]], %[[BITCAST1_SRC]] [2] : i32 into vector<8xi32>
24+
// CHECK: %[[PAYLOAD2_SRC:.*]] = vector.insert %[[C16_I32]], %[[PAYLOAD1_SRC]] [3] : i32 into vector<8xi32>
25+
// CHECK: %[[PAYLOAD3_SRC:.*]] = vector.insert %[[C128_I32]], %[[PAYLOAD2_SRC]] [4] : i32 into vector<8xi32>
26+
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4>
27+
28+
// CHECK: %[[BITCAST2:.*]] = vector.bitcast %[[PAYLOAD3_SRC]] : vector<8xi32> to vector<4xi64>
29+
// CHECK: %[[SRCPTR64:.*]] = vector.extract %[[BITCAST2]][0] : i64 from vector<4xi64>
30+
// CHECK: %[[SRCLLVMPTR:.*]] = llvm.inttoptr %[[SRCPTR64]] : i64 to !llvm.ptr<1>
31+
// CHECK: %[[LOADED:.*]] = xevm.blockload2d %[[SRCLLVMPTR]], %[[C64_I32]],
32+
// CHECK-SAME: %[[C16_I32]], %[[C64_I32]], %[[C16_I32]], %[[C8_I32]] <{
33+
// CHECK-SAME: cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 16 : i32,
34+
// CHECK-SAME: pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false,
35+
// CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
36+
%loaded = xegpu.load_nd %src_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
37+
: !xegpu.tensor_desc<8x64xi4> -> vector<32xi4>
38+
39+
// CHECK: %[[PAYLOAD_DST:.*]] = vector.insert %[[DSTPTR64]], %[[CST]] [0] : i64 into vector<4xi64>
40+
// CHECK: %[[BITCAST1_DST:.*]] = vector.bitcast %[[PAYLOAD_DST]] : vector<4xi64> to vector<8xi32>
41+
// CHECK: %[[PAYLOAD1_DST:.*]] = vector.insert %[[C128_I32]], %[[BITCAST1_DST]] [2] : i32 into vector<8xi32>
42+
// CHECK: %[[PAYLOAD2_DST:.*]] = vector.insert %[[C16_I32]], %[[PAYLOAD1_DST]] [3] : i32 into vector<8xi32>
43+
// CHECK: %[[PAYLOAD3_DST:.*]] = vector.insert %[[C128_I32]], %[[PAYLOAD2_DST]] [4] : i32 into vector<8xi32>
44+
%dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4, #xegpu.block_tdesc_attr<memory_space = global>>
45+
46+
// CHECK: %[[BITCAST2_DST:.*]] = vector.bitcast %[[PAYLOAD3_DST]] : vector<8xi32> to vector<4xi64>
47+
// CHECK: %[[DSTPTR64:.*]] = vector.extract %[[BITCAST2_DST]][0] : i64 from vector<4xi64>
48+
// CHECK: %[[DSTLLVMPTR:.*]] = llvm.inttoptr %[[DSTPTR64]] : i64 to !llvm.ptr<1>
49+
// CHECK: xevm.blockstore2d %[[DSTLLVMPTR]], %[[C64_I32]], %[[C16_I32]],
50+
// CHECK-SAME: %[[C64_I32]], %[[C16_I32]], %[[C8_I32]], %[[LOADED]] <{
51+
// CHECK-SAME: cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 16 : i32,
52+
// CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
53+
xegpu.store_nd %loaded, %dst_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
54+
: vector<32xi4>, !xegpu.tensor_desc<8x64xi4, #xegpu.block_tdesc_attr<memory_space = global>>
55+
gpu.return
56+
}
57+
58+
// CHECK-LABEL: gpu.func @load_matrix_b_request_pack
59+
gpu.func @load_matrix_b_request_pack(%src: memref<64x128xi4, 1>, %dst: memref<64x128xi4, 1>) kernel {
60+
// CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
61+
// CHECK: %[[C32_I32:.*]] = arith.constant 32 : i32
62+
// CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32
63+
%srcce = memref.memory_space_cast %src : memref<64x128xi4, 1> to memref<64x128xi4>
64+
%dstte = memref.memory_space_cast %dst : memref<64x128xi4, 1> to memref<64x128xi4>
65+
66+
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<64x128xi4> -> !xegpu.tensor_desc<32x32xi4>
67+
68+
// CHECK: xevm.blockload2d %{{.*}}, %[[C64_I32]], %[[C64_I32]], %[[C64_I32]], %[[C16_I32]], %[[C32_I32]] <{
69+
// CHECK-SAME: cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 8 : i32,
70+
// CHECK-SAME: pack_register = true, tile_height = 32 : i32, tile_width = 16 : i32, transpose = false,
71+
// CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
72+
%loaded = xegpu.load_nd %src_tdesc[32, 32] <{packed, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
73+
: !xegpu.tensor_desc<32x32xi4> -> vector<64xi4>
74+
75+
%c32 = arith.constant 32 : index
76+
%c0 = arith.constant 0 : index
77+
vector.store %loaded, %dstte[%c32, %c0] : memref<64x128xi4>, vector<64xi4>
78+
gpu.return
79+
}
80+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
2+
3+
gpu.module @prefetch_check {
4+
// CHECK-LABEL: gpu.func @prefetch_matrix_a
5+
gpu.func @prefetch_matrix_a(%src: memref<16x128xi4, 1>) kernel {
6+
// CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32
7+
// CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32
8+
// CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
9+
%srcce = memref.memory_space_cast %src : memref<16x128xi4, 1> to memref<16x128xi4>
10+
11+
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4>
12+
13+
// CHECK: xevm.blockprefetch2d %{{.*}}, %[[C64_I32]], %[[C16_I32]], %[[C64_I32]], %[[C16_I32]], %[[C8_I32]]
14+
// CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 16 : i32,
15+
// CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}> : (!llvm.ptr<1>
16+
xegpu.prefetch_nd %src_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
17+
: !xegpu.tensor_desc<8x64xi4>
18+
19+
gpu.return
20+
}
21+
}

0 commit comments

Comments
 (0)