Skip to content

Commit ca477af

Browse files
XeGPUToVC: improvement to exp lowering. (#831)
The previous lowering was failing for f16 when creating the log2e vector constant: error: 'arith.constant' op failed to verify that all of {value, result} have same type %2 = math.exp %v1 : vector<16xf16> ^ note: see current operation: %2 = "arith.constant"() <{value = dense<1.44269502> : vector<16xf32>}>: () -> vector<16xf16>
1 parent 14bfdba commit ca477af

File tree

2 files changed

+29
-8
lines changed

2 files changed

+29
-8
lines changed

lib/Conversion/XeGPUToVC/XeGPUToVC.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,16 +1378,17 @@ struct ElementwiseToVCPattern : public OpConversionPattern<MOp> {
13781378
auto loc = op.getLoc();
13791379
// This lowering pattern is needed only for spirv ops with large vector
13801380
// lengths.
1381-
auto vecSize = vecTy.getNumElements();
13821381
// for larger vector lengths, "llvm.genx.exp" returns the base 2
13831382
// exponentiation of the input. To get the base e exponentiation, we need to
13841383
// scale the input by log2(e)
13851384
auto operands = adaptor.getOperands();
13861385
SmallVector<Value> args{operands};
13871386
if (isExpOp) {
1388-
SmallVector<float> log2e(vecSize, 1.442695040888963);
1389-
auto log2eConstVec = rewriter.create<arith::ConstantOp>(
1390-
op.getLoc(), vecTy, rewriter.getF32VectorAttr(log2e));
1387+
auto log2e = rewriter.create<arith::ConstantOp>(
1388+
loc,
1389+
rewriter.getFloatAttr(vecTy.getElementType(), 1.442695040888963));
1390+
auto log2eConstVec =
1391+
rewriter.create<vector::BroadcastOp>(loc, vecTy, log2e);
13911392
auto input = operands[0];
13921393
auto scaledInput =
13931394
rewriter.create<arith::MulFOp>(op.getLoc(), input, log2eConstVec);
Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: imex-opt -convert-xegpu-to-vc='enable-vc-intrinsic=true useRawSend=true' -cse %s | FileCheck %s --check-prefixes=CHECK
1+
// RUN: imex-opt -convert-xegpu-to-vc='enable-vc-intrinsic=true useRawSend=true' -cse --split-input-file %s | FileCheck %s --check-prefixes=CHECK
22
module @gemm attributes {gpu.container_module} {
33
gpu.module @test_kernel {
44

@@ -7,10 +7,12 @@ module @gemm attributes {gpu.container_module} {
77
%c0 = arith.constant 0 : index
88
%cv1 = arith.constant dense<1.0> : vector<16xf32>
99
%v1 = vector.load %arg0[%c0, %c0] : memref<8x16xf32>, vector<16xf32>
10-
// CHECK: arith.mulf
11-
// CHECK-NEXT: func.call @llvm.genx.exp.v16f32
10+
// CHECK: %[[LOG2E:.*]] = arith.constant 1.44{{.*}} f32
11+
// CHECK-NEXT: %[[LOG2E_VEC:.*]] = vector.broadcast %[[LOG2E]] : f32 to vector<16xf32>
12+
// CHECK-NEXT: %[[MULF:.*]] = arith.mulf {{.*}} %[[LOG2E_VEC]]
13+
// CHECK-NEXT: func.call @llvm.genx.exp.v16f32(%[[MULF]])
1214
%1 = math.exp %v1 fastmath<nnan> : vector<16xf32>
13-
// CHECK-NEXT: func.call @llvm.genx.exp.v16f32
15+
// CHECK-NEXT: func.call @llvm.genx.exp.v16f32(%[[MULF]])
1416
%2 = math.exp %v1 : vector<16xf32>
1517
// CHECK-NEXT: func.call @llvm.genx.fmax.v16f32
1618
%4 = arith.maximumf %v1, %cv1 fastmath<nnan> : vector<16xf32>
@@ -20,3 +22,21 @@ module @gemm attributes {gpu.container_module} {
2022
}
2123
}
2224
}
25+
26+
// -----
27+
28+
module @gemm attributes {gpu.container_module} {
29+
gpu.module @exp_f16 {
30+
// CHECK-LABEL: gpu.func @exp_f16
31+
gpu.func @exp_f16(%arg0: memref<8x16xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>}{
32+
%c0 = arith.constant 0 : index
33+
%v1 = vector.load %arg0[%c0, %c0] : memref<8x16xf16>, vector<16xf16>
34+
// CHECK: %[[LOG2E:.*]] = arith.constant 1.44{{.*}} f16
35+
// CHECK-NEXT: %[[LOG2E_VEC:.*]] = vector.broadcast %[[LOG2E]] : f16 to vector<16xf16>
36+
// CHECK-NEXT: %[[MULF:.*]] = arith.mulf {{.*}} %[[LOG2E_VEC]]
37+
// CHECK-NEXT: func.call @llvm.genx.exp.v8i32(%[[MULF]])
38+
%2 = math.exp %v1 : vector<16xf16>
39+
gpu.return
40+
}
41+
}
42+
}

0 commit comments

Comments
 (0)