Skip to content

Commit 391e344

Browse files
committed
try
Signed-off-by: Sidorov, Dmitry <[email protected]>
1 parent 0efc825 commit 391e344

File tree

17 files changed

+542
-1
lines changed

17 files changed

+542
-1
lines changed

clang/lib/Driver/ToolChains/SYCL.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1710,9 +1710,18 @@ void SYCLToolChain::AddImpliedTargetArgs(const llvm::Triple &Triple,
17101710
if (Args.hasFlag(options::OPT_ftarget_export_symbols,
17111711
options::OPT_fno_target_export_symbols, false))
17121712
BeArgs.push_back("-library-compilation");
1713-
} else if (IsJIT)
1713+
// -foffload-fp32-prec-[sqrt/div]
1714+
if (Args.hasArg(options::OPT_foffload_fp32_prec_div) ||
1715+
Args.hasArg(options::OPT_foffload_fp32_prec_sqrt))
1716+
BeArgs.push_back("-ze-fp32-correctly-rounded-divide-sqrt");
1717+
} else if (IsJIT) {
17141718
// -ftarget-compile-fast JIT
17151719
Args.AddLastArg(BeArgs, options::OPT_ftarget_compile_fast);
1720+
// -foffload-fp32-prec-div JIT
1721+
Args.AddLastArg(BeArgs, options::OPT_foffload_fp32_prec_div);
1722+
// -foffload-fp32-prec-sqrt JIT
1723+
Args.AddLastArg(BeArgs, options::OPT_OPT_foffload_fp32_prec_sqrt);
1724+
}
17161725
if (IsGen) {
17171726
for (auto [DeviceName, BackendArgStr] : PerDeviceArgs) {
17181727
CmdArgs.push_back("-device_options");
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
//===-- SYCLSqrtFDivMaxErrorCleanUp.h - SYCLSqrtFDivMaxErrorCleanUp Pass --===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// Remove llvm.fpbuiltin.[sqrt/fdiv] intrinsics to ensure compatibility with the
9+
// old drivers (that don't support SPV_INTEL_fp_max_error extension).
10+
// The intrinsic functions are removed in case if they are used with standard
11+
// for OpenCL max-error (e.g [3.0/2.5] ULP) and there are no:
12+
// - other llvm.fpbuiltin.* intrinsic functions;
13+
// - fdiv instructions
14+
// - @sqrt builtins (both C and C++-styles)/llvm intrinsic in the module.
15+
//===----------------------------------------------------------------------===//
16+
#ifndef LLVM_SYCL_SQRT_FDIV_MAX_ERROR_CLEAN_UP_H
17+
#define LLVM_SYCL_SQRT_FDIV_MAX_ERROR_CLEAN_UP_H
18+
19+
#include "llvm/IR/PassManager.h"
20+
21+
namespace llvm {
22+
23+
// FIXME: remove this pass, it's not really needed.
24+
class SYCLSqrtFDivMaxErrorCleanUpPass
25+
: public PassInfoMixin<SYCLSqrtFDivMaxErrorCleanUpPass> {
26+
public:
27+
PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
28+
29+
static bool isRequired() { return true; }
30+
};
31+
32+
} // namespace llvm
33+
34+
#endif // LLVM_SYCL_SQRT_FDIV_MAX_ERROR_CLEAN_UP_H

llvm/lib/Passes/PassBuilder.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@
142142
#include "llvm/SYCLLowerIR/SYCLJointMatrixTransform.h"
143143
#include "llvm/SYCLLowerIR/SYCLPropagateAspectsUsage.h"
144144
#include "llvm/SYCLLowerIR/SYCLPropagateJointMatrixUsage.h"
145+
#include "llvm/SYCLLowerIR/SYCLSqrtFDivMaxErrorCleanUp.h"
145146
#include "llvm/SYCLLowerIR/SYCLVirtualFunctionsAnalysis.h"
146147
#include "llvm/SYCLLowerIR/SpecConstants.h"
147148
#include "llvm/Support/CommandLine.h"

llvm/lib/Passes/PassRegistry.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ MODULE_PASS("esimd-remove-host-code", ESIMDRemoveHostCodePass());
157157
MODULE_PASS("esimd-remove-optnone-noinline", ESIMDRemoveOptnoneNoinlinePass());
158158
MODULE_PASS("sycl-conditional-call-on-device", SYCLConditionalCallOnDevicePass())
159159
MODULE_PASS("sycl-joint-matrix-transform", SYCLJointMatrixTransformPass())
160+
MODULE_PASS("sycl-sqrt-fdiv-max-error-clean-up", SYCLSqrtFDivMaxErrorCleanUpPass())
160161
MODULE_PASS("sycl-propagate-aspects-usage", SYCLPropagateAspectsUsagePass())
161162
MODULE_PASS("sycl-propagate-joint-matrix-usage", SYCLPropagateJointMatrixUsagePass())
162163
MODULE_PASS("sycl-add-opt-level-attribute", SYCLAddOptLevelAttributePass())

llvm/lib/SYCLLowerIR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ add_llvm_component_library(LLVMSYCLLowerIR
6767
SYCLJointMatrixTransform.cpp
6868
SYCLPropagateAspectsUsage.cpp
6969
SYCLPropagateJointMatrixUsage.cpp
70+
SYCLSqrtFDivMaxErrorCleanUp.cpp
7071
SYCLVirtualFunctionsAnalysis.cpp
7172
SYCLUtils.cpp
7273
SanitizeDeviceGlobal.cpp
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
//===- SYCLSqrtFDivMaxErrorCleanUp.cpp - SYCLSqrtFDivMaxErrorCleanUp Pass -===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// Remove llvm.fpbuiltin.[sqrt/fdiv] intrinsics to ensure compatibility with the
9+
// old drivers (that don't support SPV_INTEL_fp_max_error extension).
10+
// The intrinsic functions are removed in case if they are used with standard
11+
// for OpenCL max-error (e.g [3.0/2.5] ULP) and there are no:
12+
// - other llvm.fpbuiltin.* intrinsic functions;
13+
// - fdiv instructions
14+
// - @sqrt builtins (both C and C++-styles)/llvm intrinsic in the module.
15+
//===----------------------------------------------------------------------===//
16+
17+
#include "llvm/SYCLLowerIR/SYCLSqrtFDivMaxErrorCleanUp.h"
18+
19+
#include "llvm/ADT/SmallSet.h"
20+
#include "llvm/IR/Module.h"
21+
#include "llvm/IR/IntrinsicInst.h"
22+
#include "llvm/IR/IRBuilder.h"
23+
24+
using namespace llvm;
25+
26+
namespace {
27+
static constexpr char SQRT_ERROR[] = "3.0";
28+
static constexpr char FDIV_ERROR[] = "2.5";
29+
} // namespace
30+
31+
PreservedAnalyses
32+
SYCLSqrtFDivMaxErrorCleanUpPass::run(Module &M,
33+
ModuleAnalysisManager &MAM) {
34+
SmallVector<IntrinsicInst *, 16> WorkListSqrt;
35+
SmallVector<IntrinsicInst *, 16> WorkListFDiv;
36+
37+
// Add all llvm.fpbuiltin.sqrt with 3.0 error and llvm.fpbuiltin.fdiv with
38+
// 2.5 error to the work list to remove them later. If attributes with other
39+
// values or other llvm.fpbuiltin.* intrinsic functions found - abort the
40+
// pass.
41+
for (auto &F : M) {
42+
if (!F.isDeclaration())
43+
continue;
44+
const auto ID = F.getIntrinsicID();
45+
if (ID != llvm::Intrinsic::fpbuiltin_sqrt &&
46+
ID != llvm::Intrinsic::fpbuiltin_fdiv)
47+
continue;
48+
49+
for (auto *Use : F.users()) {
50+
auto *II = cast<IntrinsicInst>(Use);
51+
if (II && II->getCalledFunction()->getName().
52+
starts_with("llvm.fpbuiltin")) {
53+
// llvm.fpbuiltin.* intrinsics should always have fpbuiltin-max-error
54+
// attribute, but it's not a concern of the pass, so just do an early
55+
// exit here if the attribute is not attached.
56+
if (!II->getAttributes().hasFnAttr("fpbuiltin-max-error"))
57+
return PreservedAnalyses::none();
58+
StringRef MaxError = II->getAttributes().getFnAttr(
59+
"fpbuiltin-max-error").getValueAsString();
60+
61+
if (ID == llvm::Intrinsic::fpbuiltin_sqrt) {
62+
if (MaxError != SQRT_ERROR)
63+
return PreservedAnalyses::none();
64+
WorkListSqrt.push_back(II);
65+
}
66+
else if (ID == llvm::Intrinsic::fpbuiltin_fdiv) {
67+
if (MaxError != FDIV_ERROR)
68+
return PreservedAnalyses::none();
69+
WorkListFDiv.push_back(II);
70+
} else {
71+
// Another llvm.fpbuiltin.* intrinsic was found - the module is
72+
// already not backward compatible.
73+
return PreservedAnalyses::none();
74+
}
75+
}
76+
}
77+
}
78+
79+
// No intrinsics at all - do an early exist.
80+
if (WorkListSqrt.empty() && WorkListFDiv.empty())
81+
return PreservedAnalyses::none();
82+
83+
// If @sqrt, @_Z4sqrt*, @llvm.sqrt. or fdiv present in the module - do
84+
// nothing.
85+
for (auto &F : M) {
86+
if (F.isDeclaration())
87+
continue;
88+
for (auto &BB : F) {
89+
for (auto &II : BB) {
90+
if (auto *CI = dyn_cast<CallInst>(&II)) {
91+
auto *SqrtF = CI->getCalledFunction();
92+
if (SqrtF->getName() == "sqrt" ||
93+
SqrtF->getName().starts_with("_Z4sqrt") ||
94+
SqrtF->getIntrinsicID() == llvm::Intrinsic::sqrt)
95+
return PreservedAnalyses::none();
96+
}
97+
if (auto *FPI = dyn_cast<FPMathOperator>(&II)) {
98+
auto Opcode = FPI->getOpcode();
99+
if (Opcode == Instruction::FDiv)
100+
return PreservedAnalyses::none();
101+
}
102+
}
103+
}
104+
}
105+
106+
// Replace @llvm.fpbuiltin.sqrt call with @llvm.sqrt. llvm-spirv will handle
107+
// it later.
108+
SmallSet<Function *, 2> DeclToRemove;
109+
for (auto *Sqrt : WorkListSqrt) {
110+
DeclToRemove.insert(Sqrt->getCalledFunction());
111+
IRBuilder Builder(Sqrt);
112+
Builder.SetInsertPoint(Sqrt);
113+
Type *Ty = Sqrt->getType();
114+
AttributeList Attrs = Sqrt->getAttributes();
115+
Function *NewSqrtF =
116+
Intrinsic::getDeclaration(&M, llvm::Intrinsic::sqrt, Ty);
117+
auto *NewSqrt = Builder.CreateCall(NewSqrtF, { Sqrt->getOperand(0) },
118+
Sqrt->getName());
119+
120+
// Copy FP flags, metadata and attributes. Replace old call with a new call.
121+
Attrs = Attrs.removeFnAttribute(Sqrt->getContext(), "fpbuiltin-max-error");
122+
NewSqrt->setAttributes(Attrs);
123+
NewSqrt->copyMetadata(*Sqrt);
124+
FPMathOperator *FPOp = cast<FPMathOperator>(Sqrt);
125+
FastMathFlags FMF = FPOp->getFastMathFlags();
126+
NewSqrt->setFastMathFlags(FMF);
127+
Sqrt->replaceAllUsesWith(NewSqrt);
128+
Sqrt->dropAllReferences();
129+
Sqrt->eraseFromParent();
130+
}
131+
132+
// Replace @llvm.fpbuiltin.fdiv call with fdiv.
133+
for (auto *FDiv : WorkListFDiv) {
134+
DeclToRemove.insert(FDiv->getCalledFunction());
135+
IRBuilder Builder(FDiv);
136+
Builder.SetInsertPoint(FDiv);
137+
Instruction *NewFDiv =
138+
cast<Instruction>(Builder.CreateFDiv(
139+
FDiv->getOperand(0), FDiv->getOperand(1), FDiv->getName()));
140+
141+
// Copy FP flags and metadata. Replace old call with a new instruction.
142+
cast<Instruction>(NewFDiv)->copyMetadata(*FDiv);
143+
FPMathOperator *FPOp = cast<FPMathOperator>(FDiv);
144+
FastMathFlags FMF = FPOp->getFastMathFlags();
145+
NewFDiv->setFastMathFlags(FMF);
146+
FDiv->replaceAllUsesWith(NewFDiv);
147+
FDiv->dropAllReferences();
148+
FDiv->eraseFromParent();
149+
}
150+
151+
// Clear old declarations.
152+
for (auto *Decl : DeclToRemove) {
153+
assert(Decl->isDeclaration() &&
154+
"attempting to remove a function definition");
155+
Decl->dropAllReferences();
156+
Decl->eraseFromParent();
157+
}
158+
159+
return PreservedAnalyses::all();
160+
}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
; Test checks if @llvm.fpbuiltin.fdiv and @llvm.fpbuiltin.sqrt are removed from
2+
; the module.
3+
4+
; RUN: opt -passes=sycl-sqrt-fdiv-max-error-clean-up < %s -S | FileCheck %s
5+
6+
; CHECK-NOT: llvm.fpbuiltin.fdiv.f32
7+
; CHECK-NOT: llvm.fpbuiltin.sqrt.f32
8+
; CHECK-NOT: fpbuiltin-max-error
9+
10+
; CHECK: test_fp_max_error_decoration(float [[F1:[%0-9a-z.]+]], float [[F2:[%0-9a-z.]+]])
11+
; CHECK: [[V1:[%0-9a-z.]+]] = fdiv float [[F1]], [[F2]]
12+
; CHECK: call float @llvm.sqrt.f32(float [[V1]])
13+
14+
; CHECK: test_fp_max_error_decoration_fast(float [[F1:[%0-9a-z.]+]], float [[F2:[%0-9a-z.]+]])
15+
; CHECK: [[V1:[%0-9a-z.]+]] = fdiv fast float [[F1]], [[F2]]
16+
; CHECK: call fast float @llvm.sqrt.f32(float [[V1]])
17+
18+
; CHECK: test_fp_max_error_decoration_debug(float [[F1:[%0-9a-z.]+]], float [[F2:[%0-9a-z.]+]])
19+
; CHECK: [[V1:[%0-9a-z.]+]] = fdiv float [[F1]], [[F2]], !dbg ![[#Loc1:]]
20+
; CHECK: call float @llvm.sqrt.f32(float [[V1]]), !dbg ![[#Loc2:]]
21+
22+
; CHECK: [[#Loc1]] = !DILocation(line: 1, column: 1, scope: ![[#]])
23+
; CHECK: [[#Loc2]] = !DILocation(line: 2, column: 1, scope: ![[#]])
24+
25+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
26+
target triple = "spir64-unknown-unknown"
27+
28+
define void @test_fp_max_error_decoration(float %f1, float %f2) {
29+
entry:
30+
%v1 = call float @llvm.fpbuiltin.fdiv.f32(float %f1, float %f2) #0
31+
%v2 = call float @llvm.fpbuiltin.sqrt.f32(float %v1) #1
32+
ret void
33+
}
34+
35+
define void @test_fp_max_error_decoration_fast(float %f1, float %f2) {
36+
entry:
37+
%v1 = call fast float @llvm.fpbuiltin.fdiv.f32(float %f1, float %f2) #0
38+
%v2 = call fast float @llvm.fpbuiltin.sqrt.f32(float %v1) #1
39+
ret void
40+
}
41+
42+
define void @test_fp_max_error_decoration_debug(float %f1, float %f2) {
43+
entry:
44+
%v1 = call float @llvm.fpbuiltin.fdiv.f32(float %f1, float %f2) #0, !dbg !7
45+
%v2 = call float @llvm.fpbuiltin.sqrt.f32(float %v1) #1, !dbg !8
46+
ret void
47+
}
48+
49+
declare float @llvm.fpbuiltin.fdiv.f32(float, float)
50+
51+
declare float @llvm.fpbuiltin.sqrt.f32(float)
52+
53+
attributes #0 = { "fpbuiltin-max-error"="2.5" }
54+
attributes #1 = { "fpbuiltin-max-error"="3.0" }
55+
56+
!llvm.dbg.cu = !{!0}
57+
!llvm.module.flags = !{!9}
58+
59+
!0 = distinct !DICompileUnit(language: DW_LANG_C99, file: !1, producer: "clang", isOptimized: false, runtimeVersion: 0, emissionKind: FullDebug, nameTableKind: None)
60+
!1 = !DIFile(filename: "test.c", directory: "/tmp", checksumkind: CSK_MD5, checksum: "2a034da6937f5b9cf6dd2d89127f57fd")
61+
!2 = distinct !DISubprogram(name: "test_fp_max_error_decoration_debug", scope: !1, file: !1, line: 1, type: !3, scopeLine: 2, flags: DIFlagPrototyped, spFlags: DISPFlagDefinition, unit: !0)
62+
!3 = !DISubroutineType(types: !4)
63+
!4 = !{!5, !6, !6}
64+
!5 = !DIBasicType(name: "int", size: 32, encoding: DW_ATE_signed)
65+
!6 = !DIBasicType(name: "float", size: 32, encoding: DW_ATE_float)
66+
!7 = !DILocation(line: 1, column: 1, scope: !2)
67+
!8 = !DILocation(line: 2, column: 1, scope: !2)
68+
!9 = !{i32 2, !"Debug Info Version", i32 3}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
; Test checks if @llvm.fpbuiltin.fdiv and @llvm.fpbuiltin.sqrt remain if
2+
; non-standart for OpenCL max-error is used.
3+
4+
; RUN: opt -passes=sycl-sqrt-fdiv-max-error-clean-up < %s -S | FileCheck %s
5+
6+
; CHECK: llvm.fpbuiltin.fdiv.f32
7+
; CHECK: llvm.fpbuiltin.sqrt.f32
8+
; CHECK: fpbuiltin-max-error
9+
10+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
11+
target triple = "spir64-unknown-unknown"
12+
13+
define void @test_fp_max_error_decoration(float %f1, float %f2) {
14+
entry:
15+
%v1 = call float @llvm.fpbuiltin.fdiv.f32(float %f1, float %f2) #0
16+
%v2 = call float @llvm.fpbuiltin.sqrt.f32(float %v1) #1
17+
ret void
18+
}
19+
20+
declare float @llvm.fpbuiltin.fdiv.f32(float, float)
21+
22+
declare float @llvm.fpbuiltin.sqrt.f32(float)
23+
24+
attributes #0 = { "fpbuiltin-max-error"="2.0" }
25+
attributes #1 = { "fpbuiltin-max-error"="3.0" }
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
; Test checks if @llvm.fpbuiltin.fdiv and @llvm.fpbuiltin.sqrt remain if
2+
; fdiv instruction was in the module.
3+
4+
; RUN: opt -passes=sycl-sqrt-fdiv-max-error-clean-up < %s -S | FileCheck %s
5+
6+
; CHECK: llvm.fpbuiltin.fdiv.f32
7+
; CHECK: llvm.fpbuiltin.sqrt.f32
8+
; CHECK: fpbuiltin-max-error
9+
10+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
11+
target triple = "spir64-unknown-unknown"
12+
13+
define void @test_fp_max_error_decoration(float %f1, float %f2) {
14+
entry:
15+
%v1 = call float @llvm.fpbuiltin.fdiv.f32(float %f1, float %f2) #0
16+
%v2 = call float @llvm.fpbuiltin.sqrt.f32(float %v1) #1
17+
%v3 = fdiv float %v2, %f2
18+
ret void
19+
}
20+
21+
declare float @llvm.fpbuiltin.fdiv.f32(float, float)
22+
23+
declare float @llvm.fpbuiltin.sqrt.f32(float)
24+
25+
attributes #0 = { "fpbuiltin-max-error"="2.0" }
26+
attributes #1 = { "fpbuiltin-max-error"="3.0" }
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
; Test checks if @llvm.fpbuiltin.fdiv and @llvm.fpbuiltin.sqrt remain if
2+
; other fpbuiltin intrinsic is used in the module.
3+
4+
; RUN: opt -passes=sycl-sqrt-fdiv-max-error-clean-up < %s -S | FileCheck %s
5+
6+
; CHECK: llvm.fpbuiltin.fdiv.f32
7+
; CHECK: llvm.fpbuiltin.sqrt.f32
8+
; CHECK: fpbuiltin-max-error
9+
10+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
11+
target triple = "spir64-unknown-unknown"
12+
13+
define void @test_fp_max_error_decoration(float %f1, float %f2) {
14+
entry:
15+
%v1 = call float @llvm.fpbuiltin.fdiv.f32(float %f1, float %f2) #0
16+
%v2 = call float @llvm.fpbuiltin.sqrt.f32(float %v1) #1
17+
%v3 = call float @llvm.fpbuiltin.exp.f32(float %v2)
18+
ret void
19+
}
20+
21+
declare float @llvm.fpbuiltin.fdiv.f32(float, float)
22+
23+
declare float @llvm.fpbuiltin.sqrt.f32(float)
24+
25+
declare float @llvm.fpbuiltin.exp.f32(float)
26+
27+
attributes #0 = { "fpbuiltin-max-error"="2.0" }
28+
attributes #1 = { "fpbuiltin-max-error"="3.0" }

0 commit comments

Comments
 (0)