Skip to content

Commit cc0aa07

Browse files
vmustyaigcbot
authored andcommitted
Support copysign intrinsic in VC
.
1 parent 0a90ec2 commit cc0aa07

File tree

6 files changed

+155
-1
lines changed

6 files changed

+155
-1
lines changed

IGC/VectorCompiler/lib/GenXCodeGen/GenXLowering.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,8 @@ class GenXLowering : public FunctionPass {
276276
bool lowerReduction(CallInst *CI, Instruction::BinaryOps Opcode);
277277
bool lowerReduction(CallInst *CI, Intrinsic::ID);
278278

279+
bool lowerCopySign(CallInst *CI);
280+
279281
bool generatePredicatedWrrForNewLoad(CallInst *CI);
280282
};
281283

@@ -2157,6 +2159,8 @@ bool GenXLowering::processInst(Instruction *Inst) {
21572159
return lowerReduction(CI, Intrinsic::maxnum);
21582160
case Intrinsic::vector_reduce_fmin:
21592161
return lowerReduction(CI, Intrinsic::minnum);
2162+
case Intrinsic::copysign:
2163+
return lowerCopySign(CI);
21602164
case GenXIntrinsic::genx_get_hwid:
21612165
return lowerHardwareThreadID(CI);
21622166
case vc::InternalIntrinsic::logical_thread_id:
@@ -5026,6 +5030,66 @@ bool GenXLowering::lowerReduction(CallInst *CI, Intrinsic::ID IID) {
50265030
});
50275031
}
50285032

