Skip to content

Commit fc58546

Browse files
igorban-inteligcbot
authored andcommitted
Apply rsqrt pattern for double
.
1 parent 945236a commit fc58546

File tree

3 files changed

+71
-47
lines changed

3 files changed

+71
-47
lines changed

IGC/VectorCompiler/lib/BiF/Library/Math/F64/rsqrt.cpp

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,34 +17,34 @@ namespace {
1717

1818
template <int N>
1919
CM_NODEBUG CM_INLINE mask<N> check_is_nan_or_inf(vector<double, N> q) {
20-
vector<uint32_t, 2 * N> q_split = q.template format<uint32_t>();
20+
vector<uint32_t, 2 *N> q_split = q.template format<uint32_t>();
2121
vector<uint32_t, N> q_hi = q_split.template select<N, 2>(1);
2222
return (q_hi >= exp_32bitmask);
2323
}
2424

2525
template <int N>
2626
CM_NODEBUG CM_INLINE vector<uint32_t, N> get_exp(vector<double, N> x) {
27-
vector<uint32_t, 2 * N> x_split = x.template format<uint32_t>();
27+
vector<uint32_t, 2 *N> x_split = x.template format<uint32_t>();
2828
vector<uint32_t, N> x_hi = x_split.template select<N, 2>(1);
2929
return (x_hi >> exp_shift) & exp_mask;
3030
}
3131

3232
template <int N>
3333
CM_NODEBUG CM_INLINE vector<uint32_t, N> get_sign(vector<double, N> x) {
34-
vector<uint32_t, 2 * N> x_split = x.template format<uint32_t>();
34+
vector<uint32_t, 2 *N> x_split = x.template format<uint32_t>();
3535
vector<uint32_t, N> x_hi = x_split.template select<N, 2>(1);
3636
return x_hi & sign_32bit;
3737
}
3838

3939
template <int N> CM_NODEBUG CM_INLINE mask<N> is_denormal(vector<double, N> x) {
40-
vector<uint32_t, 2 * N> x_int = x.template format<uint32_t>();
40+
vector<uint32_t, 2 *N> x_int = x.template format<uint32_t>();
4141
vector<uint32_t, N> x_hi = x_int.template select<N, 2>(1);
4242
return x_hi < min_sign_exp;
4343
}
4444

4545
template <int N>
4646
CM_NODEBUG CM_INLINE vector<uint32_t, N> sep_exp(vector<double, N> x) {
47-
vector<uint32_t, 2 * N> x_int = x.template format<uint32_t>();
47+
vector<uint32_t, 2 *N> x_int = x.template format<uint32_t>();
4848
vector<uint32_t, N> x_hi = x_int.template select<N, 2>(1);
4949
vector<uint32_t, N> res = (x_hi >> exp_shift) - exp_bias;
5050
return res >> 1;
@@ -84,8 +84,9 @@ CM_NODEBUG CM_INLINE vector<double, N> rsqrt_float(vector<double, N> x) {
8484
}
8585

8686
template <int N>
87-
CM_NODEBUG CM_INLINE vector<double, N> uint64_sub_hi(vector<double, N> x, vector<uint32_t, N> hi) {
88-
vector<uint32_t, 2 * N> ex_mx_int = 0;
87+
CM_NODEBUG CM_INLINE vector<double, N> uint64_sub_hi(vector<double, N> x,
88+
vector<uint32_t, N> hi) {
89+
vector<uint32_t, 2 *N> ex_mx_int = 0;
8990
ex_mx_int.template select<N, 2>(1) = hi;
9091
vector<uint64_t, N> ex_u64 = ex_mx_int.template format<uint64_t>();
9192
vector<uint64_t, N> mx_u64 = x.template format<uint64_t>();
@@ -163,9 +164,10 @@ CM_NODEBUG CM_INLINE vector<double, N> sqrt_special(vector<double, N> a) {
163164
}
164165

165166
template <int N>
166-
CM_NODEBUG CM_INLINE vector<double, N> calc_sqrt(vector<double, N> x, mask<N> special) {
167+
CM_NODEBUG CM_INLINE vector<double, N> calc_sqrt(vector<double, N> x,
168+
mask<N> special) {
167169
// Now start the SQRT computation
168-
// Use math.rsqtm (emulated here)
170+
// Use math.rsqtm (emulated here)
169171
vector<double, N> y0 = math_rsqt_dp(x);
170172
// predicate is set for 0, neg a, Inf, NaN inputs
171173
y0.merge(sqrt_special(x), special);
@@ -174,7 +176,8 @@ CM_NODEBUG CM_INLINE vector<double, N> calc_sqrt(vector<double, N> x, mask<N> sp
174176
}
175177

176178
template <int N>
177-
CM_NODEBUG CM_INLINE vector<double, N> invert_calc(vector<double, N> a, vector<double, N> y0) {
179+
CM_NODEBUG CM_INLINE vector<double, N> invert_calc(vector<double, N> a,
180+
vector<double, N> y0) {
178181
// IEEE SQRT computes H0 = 0.5*y0 (can be skipped)
179182
// Step 3: S0 = a*y0
180183
vector<double, N> S0 = a * y0;
@@ -235,15 +238,17 @@ __vc_builtin_rsqrt_f64__rte_(double a) {
235238
return __impl_rsqrt_f64(va)[0];
236239
}
237240

238-
#define FREM(WIDTH) \
241+
#define RSQRT(WIDTH) \
239242
CM_NODEBUG CM_NOINLINE extern "C" cl_vector<double, WIDTH> \
240-
__vc_builtin_rsqrt_v##WIDTH##f64__rte_(cl_vector<double, WIDTH> a) { \
243+
__vc_builtin_rsqrt_v##WIDTH##f64__rte_(cl_vector<double, WIDTH> a) { \
241244
vector<double, WIDTH> va{a}; \
242-
auto r = __impl_rsqrt_f64(va); \
245+
auto r = __impl_rsqrt_f64(va); \
243246
return r.cl_vector(); \
244247
}
245248

246-
FREM(1)
247-
FREM(2)
248-
FREM(4)
249-
FREM(8)
249+
RSQRT(1)
250+
RSQRT(2)
251+
RSQRT(4)
252+
RSQRT(8)
253+
RSQRT(16)
254+
RSQRT(32)

IGC/VectorCompiler/lib/GenXCodeGen/GenXPatternMatch.cpp

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,8 +1255,9 @@ bool GenXPatternMatch::flipBoolNot(Instruction *Inst) {
12551255
bool GenXPatternMatch::matchInverseSqrt(CallInst *I) {
12561256
IGC_ASSERT(I && I->arg_size() == 1);
12571257

1258-
// Leave as it is for double types
1259-
if (I->getType()->getScalarType()->isDoubleTy())
1258+
// Double rsqrt may be generated only before legalization
1259+
if (I->getType()->getScalarType()->isDoubleTy() &&
1260+
(!ST->hasFP64() || Kind == PatternMatchKind::PostLegalization))
12601261
return false;
12611262

12621263
bool IsFast = true;
@@ -2293,9 +2294,9 @@ bool MinMaxMatcher::emit() {
22932294

22942295
// For a given instruction, find the insertion position which is the closest
22952296
// to all the similar users to the specified reference user.
2296-
static Instruction *findOptimalInsertionPos(
2297-
Instruction *I, Instruction *Ref, DominatorTree *DT,
2298-
std::function<bool(Instruction *, Instruction *)> IsDivisor) {
2297+
static Instruction *
2298+
findOptimalInsertionPos(Value *I, Instruction *Ref, DominatorTree *DT,
2299+
std::function<bool(Instruction *, Value *)> IsDivisor) {
22992300
IGC_ASSERT_MESSAGE(!isa<PHINode>(Ref), "PHINode is not expected!");
23002301

23012302
// Shortcut case. If it's single-used, insert just before that user.
@@ -2402,48 +2403,48 @@ void GenXPatternMatch::visitFDiv(BinaryOperator &I) {
24022403
return;
24032404
}
24042405

2405-
// Skip if FP64 emulation is required for this platform
2406-
if (ST->emulateFDivFSqrt64() && I.getType()->getScalarType()->isDoubleTy())
2407-
return;
2408-
24092406
Instruction *Divisor = dyn_cast<Instruction>(Op1);
2410-
if (!Divisor)
2411-
return;
24122407

2413-
auto IsDivisor = [](Instruction *I, Instruction *MaybeDivisor) {
2408+
auto IsDivisor = [](Instruction *I, Value *MaybeDivisor) {
24142409
return I->getOpcode() == Instruction::FDiv &&
24152410
I->getOperand(1) == MaybeDivisor;
24162411
};
24172412

2418-
Instruction *Pos = findOptimalInsertionPos(Divisor, &I, DT, IsDivisor);
2413+
Instruction *Pos = findOptimalInsertionPos(Op1, &I, DT, IsDivisor);
24192414
IRB.SetInsertPoint(Pos);
24202415

24212416
// (fdiv 1., (sqrt x)) -> (rsqrt x)
24222417
// Allow the pattern even if fdiv has no fast-math flags.
2423-
auto IID = vc::getAnyIntrinsicID(Divisor);
2424-
if ((IID == GenXIntrinsic::genx_sqrt ||
2425-
(IID == Intrinsic::sqrt && Divisor->hasApproxFunc())) &&
2426-
match(Op0, m_FPOne()) && Divisor->hasOneUse()) {
2427-
auto *Rsqrt = createInverseSqrt(Divisor->getOperand(0), Pos);
2428-
I.replaceAllUsesWith(Rsqrt);
2429-
I.eraseFromParent();
2430-
Divisor->eraseFromParent();
2418+
if (Divisor) {
2419+
auto IID = vc::getAnyIntrinsicID(Divisor);
2420+
if ((IID == GenXIntrinsic::genx_sqrt ||
2421+
(IID == Intrinsic::sqrt && Divisor->hasApproxFunc())) &&
2422+
match(Op0, m_FPOne()) && Divisor->hasOneUse()) {
2423+
auto *Rsqrt = createInverseSqrt(Divisor->getOperand(0), Pos);
2424+
I.replaceAllUsesWith(Rsqrt);
2425+
I.eraseFromParent();
2426+
Divisor->eraseFromParent();
2427+
2428+
Changed |= true;
2429+
return;
2430+
}
2431+
}
24312432

2432-
Changed |= true;
2433+
// Skip if FP64 emulation is required for this platform
2434+
if (ST->emulateFDivFSqrt64() && I.getType()->getScalarType()->isDoubleTy())
24332435
return;
2434-
}
24352436

24362437
// Skip if reciprocal optimization is not allowed.
24372438
if (!I.hasAllowReciprocal())
24382439
return;
24392440

2440-
auto Rcp = getReciprocal(IRB, Divisor);
2441+
auto *Rcp = getReciprocal(IRB, Op1);
24412442
cast<Instruction>(Rcp)->setDebugLoc(I.getDebugLoc());
24422443

2443-
for (auto UI = Divisor->user_begin(); UI != Divisor->user_end();) {
2444+
for (auto UI = Op1->user_begin(); UI != Op1->user_end();) {
24442445
auto *U = *UI++;
24452446
Instruction *UserInst = dyn_cast<Instruction>(U);
2446-
if (!UserInst || UserInst == Rcp || !IsDivisor(UserInst, Divisor))
2447+
if (!UserInst || UserInst == Rcp || !IsDivisor(UserInst, Op1))
24472448
continue;
24482449
Op0 = UserInst->getOperand(0);
24492450
Value *NewVal = Rcp;

IGC/VectorCompiler/test/PatternMatch/inverse_sqrt.ll

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
;
77
;============================ end_copyright_notice =============================
88

9-
; RUN: %opt %use_old_pass_manager% -GenXPatternMatch -march=genx64 -mcpu=Gen9 -mtriple=spir64-unknown-unknown -S < %s | FileCheck %s
9+
; RUN: %opt %use_old_pass_manager% -GenXPatternMatch -march=genx64 -mcpu=Gen9 \
10+
; RUN: -mtriple=spir64-unknown-unknown -S < %s | FileCheck %s
1011

1112
; CHECK-LABEL: @test_inverse
1213
define <16 x float> @test_inverse(<16 x float> %val) {
@@ -56,10 +57,10 @@ define <16 x float> @test_inverse_not_fast(<16 x float> %src) {
5657
ret <16 x float> %inv
5758
}
5859

59-
; CHECK-LABEL: @test_not_inverse_double
60-
define <16 x double> @test_not_inverse_double(<16 x double> %val_double) {
60+
; CHECK-LABEL: @test_inverse_double
61+
define <16 x double> @test_inverse_double(<16 x double> %val_double) {
6162
%sqrt = call <16 x double> @llvm.sqrt.v16f64(<16 x double> %val_double)
62-
; CHECK: @llvm.genx.inv.v16f64(<16 x double> %sqrt)
63+
; CHECK: call <16 x double> @llvm.genx.rsqrt.v16f64(<16 x double> %val_double)
6364
%inv = call <16 x double> @llvm.genx.inv.v16f64(<16 x double> %sqrt)
6465
ret <16 x double> %inv
6566
}
@@ -240,6 +241,22 @@ define float @test_inv_sqrt_6(float %val) {
240241
ret float %sqrt
241242
}
242243

244+
; CHECK-LABEL: @test_inverse_double_2
245+
define <16 x double> @test_inverse_double_2(<16 x double> %val_double) {
246+
%sqrt = call <16 x double> @llvm.genx.sqrt.v16f64(<16 x double> %val_double)
247+
%div = fdiv <16 x double> <double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00>, %sqrt
248+
; CHECK: call <16 x double> @llvm.genx.rsqrt.v16f64(<16 x double> %val_double)
249+
ret <16 x double> %div
250+
}
251+
252+
; CHECK-LABEL: @test_inverse_double_3
253+
define <16 x double> @test_inverse_double_3(<16 x double> %val_double) {
254+
%div = fdiv arcp <16 x double> <double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00>, %val_double
255+
%sqrt = call <16 x double> @llvm.genx.sqrt.v16f64(<16 x double> %div)
256+
; CHECK: call <16 x double> @llvm.genx.rsqrt.v16f64(<16 x double> %val_double)
257+
ret <16 x double> %sqrt
258+
}
259+
243260
declare float @llvm.sqrt.f32(float)
244261
declare float @llvm.genx.sqrt.f32(float)
245262
declare float @llvm.genx.inv.f32(float)
@@ -248,5 +265,6 @@ declare <2 x float> @llvm.genx.inv.v2f32(<2 x float>)
248265
declare <16 x float> @llvm.sqrt.v16f32(<16 x float>)
249266
declare <16 x double> @llvm.sqrt.v16f64(<16 x double>)
250267
declare <16 x float> @llvm.genx.sqrt.v16f32(<16 x float>)
268+
declare <16 x double> @llvm.genx.sqrt.v16f64(<16 x double>)
251269
declare <16 x float> @llvm.genx.inv.v16f32(<16 x float>)
252270
declare <16 x double> @llvm.genx.inv.v16f64(<16 x double>)

0 commit comments

Comments
 (0)