Skip to content

Commit 255c28e

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents 7be3da8 + faaa6a1 commit 255c28e

File tree

5 files changed

+260
-20
lines changed

5 files changed

+260
-20
lines changed

lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -338,15 +338,6 @@ void GPUXToSPIRVPass::runOnOperation() {
338338
if (rank < 1 || type.getNumElements() == 1)
339339
return elemType;
340340

341-
// load2d/store2d is 3-d with vnni format, and 4d with array_length
342-
// TODO: what if load without any vnni? are we going to transform all
343-
// fp16/bf16
344-
auto factor = 32 / elemType.getIntOrFloatBitWidth();
345-
if ((rank == 3 || rank == 4) && type.getShape()[rank - 1] == factor) {
346-
elemType = ::mlir::IntegerType::get(context, 32);
347-
rank--;
348-
}
349-
350341
unsigned sum = 1;
351342
for (unsigned i = 0; i < rank; i++) {
352343
sum *= type.getShape()[i];

lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,20 @@ encodeVectorType(ConversionPatternRewriter &rewriter, VectorType type,
102102
auto newType = VectorType::get(size, elemType);
103103
return std::make_pair(str, newType);
104104
}
105+
106+
/// @brief
107+
/// We have to use i32 for intrinsic calls like llvm_genx_raw_send2_*, if we
108+
/// want to get the original element type (e.g., f16) as the result of a load,
109+
/// we have to encode the resulting i32 vector back to it.
110+
VectorType encodeVectorTypeTo(VectorType currentVecType, Type toElemType) {
111+
auto elemType = currentVecType.getElementType();
112+
auto currentbitWidth = elemType.getIntOrFloatBitWidth();
113+
auto newBitwidth = toElemType.getIntOrFloatBitWidth();
114+
const int size =
115+
currentVecType.getNumElements() * currentbitWidth / newBitwidth;
116+
return VectorType::get(size, toElemType);
117+
}
118+
105119
unsigned encodeDataum(Type type) {
106120
switch (type.getIntOrFloatBitWidth()) {
107121
case 8:
@@ -555,7 +569,17 @@ class LoadStorePrefetchNdToLsc : public OpConversionPattern<OpType> {
555569
auto funcOp =
556570
rewriter.create<spirv::FunctionCallOp>(loc, retType, funcName, args);
557571
if (rank == 2) {
558-
rewriter.replaceOp(op, funcOp);
572+
// Intrinsic accepts and returns i32 type, but we want to return a
573+
// vector of the original element type
574+
auto loadResultInOrigType =
575+
encodeVectorTypeTo(retType, tileType.getElementType());
576+
if (loadResultInOrigType != funcOp->getResult(0).getType()) {
577+
auto cast = rewriter.create<spirv::BitcastOp>(
578+
loc, loadResultInOrigType, funcOp->getResult(0));
579+
rewriter.replaceOp(op, cast);
580+
} else {
581+
rewriter.replaceOp(op, funcOp);
582+
}
559583
} else {
560584
auto cast = rewriter.create<spirv::BitcastOp>(loc, op.getType(),
561585
funcOp->getResult(0));
@@ -745,7 +769,16 @@ class LoadStorePrefetchNdToRawSend : public OpConversionPattern<OpType> {
745769
auto funcOp =
746770
rewriter.create<spirv::FunctionCallOp>(loc, retType, funcName, args);
747771
if (rank == 2) {
748-
rewriter.replaceOp(op, funcOp);
772+
// Intrinsic accepts and returns i32 type, but we want to return a
773+
// vector of the original element type
774+
auto loadResultInOrigType = encodeVectorTypeTo(newType, elmType);
775+
if (loadResultInOrigType != funcOp->getResult(0).getType()) {
776+
auto cast = rewriter.create<spirv::BitcastOp>(
777+
loc, loadResultInOrigType, funcOp->getResult(0));
778+
rewriter.replaceOp(op, cast);
779+
} else {
780+
rewriter.replaceOp(op, funcOp);
781+
}
749782
} else {
750783
auto cast = rewriter.create<spirv::BitcastOp>(loc, op.getType(),
751784
funcOp->getResult(0));
@@ -804,8 +837,24 @@ class DpasToVCPattern : public OpConversionPattern<DpasOp> {
804837
auto infoAttr = rewriter.getIntegerAttr(rewriter.getI32Type(), infoVal);
805838
auto info = rewriter.create<spirv::ConstantOp>(loc, rewriter.getI32Type(),
806839
infoAttr);
807-
auto newResultType = encodeVectorType(rewriter, resultType).second;
808-
SmallVector<Value, 4> args{adaptor.getRhs(), adaptor.getLhs(), info};
840+
841+
auto lhs = adaptor.getLhs();
842+
auto rhs = adaptor.getRhs();
843+
// Intrinsic accepts i32 type, therefore the element type should be casted
844+
// to i32
845+
auto [lhsName, lhsNewType] = encodeVectorType(rewriter, lhsType);
846+
auto [rhsName, rhsNewType] = encodeVectorType(rewriter, rhsType);
847+
auto [resultName, newResultType] = encodeVectorType(rewriter, resultType);
848+
849+
if (lhsNewType != adaptor.getLhs().getType()) {
850+
lhs =
851+
rewriter.create<spirv::BitcastOp>(loc, lhsNewType, adaptor.getLhs());
852+
}
853+
if (rhsNewType != adaptor.getRhs().getType()) {
854+
rhs =
855+
rewriter.create<spirv::BitcastOp>(loc, rhsNewType, adaptor.getRhs());
856+
}
857+
SmallVector<Value, 4> args{rhs, lhs, info};
809858
std::string funcName = "llvm_genx_dpas_nosrc0_";
810859
if (op.getAcc()) {
811860
funcName = "llvm_genx_dpas2_";
@@ -819,14 +868,14 @@ class DpasToVCPattern : public OpConversionPattern<DpasOp> {
819868
auto sdArg = createIntConstant(i32Type, sd);
820869
auto rcArg = createIntConstant(i32Type, rc);
821870
auto signless = createIntConstant(i32Type, 0);
822-
args.assign({adaptor.getAcc(), adaptor.getRhs(), adaptor.getLhs(),
823-
prec1Arg, prec2Arg, sdArg, rcArg, signless, signless});
871+
args.assign({adaptor.getAcc(), rhs, lhs, prec1Arg, prec2Arg, sdArg, rcArg,
872+
signless, signless});
824873
}
825-
funcName += encodeVectorType(rewriter, resultType).first;
874+
funcName += resultName;
826875
funcName += "_";
827-
funcName += encodeVectorType(rewriter, rhsType).first;
876+
funcName += rhsName;
828877
funcName += "_";
829-
funcName += encodeVectorType(rewriter, lhsType).first;
878+
funcName += lhsName;
830879
auto funcType =
831880
rewriter.getFunctionType(ValueRange(args).getTypes(), newResultType);
832881
Operation *opPtr = op;
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// RUN: imex-opt -imex-convert-gpu-to-spirv='enable-vc-intrinsic=true' %s | FileCheck %s
2+
// RUN: IMEX_NOT_PREFER_RAWSEND=1 imex-opt -imex-convert-gpu-to-spirv='enable-vc-intrinsic=true' %s | FileCheck %s --check-prefix=LSC
3+
module @gemm attributes {gpu.container_module} {
4+
memref.global "private" constant @__constant_8x16xf16 : memref<8x16xf16> = dense<5.000000e-01>
5+
memref.global "private" constant @__constant_16x16xf16 : memref<16x16xf16> = dense<1.099610e+00>
6+
func.func @test(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} {
7+
%c1 = arith.constant 1 : index
8+
%memref = gpu.alloc host_shared () : memref<8x16xf16>
9+
memref.copy %arg0, %memref : memref<8x16xf16> to memref<8x16xf16>
10+
%memref_0 = gpu.alloc host_shared () : memref<16x16xf16>
11+
memref.copy %arg1, %memref_0 : memref<16x16xf16> to memref<16x16xf16>
12+
%memref_1 = gpu.alloc host_shared () : memref<8x16xf32>
13+
gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf16>, %memref_0 : memref<16x16xf16>, %memref_1 : memref<8x16xf32>)
14+
gpu.dealloc %memref : memref<8x16xf16>
15+
gpu.dealloc %memref_0 : memref<16x16xf16>
16+
return %memref_1 : memref<8x16xf32>
17+
}
18+
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
19+
gpu.func @test_kernel(%A: memref<8x16xf16>, %B: memref<16x16xf16>, %C: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
20+
// LSC: spirv.FunctionCall @llvm_genx_lsc_prefetch2d_stateless_i1_i64
21+
// LSC: spirv.FunctionCall @llvm_genx_lsc_prefetch2d_stateless_i1_i64
22+
// LSC: spirv.FunctionCall @llvm_genx_lsc_load2d_stateless_v64i32_i1_i64
23+
// LSC: spirv.FunctionCall @llvm_genx_lsc_load2d_stateless_v128i32_i1_i64
24+
// LSC: spirv.FunctionCall @llvm_genx_dpas_nosrc0_v128f32_v128i32_v64i32
25+
// LSC: spirv.FunctionCall @llvm_genx_lsc_store2d_stateless_i1_i64_v128f32
26+
27+
// CHECK: %[[A_tile_desc_base:.*]] = spirv.ConvertPtrToU %arg0 : !spirv.ptr<!spirv.array<128 x f16>, CrossWorkgroup> to i64
28+
// CHECK: %[[A_tile_payload_idx0:.*]] = spirv.VectorInsertDynamic %[[A_tile_desc_base]]
29+
// CHECK: %[[A_tile_payload_idx0_i32:.*]] = spirv.Bitcast %[[A_tile_payload_idx0]] : vector<4xi64> to vector<8xi32>
30+
// CHECK: %[[A_tile_payload_idx2:.*]] = spirv.VectorInsertDynamic
31+
// CHECK: %[[A_tile_payload_idx3:.*]] = spirv.VectorInsertDynamic
32+
// CHECK: %[[A_tile_payload_idx4:.*]] = spirv.VectorInsertDynamic
33+
// CHECK: %[[A_tile_payload_idx5:.*]] = spirv.VectorInsertDynamic
34+
// CHECK: %[[A_tile_payload_idx6:.*]] = spirv.VectorInsertDynamic
35+
// CHECK: %[[A_tile_payload_idx7:.*]] = spirv.VectorInsertDynamic
36+
37+
// CHECK: %[[B_tile_desc_base:.*]] = spirv.ConvertPtrToU %arg1 : !spirv.ptr<!spirv.array<256 x f16>, CrossWorkgroup> to i64
38+
// CHECK: %[[B_tile_payload_idx0:.*]] = spirv.VectorInsertDynamic %[[B_tile_desc_base]]
39+
// CHECK: %[[B_tile_payload_idx0_i32:.*]] = spirv.Bitcast %[[B_tile_payload_idx0]] : vector<4xi64> to vector<8xi32>
40+
// CHECK: %[[B_tile_payload_idx2:.*]] = spirv.VectorInsertDynamic
41+
// CHECK: %[[B_tile_payload_idx3:.*]] = spirv.VectorInsertDynamic
42+
// CHECK: %[[B_tile_payload_idx4:.*]] = spirv.VectorInsertDynamic
43+
// CHECK: %[[B_tile_payload_idx5:.*]] = spirv.VectorInsertDynamic
44+
// CHECK: %[[B_tile_payload_idx6:.*]] = spirv.VectorInsertDynamic
45+
// CHECK: %[[B_tile_payload_idx7:.*]] = spirv.VectorInsertDynamic
46+
47+
// CHECK: %[[C_tile_desc_base:.*]] = spirv.ConvertPtrToU %arg2 : !spirv.ptr<!spirv.array<128 x f32>, CrossWorkgroup> to i64
48+
// CHECK: %[[C_tile_payload_idx0:.*]] = spirv.VectorInsertDynamic %[[C_tile_desc_base]]
49+
// CHECK: %[[C_tile_payload_idx0_i32:.*]] = spirv.Bitcast %[[C_tile_payload_idx0]] : vector<4xi64> to vector<8xi32>
50+
// CHECK: %[[C_tile_payload_idx2:.*]] = spirv.VectorInsertDynamic
51+
// CHECK: %[[C_tile_payload_idx3:.*]] = spirv.VectorInsertDynamic
52+
// CHECK: %[[C_tile_payload_idx4:.*]] = spirv.VectorInsertDynamic
53+
// CHECK: %[[C_tile_payload_idx5:.*]] = spirv.VectorInsertDynamic
54+
// CHECK: %[[C_tile_payload_idx6:.*]] = spirv.VectorInsertDynamic
55+
// CHECK: %[[C_tile_payload_idx7:.*]] = spirv.VectorInsertDynamic
56+
57+
// CHECK: spirv.FunctionCall @llvm_genx_raw_send2_noresult_i1_v8i32(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A_tile_payload_idx7]])
58+
59+
// CHECK: spirv.FunctionCall @llvm_genx_raw_send2_noresult_i1_v8i32(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[B_tile_payload_idx7]])
60+
61+
// CHECK: %[[A_increment:.*]] = spirv.Constant dense<1.000000e+00> : vector<128xf16>
62+
63+
// CHECK: %[[A_i32:.*]] = spirv.FunctionCall @llvm_genx_raw_send2_v64i32_i1_v8i32(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A_tile_payload_idx7]], %{{.*}})
64+
// CHECK: %[[A_f16:.*]] = spirv.Bitcast %[[A_i32]] : vector<64xi32> to vector<128xf16>
65+
// CHECK: %[[A_f16_inc:.*]] = spirv.FAdd %[[A_f16]], %[[A_increment]] : vector<128xf16>
66+
67+
// CHECK: %[[B_i32:.*]] = spirv.FunctionCall @llvm_genx_raw_send2_v128i32_i1_v8i32(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[B_tile_payload_idx7]], %{{.*}})
68+
// CHECK: %[[B_f16:.*]] = spirv.Bitcast %[[B_i32]] : vector<128xi32> to vector<256xf16>
69+
70+
// CHECK: %[[A_back_i32:.*]] = spirv.Bitcast %[[A_f16_inc]] : vector<128xf16> to vector<64xi32>
71+
// CHECK: %[[B_back_i32:.*]] = spirv.Bitcast %[[B_f16]] : vector<256xf16> to vector<128xi32>
72+
// CHECK: %[[DPAS_RES:.*]] = spirv.FunctionCall @llvm_genx_dpas_nosrc0_v128f32_v128i32_v64i32(%[[B_back_i32]], %[[A_back_i32]], %{{.*}})
73+
74+
// CHECK: spirv.FunctionCall @llvm_genx_raw_sends2_noresult_i1_v8i32_v128f32(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C_tile_payload_idx7]], %[[DPAS_RES]])
75+
%A_tdesc = xegpu.create_nd_tdesc %A[0, 0] {mode = vc} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
76+
%B_tdesc = xegpu.create_nd_tdesc %B[0, 0] {mode = vc} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
77+
%C_tdesc = xegpu.create_nd_tdesc %C[0, 0] {mode = vc} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
78+
xegpu.prefetch_nd %A_tdesc {mode = vc} : !xegpu.tensor_desc<8x16xf16>
79+
xegpu.prefetch_nd %B_tdesc {mode = vc} : !xegpu.tensor_desc<16x16xf16>
80+
%A_increment = arith.constant dense<1.0> : vector<128xf16>
81+
%A_increment_ = vector.shape_cast %A_increment : vector<128xf16> to vector<8x8x2xf16>
82+
83+
%A_tensor = xegpu.load_nd %A_tdesc {mode = vc, vnni_axis = 1} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16>
84+
%A_tensor_incremented = arith.addf %A_tensor, %A_increment_ : vector<8x8x2xf16>
85+
%B_tensor = xegpu.load_nd %B_tdesc {mode = vc, vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16>
86+
%dpas_result = xegpu.dpas %A_tensor_incremented, %B_tensor {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
87+
xegpu.store_nd %dpas_result, %C_tdesc {mode = vc} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
88+
gpu.return
89+
}
90+
}
91+
func.func @main() attributes {llvm.emit_c_interface} {
92+
%0 = memref.get_global @__constant_8x16xf16 : memref<8x16xf16>
93+
%1 = memref.get_global @__constant_16x16xf16 : memref<16x16xf16>
94+
%2 = call @test(%0, %1) : (memref<8x16xf16>, memref<16x16xf16>) -> memref<8x16xf32>
95+
%cast = memref.cast %2 : memref<8x16xf32> to memref<*xf32>
96+
//call @printMemrefF32(%cast) : (memref<*xf32>) -> ()
97+
return
98+
}
99+
func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
100+
}

test/Conversion/XeGPUToSPIRV/xegpu-to-vc.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@ gpu.module @test attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.
3030
// CHECK: -> vector<128xf32> "None" attributes {VectorComputeFunctionINTEL, linkage_attributes =
3131
// CHECK: #spirv.linkage_attributes<linkage_name = "llvm.genx.dpas.nosrc0.v128f32.v128i32.v64i32", linkage_type = <Import>>}
3232
// CHECK-LABEL: spirv.func @dpas
33-
// CHECK: (%[[A:.*]]: vector<64xi32>, %[[B:.*]]: vector<128xi32>)
33+
// CHECK: (%[[A:.*]]: vector<128xf16>, %[[B:.*]]: vector<256xf16>)
3434
// CHECK-NEXT: %[[cst134744586_i32:.*]] = spirv.Constant 134744586 : i32
35-
// CHECK-NEXT: %{{.*}} = spirv.FunctionCall @llvm_genx_dpas_nosrc0_v128f32_v128i32_v64i32(%[[B]], %[[A]], %[[cst134744586_i32]])
35+
// CHECK-NEXT: %[[A_cast:.*]] = spirv.Bitcast %[[A]] : vector<128xf16> to vector<64xi32>
36+
// CHECK-NEXT: %[[B_cast:.*]] = spirv.Bitcast %[[B]] : vector<256xf16> to vector<128xi32>
37+
// CHECK-NEXT: %{{.*}} = spirv.FunctionCall @llvm_genx_dpas_nosrc0_v128f32_v128i32_v64i32(%[[B_cast]], %[[A_cast]], %[[cst134744586_i32]])
3638
// CHECK: (vector<128xi32>, vector<64xi32>, i32) -> vector<128xf32>
3739
gpu.func @dpas(%A : vector<8x8x2xf16>, %B : vector<8x16x2xf16>)
3840
kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {

0 commit comments

Comments
 (0)