5033+
bool GenXLowering::lowerCopySign(CallInst *CI) {
5034+
IRBuilder<> Builder(CI);
5035+
5036+
auto *Ty = CI->getType()->getScalarType();
5037+
auto ElementSize = Ty->getPrimitiveSizeInBits();
5038+
auto Stride = ElementSize / genx::WordBits;
5039+
IGC_ASSERT(ElementSize % genx::WordBits == 0);
5040+
IGC_ASSERT(Stride == 1 || Stride == 2 || Stride == 4);
5041+
5042+
auto NumElements = 1;
5043+
if (auto *VTy = dyn_cast<IGCLLVM::FixedVectorType>(CI->getType()))
5044+
NumElements = VTy->getNumElements();
5045+
auto CastNumElements = NumElements * Stride;
5046+
5047+
auto *Int16Ty = Builder.getInt16Ty();
5048+
auto *CastTy = IGCLLVM::FixedVectorType::get(Int16Ty, CastNumElements);
5049+
auto *LowerTy = IGCLLVM::FixedVectorType::get(Int16Ty, NumElements);
5050+
5051+
auto *Mag = CI->getOperand(0);
5052+
auto *Sign = CI->getOperand(1);
5053+
5054+
auto *MagCast = Builder.CreateBitCast(Mag, CastTy);
5055+
auto *MagInt = MagCast;
5056+
auto *SignInt = Builder.CreateBitCast(Sign, CastTy);
5057+
5058+
vc::CMRegion R(LowerTy, DL);
5059+
auto &DebugLoc = CI->getDebugLoc();
5060+
5061+
if (Stride > 1) {
5062+
R.VStride = Stride;
5063+
R.Width = 1;
5064+
R.Stride = 0;
5065+
R.Offset = (Stride - 1) * genx::WordBytes;
5066+
5067+
MagInt = R.createRdRegion(MagInt, "", CI, DebugLoc);
5068+
SignInt = R.createRdRegion(SignInt, "", CI, DebugLoc);
5069+
}
5070+
5071+
auto *MagMask = ConstantInt::get(Int16Ty, 0x7FFF);
5072+
auto *SignMask = ConstantInt::get(Int16Ty, 0x8000);
5073+
5074+
auto *MagAbs = Builder.CreateAnd(
5075+
MagInt, Builder.CreateVectorSplat(NumElements, MagMask));
5076+
auto *SignBit = Builder.CreateAnd(
5077+
SignInt, Builder.CreateVectorSplat(NumElements, SignMask));
5078+
5079+
auto *Res = Builder.CreateOr(MagAbs, SignBit);
5080+
5081+
if (Stride > 1)
5082+
Res = R.createWrRegion(MagCast, Res, "", CI, DebugLoc);
5083+
5084+
Res = Builder.CreateBitCast(Res, CI->getType());
5085+
5086+
Res->takeName(CI);
5087+
CI->replaceAllUsesWith(Res);
5088+
ToErase.push_back(CI);
5089+
5090+
return true;
5091+
}
5092+
50295093
/***********************************************************************
50305094
* widenByteOp : widen a vector byte operation to short if that might
50315095
* improve code

IGC/VectorCompiler/lib/GenXCodeGen/GenXPatternMatch.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ class BfnMatcher {
505505
return std::nullopt;
506506
}
507507

508-
auto *CV = dyn_cast<ConstantVector>(C);
508+
auto *CV = dyn_cast<ConstantDataVector>(C);
509509
if (!CV)
510510
return std::nullopt;
511511

IGC/VectorCompiler/lib/GenXOpts/CMTrans/GenXTranslateSPIRVBuiltins.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ Value *SPIRVExpander::visitCallInst(CallInst &CI) {
261261
.StartsWith("popcount", Intrinsic::ctpop)
262262
.StartsWith("s_abs", GenXIntrinsic::genx_absi)
263263
// Floating-point intrinsics
264+
.StartsWith("copysign", Intrinsic::copysign)
264265
.StartsWith("fabs", Intrinsic::fabs)
265266
.StartsWith("fmax", Intrinsic::maxnum)
266267
.StartsWith("fma", Intrinsic::fma)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
;=========================== begin_copyright_notice ============================
2+
;
3+
; Copyright (C) 2024 Intel Corporation
4+
;
5+
; SPDX-License-Identifier: MIT
6+
;
7+
;============================ end_copyright_notice =============================
8+
9+
; RUN: %opt %use_old_pass_manager% -GenXLowering -march=genx64 -mcpu=XeHPC -mtriple=spir64-unknown-unknown -S < %s | FileCheck %s
10+
11+
declare <4 x half> @llvm.copysign.v4f16(<4 x half>, <4 x half>)
12+
declare <4 x bfloat> @llvm.copysign.v4bf16(<4 x bfloat>, <4 x bfloat>)
13+
declare <4 x float> @llvm.copysign.v4f32(<4 x float>, <4 x float>)
14+
declare <4 x double> @llvm.copysign.v4f64(<4 x double>, <4 x double>)
15+
16+
; CHECK-LABEL: @test_v4f16
17+
define <4 x half> @test_v4f16(<4 x half> %src, <4 x half> %sign) {
18+
; CHECK: [[MAG:%.*]] = bitcast <4 x half> %src to <4 x i16>
19+
; CHECK: [[SGN:%.*]] = bitcast <4 x half> %sign to <4 x i16>
20+
; CHECK: [[ABS:%.*]] = and <4 x i16> [[MAG]], <i16 32767, i16 32767, i16 32767, i16 32767>
21+
; CHECK: [[SIGN:%.*]] = and <4 x i16> [[SGN]], <i16 -32768, i16 -32768, i16 -32768, i16 -32768>
22+
; CHECK: [[RES:%.*]] = or <4 x i16> [[ABS]], [[SIGN]]
23+
; CHECK: [[RES_HALF:%.*]] = bitcast <4 x i16> [[RES]] to <4 x half>
24+
; CHECK: ret <4 x half> [[RES_HALF]]
25+
%res = call <4 x half> @llvm.copysign.v4f16(<4 x half> %src, <4 x half> %sign)
26+
ret <4 x half> %res
27+
}
28+
29+
; CHECK-LABEL: @test_v4bf16
30+
define <4 x bfloat> @test_v4bf16(<4 x bfloat> %src, <4 x bfloat> %sign) {
31+
; CHECK: [[MAG:%.*]] = bitcast <4 x bfloat> %src to <4 x i16>
32+
; CHECK: [[SGN:%.*]] = bitcast <4 x bfloat> %sign to <4 x i16>
33+
; CHECK: [[ABS:%.*]] = and <4 x i16> [[MAG]], <i16 32767, i16 32767, i16 32767, i16 32767>
34+
; CHECK: [[SIGN:%.*]] = and <4 x i16> [[SGN]], <i16 -32768, i16 -32768, i16 -32768, i16 -32768>
35+
; CHECK: [[RES:%.*]] = or <4 x i16> [[ABS]], [[SIGN]]
36+
; CHECK: [[RES_BF:%.*]] = bitcast <4 x i16> [[RES]] to <4 x bfloat>
37+
; CHECK: ret <4 x bfloat> [[RES_BF]]
38+
%res = call <4 x bfloat> @llvm.copysign.v4bf16(<4 x bfloat> %src, <4 x bfloat> %sign)
39+
ret <4 x bfloat> %res
40+
}
41+
42+
; CHECK-LABEL: @test_v4f32
43+
define <4 x float> @test_v4f32(<4 x float> %src, <4 x float> %sign) {
44+
; CHECK: [[MAG:%.*]] = bitcast <4 x float> %src to <8 x i16>
45+
; CHECK: [[SGN:%.*]] = bitcast <4 x float> %sign to <8 x i16>
46+
; CHECK: [[MAG_EXTRACT:%.*]] = call <4 x i16> @llvm.genx.rdregioni.v4i16.v8i16.i16(<8 x i16> [[MAG]], i32 2, i32 1, i32 0, i16 2, i32 undef)
47+
; CHECK: [[SGN_EXTRACT:%.*]] = call <4 x i16> @llvm.genx.rdregioni.v4i16.v8i16.i16(<8 x i16> [[SGN]], i32 2, i32 1, i32 0, i16 2, i32 undef)
48+
; CHECK: [[ABS:%.*]] = and <4 x i16> [[MAG_EXTRACT]], <i16 32767, i16 32767, i16 32767, i16 32767>
49+
; CHECK: [[SIGN:%.*]] = and <4 x i16> [[SGN_EXTRACT]], <i16 -32768, i16 -32768, i16 -32768, i16 -32768>
50+
; CHECK: [[RES:%.*]] = or <4 x i16> [[ABS]], [[SIGN]]
51+
; CHECK: [[RES_INSERT:%.*]] = call <8 x i16> @llvm.genx.wrregioni.v8i16.v4i16.i16.i1(<8 x i16> [[MAG]], <4 x i16> [[RES]], i32 2, i32 1, i32 0, i16 2, i32 undef, i1 true)
52+
; CHECK: [[RES_FLOAT:%.*]] = bitcast <8 x i16> [[RES_INSERT]] to <4 x float>
53+
; CHECK: ret <4 x float> [[RES_FLOAT]]
54+
%res = call <4 x float> @llvm.copysign.v4f32(<4 x float> %src, <4 x float> %sign)
55+
ret <4 x float> %res
56+
}
57+
58+
; CHECK-LABEL: @test_v4f64
59+
define <4 x double> @test_v4f64(<4 x double> %src, <4 x double> %sign) {
60+
; CHECK: [[MAG:%.*]] = bitcast <4 x double> %src to <16 x i16>
61+
; CHECK: [[SGN:%.*]] = bitcast <4 x double> %sign to <16 x i16>
62+
; CHECK: [[MAG_EXTRACT:%.*]] = call <4 x i16> @llvm.genx.rdregioni.v4i16.v16i16.i16(<16 x i16> [[MAG]], i32 4, i32 1, i32 0, i16 6, i32 undef)
63+
; CHECK: [[SGN_EXTRACT:%.*]] = call <4 x i16> @llvm.genx.rdregioni.v4i16.v16i16.i16(<16 x i16> [[SGN]], i32 4, i32 1, i32 0, i16 6, i32 undef)
64+
; CHECK: [[ABS:%.*]] = and <4 x i16> [[MAG_EXTRACT]], <i16 32767, i16 32767, i16 32767, i16 32767>
65+
; CHECK: [[SIGN:%.*]] = and <4 x i16> [[SGN_EXTRACT]], <i16 -32768, i16 -32768, i16 -32768, i16 -32768>
66+
; CHECK: [[RES:%.*]] = or <4 x i16> [[ABS]], [[SIGN]]
67+
; CHECK: [[RES_INSERT:%.*]] = call <16 x i16> @llvm.genx.wrregioni.v16i16.v4i16.i16.i1(<16 x i16> [[MAG]], <4 x i16> [[RES]], i32 4, i32 1, i32 0, i16 6, i32 undef, i1 true)
68+
; CHECK: [[RES_DOUBLE:%.*]] = bitcast <16 x i16> [[RES_INSERT]] to <4 x double>
69+
; CHECK: ret <4 x double> [[RES_DOUBLE]]
70+
%res = call <4 x double> @llvm.copysign.v4f64(<4 x double> %src, <4 x double> %sign)
71+
ret <4 x double> %res
72+
}

IGC/VectorCompiler/test/PatternMatch/bfn_match.ll

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,12 @@ define i32 @test_unmatch_flag(<32 x i1> %a, <32 x i1> %b, <32 x i1> %c) {
132132

133133
ret i32 %2
134134
}
135+
136+
; CHECK-LABEL: @test_match_combine_by_mask_inv_vector(
137+
define <4 x i16> @test_match_combine_by_mask_inv_vector(<4 x i16> %mag, <4 x i16> %sgn) {
138+
; CHECK: %res = call <4 x i16> @llvm.genx.bfn.v4i16.v4i16(<4 x i16> %mag, <4 x i16> %sgn, <4 x i16> <i16 32767, i16 32767, i16 32767, i16 32767>, i8 -84)
139+
%abs = and <4 x i16> %mag, <i16 32767, i16 32767, i16 32767, i16 32767>
140+
%sign = and <4 x i16> %sgn, <i16 -32768, i16 -32768, i16 -32768, i16 -32768>
141+
%res = or <4 x i16> %abs, %sign
142+
ret <4 x i16> %res
143+
}

IGC/VectorCompiler/test/SPIRVBuiltins/math_native_builtins.ll

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ declare spir_func <16 x float> @_Z15__spirv_ocl_madDv16_fS_S_(<16 x float>, <16
3838
declare spir_func <7 x double> @_Z15__spirv_ocl_fmaDv7_dS_S_(<7 x double>, <7 x double>, <7 x double>)
3939
declare spir_func <7 x double> @_Z15__spirv_ocl_fmaxDv7_dS_S_(<7 x double>, <7 x double>)
4040
declare spir_func <16 x double> @_Z16__spirv_ocl_fabsDv16_d(<16 x double>)
41+
declare spir_func <7 x double> @_Z15__spirv_ocl_copysignDv7_dS_S_(<7 x double>, <7 x double>)
4142

4243
define spir_func i32 @popcount(i32 %arg) {
4344
; CHECK-LABEL: @popcount
@@ -229,3 +230,10 @@ define spir_func <16 x double> @abs_vector(<16 x double> %arg) {
229230
%res = call spir_func <16 x double> @_Z16__spirv_ocl_fabsDv16_d(<16 x double> %arg)
230231
ret <16 x double> %res
231232
}
233+
234+
define spir_func <7 x double> @copysign_vector(<7 x double> %arg1, <7 x double> %arg2) {
235+
; CHECK-LABEL: @copysign_vector
236+
; CHECK: %res = call <7 x double> @llvm.copysign.v7f64(<7 x double> %arg1, <7 x double> %arg2)
237+
%res = call spir_func <7 x double> @_Z15__spirv_ocl_copysignDv7_dS_S_(<7 x double> %arg1, <7 x double> %arg2)
238+
ret <7 x double> %res
239+
}

0 commit comments

Comments
 (0)