Skip to content

Commit c4f91b9

Browse files
authored
Generate fma from mad when allowed by compile options (#1438)
1 parent 9aad224 commit c4f91b9

File tree

5 files changed

+62
-0
lines changed

5 files changed

+62
-0
lines changed

lib/Builtins.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,9 @@ Builtins::getExtInstEnum(const Builtins::FunctionInfo &func_info) {
686686
return glsl::ExtInst::ExtInstPackHalf2x16;
687687
case Builtins::kSpirvUnpack:
688688
return glsl::ExtInst::ExtInstUnpackHalf2x16;
689+
case Builtins::kMad:
690+
// Only floating-point kMad should be able to get here
691+
return glsl::ExtInst::ExtInstFma;
689692
default:
690693
break;
691694
}

lib/LongVectorLoweringPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ Function *getBIFScalarVersion(Function &Builtin) {
235235
case clspv::Builtins::kLdexp:
236236
case clspv::Builtins::kLog:
237237
case clspv::Builtins::kLog2:
238+
case clspv::Builtins::kMad:
238239
case clspv::Builtins::kMax:
239240
case clspv::Builtins::kMin:
240241
case clspv::Builtins::kMix:

lib/ReplaceOpenCLBuiltinPass.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2116,6 +2116,13 @@ bool ReplaceOpenCLBuiltinPass::replaceSignbit(Function &F, bool is_vec) {
21162116

21172117
bool ReplaceOpenCLBuiltinPass::replaceMul(Function &F, bool is_float,
21182118
bool is_mad) {
2119+
// floating-point fma can be handle later in the flow if they are allowed
2120+
if (is_float && is_mad &&
2121+
(clspv::Option::UseNativeBuiltins().count(
2122+
clspv::Builtins::BuiltinType::kFma) > 0 ||
2123+
clspv::Option::ClMadEnable() || clspv::Option::UnsafeMath())) {
2124+
return false;
2125+
}
21192126
return replaceCallsWithValue(F, [&](CallInst *CI) -> llvm::Value * {
21202127
// The multiply instruction to use.
21212128
auto MulInst = is_float ? Instruction::FMul : Instruction::Mul;
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
; RUN: clspv-opt --passes=long-vector-lowering %s -o %t.ll
2+
; RUN: FileCheck %s < %t.ll
3+
4+
; CHECK-COUNT-8: call spir_func float @_Z3madfff(
5+
6+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-G1"
7+
target triple = "spir-unknown-unknown"
8+
9+
define spir_kernel void @foo(ptr addrspace(1) align 32 %a) {
10+
entry:
11+
%arrayidx = getelementptr inbounds <8 x float>, ptr addrspace(1) %a, i32 1
12+
%0 = load <8 x float>, ptr addrspace(1) %arrayidx, align 32
13+
%arrayidx1 = getelementptr inbounds <8 x float>, ptr addrspace(1) %a, i32 2
14+
%1 = load <8 x float>, ptr addrspace(1) %arrayidx1, align 32
15+
%arrayidx2 = getelementptr inbounds <8 x float>, ptr addrspace(1) %a, i32 3
16+
%2 = load <8 x float>, ptr addrspace(1) %arrayidx2, align 32
17+
%call = call spir_func <8 x float> @_Z3madDv8_fS_S_(<8 x float> %0, <8 x float> %1, <8 x float> %2)
18+
%arrayidx3 = getelementptr inbounds <8 x float>, ptr addrspace(1) %a, i32 0
19+
store <8 x float> %call, ptr addrspace(1) %arrayidx3, align 32
20+
ret void
21+
}
22+
23+
declare spir_func <8 x float> @_Z3madDv8_fS_S_(<8 x float>, <8 x float>, <8 x float>)
24+

test/mad-float-optimization.cl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: clspv %target %s -o %t.unsafe.spv --cl-unsafe-math-optimizations
2+
// RUN: spirv-dis -o %t.unsafe.spvasm %t.unsafe.spv
3+
// RUN: FileCheck %s < %t.unsafe.spvasm
4+
// RUN: spirv-val --target-env spv1.0 %t.unsafe.spv
5+
6+
// RUN: clspv %target %s -o %t.mad.spv --cl-mad-enable
7+
// RUN: spirv-dis -o %t.mad.spvasm %t.mad.spv
8+
// RUN: FileCheck %s < %t.mad.spvasm
9+
// RUN: spirv-val --target-env spv1.0 %t.mad.spv
10+
11+
// RUN: clspv %target %s -o %t.native.spv --use-native-builtins=fma
12+
// RUN: spirv-dis -o %t.native.spvasm %t.native.spv
13+
// RUN: FileCheck %s < %t.native.spvasm
14+
// RUN: spirv-val --target-env spv1.0 %t.native.spv
15+
16+
// CHECK: OpExtInst {{.*}} Fma
17+
18+
// RUN: clspv %target %s -o %t.spv
19+
// RUN: spirv-dis -o %t.spvasm %t.spv
20+
// RUN: FileCheck %s --check-prefix=NOOPT < %t.spvasm
21+
// RUN: spirv-val --target-env spv1.0 %t.spv
22+
23+
// NOOPT-NOT: OpExtInst {{.*}} Fma
24+
25+
void kernel foo(global float* a) {
26+
a[0] = mad(a[1], a[2], a[3]);
27+
}

0 commit comments

Comments
 (0)