Skip to content

Commit 187b4af

Browse files
authored
Handle sub byte type in XeGPU lowering to 2D LSC. (#1061)
1 parent 9728975 commit 187b4af

File tree

4 files changed

+112
-1
lines changed

4 files changed

+112
-1
lines changed

lib/Conversion/XeGPUToVC/LSCPatterns.cpp

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,31 @@ class LoadNdPattern : public OpConversionPattern<LoadNdOp> {
840840
op, "Only global access supported for block load.");
841841
auto payload = adaptor.getTensorDesc();
842842
auto retTy = op.getType();
843+
auto bitWidth = elemTy.getIntOrFloatBitWidth();
844+
if (bitWidth < 8) {
845+
if (8 % bitWidth != 0)
846+
return rewriter.notifyMatchFailure(
847+
op, "Only sub byte type with bit-width 1, 2, 4, or 8 are "
848+
"supported for block load.");
849+
auto subByteFactor = 8 / bitWidth;
850+
// For supported sub byte type,
851+
// fake element type to i8 and update elemTy, retTy and tdescTy
852+
// accordingly. Add cast before and after intrinsic call to ensure the
853+
// type matches the original type.
854+
elemTy = rewriter.getI8Type();
855+
auto shape = tdescTy.getShape().vec();
856+
auto lastDim = shape.size() - 1;
857+
if (shape[lastDim] % subByteFactor != 0) {
858+
return rewriter.notifyMatchFailure(
859+
op, "The last dimension but be a multiple of (8 / bitWidth) for "
860+
"sub byte types.");
861+
}
862+
shape[lastDim] = shape[lastDim] / subByteFactor;
863+
tdescTy = TensorDescType::get(tdescTy.getContext(), shape, elemTy,
864+
tdescTy.getEncoding(),
865+
/*sg_map*/ nullptr);
866+
retTy = VectorType::get(tdescTy.getShape(), elemTy);
867+
}
843868

844869
// TODO: remove this after moving transposeBitWidth into a standalone
845870
// pass. update the width and pictch of the payload when transposeBitWidth
@@ -908,6 +933,8 @@ class LoadNdPattern : public OpConversionPattern<LoadNdOp> {
908933

909934
// TODO: remove this after moving transposeBitWidth into a standalone
910935
// pass.
936+
// NOTE: sub byte type handling also needs the bitcast to the original
937+
// type after the intrinsic call.
911938
if (retTy != op.getType()) {
912939
auto targetTy = convertVectorType(op.getType()).second;
913940
callOp = rewriter.create<vector::BitCastOp>(loc, targetTy, callOp);
@@ -959,6 +986,31 @@ class PrefetchNdPattern : public OpConversionPattern<PrefetchNdOp> {
959986
if (scope != xegpu::MemorySpace::Global)
960987
return rewriter.notifyMatchFailure(
961988
op, "Only global access supported for block prefetch.");
989+
auto elemTy = tdescTy.getElementType();
990+
auto bitWidth = elemTy.getIntOrFloatBitWidth();
991+
if (bitWidth < 8) {
992+
if (8 % bitWidth != 0)
993+
return rewriter.notifyMatchFailure(
994+
op, "Only sub byte type with bit-width 1, 2, 4, or 8 are "
995+
"supported for block prefetch.");
996+
auto subByteFactor = 8 / bitWidth;
997+
// For supported sub byte type,
998+
// fake element type to i8 and update elemTy, retTy and tdescTy
999+
// accordingly. Add cast before and after intrinsic call to ensure the
1000+
// type matches the original type.
1001+
elemTy = rewriter.getI8Type();
1002+
auto shape = tdescTy.getShape().vec();
1003+
auto lastDim = shape.size() - 1;
1004+
if (shape[lastDim] % subByteFactor != 0) {
1005+
return rewriter.notifyMatchFailure(
1006+
op, "The last dimension but be a multiple of (8 / bitWidth) for "
1007+
"sub byte types.");
1008+
}
1009+
shape[lastDim] = shape[lastDim] / subByteFactor;
1010+
tdescTy = TensorDescType::get(tdescTy.getContext(), shape, elemTy,
1011+
tdescTy.getEncoding(),
1012+
/*sg_map*/ nullptr);
1013+
}
9621014
auto callOp = gen2DPrefetchIntrinsicCall(
9631015
rewriter, loc, l1hint, l3hint, tdescTy, adaptor.getTensorDesc());
9641016
rewriter.replaceOp(op, callOp);
@@ -1010,7 +1062,33 @@ class StoreNdPattern : public OpConversionPattern<StoreNdOp> {
10101062
if (scope != xegpu::MemorySpace::Global)
10111063
return rewriter.notifyMatchFailure(
10121064
op, "Only global access supported for block store.");
1013-
1065+
auto elemTy = tdescTy.getElementType();
1066+
auto bitWidth = elemTy.getIntOrFloatBitWidth();
1067+
if (bitWidth < 8) {
1068+
if (8 % bitWidth != 0)
1069+
return rewriter.notifyMatchFailure(
1070+
op, "Only sub byte type with bit-width 1, 2, 4, or 8 are "
1071+
"supported for block store.");
1072+
auto subByteFactor = 8 / bitWidth;
1073+
// For supported sub byte type,
1074+
// fake element type to i8 and update elemTy, retTy and tdescTy
1075+
// accordingly. Add cast before and after intrinsic call to ensure the
1076+
// type matches the original type.
1077+
elemTy = rewriter.getI8Type();
1078+
auto shape = tdescTy.getShape().vec();
1079+
auto lastDim = shape.size() - 1;
1080+
if (shape[lastDim] % subByteFactor != 0) {
1081+
return rewriter.notifyMatchFailure(
1082+
op, "The last dimension but be a multiple of (8 / bitWidth) for "
1083+
"sub byte types.");
1084+
}
1085+
shape[lastDim] = shape[lastDim] / subByteFactor;
1086+
tdescTy = TensorDescType::get(tdescTy.getContext(), shape, elemTy,
1087+
tdescTy.getEncoding(),
1088+
/*sg_map*/ nullptr);
1089+
auto dataTy = VectorType::get({tdescTy.getNumElements()}, elemTy);
1090+
data = rewriter.create<vector::BitCastOp>(loc, dataTy, data);
1091+
}
10141092
auto callOp =
10151093
gen2DStoreIntrinsicCall(rewriter, loc, l1hint, l3hint, tdescTy,
10161094
adaptor.getTensorDesc(), data);

test/Conversion/XeGPUToVC/load_nd.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,18 @@ module @gemm attributes {gpu.container_module} {
3535
gpu.return
3636
}
3737

38+
// CHECK: gpu.func @test_load_nd_subbyte(%[[arg0:.*]]: memref<8x256xi1>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
39+
gpu.func @test_load_nd_subbyte(%arg0: memref<8x256xi1>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>}{
40+
%0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x256xi1> -> !xegpu.tensor_desc<8x256xi1>
41+
// CHECK: %[[V10:.*]] = func.call @llvm.genx.lsc.load.2d.ugm.desc.v256i8.v2i8({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, vector<256xi8>) -> vector<256xi8>
42+
// CHECK: %[[V11:.*]] = vector.bitcast %[[V10]] : vector<256xi8> to vector<2048xi1>
43+
// CHECK: %[[V12:.*]] = vector.shape_cast %[[V11]] : vector<2048xi1> to vector<8x256xi1>
44+
%1 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x256xi1> -> vector<8x256xi1>
45+
%cst0 = arith.constant 0 : index
46+
vector.store %1, %arg0[%cst0, %cst0] : memref<8x256xi1>, vector<8x256xi1>
47+
gpu.return
48+
}
49+
3850
// CHECK: gpu.func @test_load_nd_1(%[[arg0:.*]]: memref<8x16xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
3951
gpu.func @test_load_nd_1(%arg0: memref<8x16xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>}{
4052
//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %{{.*}} : memref<8x16xf16> -> index

test/Conversion/XeGPUToVC/prefetchnd.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,15 @@ module @two_type attributes {gpu.container_module} {
116116
}
117117
}
118118
}
119+
120+
// -----
121+
module @subbyte attributes {gpu.container_module} {
122+
gpu.module @test_kernel {
123+
gpu.func @test_prefetch(%arg0: memref<8x256xi1>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
124+
%0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x256xi1> -> !xegpu.tensor_desc<8x256xi1>
125+
// CHECK: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8.i8({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i8) -> ()
126+
xegpu.prefetch_nd %0 : !xegpu.tensor_desc<8x256xi1>
127+
gpu.return
128+
}
129+
}
130+
}

test/Conversion/XeGPUToVC/store_nd.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,15 @@ module @gemm attributes {gpu.container_module} {
3838
gpu.return
3939
}
4040

41+
gpu.func @test_store_nd_subbyte(%arg0: memref<8x256xi1>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>}{
42+
%c = arith.constant dense<1> : vector<8x256xi1>
43+
%0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x256xi1> -> !xegpu.tensor_desc<8x256xi1>
44+
// CHECK: %[[V10:.*]] = vector.bitcast {{.*}} : vector<2048xi1> to vector<256xi8>
45+
// CHECK: func.call @llvm.genx.lsc.store.2d.ugm.desc.v2i8.v256i8({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[V10]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, vector<256xi8>) -> ()
46+
xegpu.store_nd %c, %0 : vector<8x256xi1>, !xegpu.tensor_desc<8x256xi1>
47+
gpu.return
48+
}
49+
4150
// CHECK: gpu.func @test_store_nd_1d_strided_memref(%[[arg0:.*]]: memref<32x32xf32, strided<[64, 1]>>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
4251
gpu.func @test_store_nd_1d_strided_memref(%arg0: memref<32x32xf32, strided<[64,1], offset: 0>>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>}{
4352

0 commit comments

Comments
 (0)