Skip to content

Commit a8b8325

Browse files
committed
X86: make VBMI2 funnel shifts use VSHLD/VSHRD for const splats
Move constant splat handling for vector funnel shifts into a DAG combiner so that VBMI2 legal widths emit VSHLD/VSHRD directly (fixes #166949). Signed-off-by: Arnav Mehta <[email protected]>
1 parent 79c56e8 commit a8b8325

File tree

3 files changed

+215
-4
lines changed

3 files changed

+215
-4
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2073,8 +2073,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
20732073

20742074
if (Subtarget.hasVBMI2()) {
20752075
for (auto VT : {MVT::v32i16, MVT::v16i32, MVT::v8i64}) {
2076-
setOperationAction(ISD::FSHL, VT, Custom);
2077-
setOperationAction(ISD::FSHR, VT, Custom);
2076+
setOperationAction(ISD::FSHL, VT, Legal);
2077+
setOperationAction(ISD::FSHR, VT, Legal);
20782078
}
20792079

20802080
setOperationAction(ISD::ROTL, MVT::v32i16, Custom);
@@ -2089,8 +2089,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
20892089
if (!Subtarget.useSoftFloat() && Subtarget.hasVBMI2()) {
20902090
for (auto VT : {MVT::v8i16, MVT::v4i32, MVT::v2i64, MVT::v16i16, MVT::v8i32,
20912091
MVT::v4i64}) {
2092-
setOperationAction(ISD::FSHL, VT, Custom);
2093-
setOperationAction(ISD::FSHR, VT, Custom);
2092+
setOperationAction(ISD::FSHL, VT, Subtarget.hasVLX() ? Legal : Custom);
2093+
setOperationAction(ISD::FSHR, VT, Subtarget.hasVLX() ? Legal : Custom);
20942094
}
20952095
}
20962096

@@ -2703,6 +2703,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
27032703
ISD::STRICT_FP_EXTEND,
27042704
ISD::FP_ROUND,
27052705
ISD::STRICT_FP_ROUND,
2706+
ISD::FSHL,
2707+
ISD::FSHR,
27062708
ISD::INTRINSIC_VOID,
27072709
ISD::INTRINSIC_WO_CHAIN,
27082710
ISD::INTRINSIC_W_CHAIN});
@@ -57624,6 +57626,49 @@ static SDValue combineFP_TO_xINT_SAT(SDNode *N, SelectionDAG &DAG,
5762457626
return SDValue();
5762557627
}
5762657628

57629+
// Combiner: turn uniform-constant splat funnel shifts into VSHLD/VSHRD
57630+
static SDValue combineFunnelShift(SDNode *N, SelectionDAG &DAG,
57631+
TargetLowering::DAGCombinerInfo &DCI,
57632+
const X86Subtarget &Subtarget) {
57633+
57634+
SDLoc DL(N);
57635+
SDValue Op0 = N->getOperand(0);
57636+
SDValue Op1 = N->getOperand(1);
57637+
SDValue Amt = N->getOperand(2);
57638+
EVT VT = Op0.getValueType();
57639+
57640+
if (!VT.isVector() || !Subtarget.hasVBMI2())
57641+
return SDValue();
57642+
57643+
// Only combine if the operation is legal for this type.
57644+
// This ensures we don't try to convert types that need to be
57645+
// widened/promoted.
57646+
if (!DAG.getTargetLoweringInfo().isOperationLegal(N->getOpcode(), VT))
57647+
return SDValue();
57648+
57649+
unsigned EltSize = VT.getScalarSizeInBits();
57650+
57651+
if (EltSize <= 8)
57652+
return SDValue();
57653+
57654+
APInt ShiftVal;
57655+
if (!X86::isConstantSplat(Amt, ShiftVal))
57656+
return SDValue();
57657+
57658+
uint64_t ModAmt = ShiftVal.urem(EltSize);
57659+
57660+
SDValue Imm = DAG.getTargetConstant(ModAmt, DL, MVT::i8);
57661+
57662+
bool IsFSHR = N->getOpcode() == ISD::FSHR;
57663+
57664+
if (IsFSHR)
57665+
std::swap(Op0, Op1);
57666+
57667+
unsigned Opcode = IsFSHR ? X86ISD::VSHRD : X86ISD::VSHLD;
57668+
57669+
return DAG.getNode(Opcode, DL, VT, {Op0, Op1, Imm});
57670+
}
57671+
5762757672
static bool needCarryOrOverflowFlag(SDValue Flags) {
5762857673
assert(Flags.getValueType() == MVT::i32 && "Unexpected VT!");
5762957674

@@ -61228,6 +61273,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
6122861273
case ISD::INTRINSIC_VOID: return combineINTRINSIC_VOID(N, DAG, DCI);
6122961274
case ISD::FP_TO_SINT_SAT:
6123061275
case ISD::FP_TO_UINT_SAT: return combineFP_TO_xINT_SAT(N, DAG, Subtarget);
61276+
case ISD::FSHL: return combineFunnelShift(N, DAG, DCI, Subtarget);
61277+
case ISD::FSHR: return combineFunnelShift(N, DAG, DCI, Subtarget);
6123161278
// clang-format on
6123261279
}
6123361280

llvm/unittests/Target/X86/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ set(LLVM_LINK_COMPONENTS
2222
)
2323

2424
add_llvm_unittest(X86Tests
25+
X86SelectionDAGTest.cpp
2526
MachineSizeOptsTest.cpp
2627
TernlogTest.cpp
2728
)
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
//===- FunnelShiftCombineTest.cpp - X86 Funnel Shift Combine Tests --------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
10+
#include "llvm/AsmParser/Parser.h"
11+
#include "llvm/CodeGen/MachineModuleInfo.h"
12+
#include "llvm/CodeGen/SelectionDAG.h"
13+
#include "llvm/CodeGen/TargetLowering.h"
14+
#include "llvm/IR/Module.h"
15+
#include "llvm/MC/TargetRegistry.h"
16+
#include "llvm/Support/SourceMgr.h"
17+
#include "llvm/Support/TargetSelect.h"
18+
#include "llvm/Target/TargetMachine.h"
19+
#include "llvm/Target/TargetOptions.h"
20+
#include "gtest/gtest.h"
21+
22+
using namespace llvm;
23+
24+
namespace {
25+
26+
class X86FunnelShiftCombineTest : public testing::Test {
27+
protected:
28+
static void SetUpTestCase() {
29+
LLVMInitializeX86TargetInfo();
30+
LLVMInitializeX86Target();
31+
LLVMInitializeX86TargetMC();
32+
}
33+
34+
void SetUp() override {
35+
Triple TargetTriple("x86_64-unknown-unknown");
36+
std::string Error;
37+
const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error);
38+
if (!T)
39+
GTEST_SKIP();
40+
41+
TargetOptions Options;
42+
// Enable VBMI2 to test funnel shift combines
43+
TM = std::unique_ptr<TargetMachine>(T->createTargetMachine(
44+
TargetTriple, "", "+avx512f,+avx512vbmi2", Options, std::nullopt,
45+
std::nullopt, CodeGenOptLevel::Default));
46+
if (!TM)
47+
GTEST_SKIP();
48+
49+
StringRef Assembly = "define void @test() { ret void }";
50+
SMDiagnostic SMError;
51+
M = parseAssemblyString(Assembly, SMError, Context);
52+
ASSERT_TRUE(M && "Could not parse module!");
53+
M->setDataLayout(TM->createDataLayout());
54+
55+
F = M->getFunction("test");
56+
ASSERT_TRUE(F && "Could not get function test!");
57+
58+
MachineModuleInfo MMI(TM.get());
59+
MF = std::make_unique<MachineFunction>(*F, *TM, *TM->getSubtargetImpl(*F),
60+
MMI.getContext(), 0);
61+
62+
DAG = std::make_unique<SelectionDAG>(*TM, CodeGenOptLevel::Default);
63+
ASSERT_TRUE(DAG && "Failed to create SelectionDAG!");
64+
OptimizationRemarkEmitter ORE(F);
65+
DAG->init(*MF, ORE, nullptr, nullptr, nullptr, nullptr, nullptr, MMI,
66+
nullptr);
67+
}
68+
69+
LLVMContext Context;
70+
std::unique_ptr<TargetMachine> TM;
71+
std::unique_ptr<Module> M;
72+
Function *F;
73+
std::unique_ptr<MachineFunction> MF;
74+
std::unique_ptr<SelectionDAG> DAG;
75+
};
76+
77+
// Test that v16i32 is legal for VBMI2 (should be combined)
78+
TEST_F(X86FunnelShiftCombineTest, TestFSHLv16i32Legal) {
79+
MVT VT = MVT::v16i32;
80+
const TargetLowering &TLI = DAG->getTargetLoweringInfo();
81+
EXPECT_TRUE(TLI.isOperationLegal(ISD::FSHL, VT));
82+
EXPECT_TRUE(TLI.isOperationLegal(ISD::FSHR, VT));
83+
}
84+
85+
// Test that v8i64 is legal for VBMI2 (should be combined)
86+
TEST_F(X86FunnelShiftCombineTest, TestFSHRv8i64Legal) {
87+
MVT VT = MVT::v8i64;
88+
const TargetLowering &TLI = DAG->getTargetLoweringInfo();
89+
EXPECT_TRUE(TLI.isOperationLegal(ISD::FSHL, VT));
90+
EXPECT_TRUE(TLI.isOperationLegal(ISD::FSHR, VT));
91+
}
92+
93+
// Test that v2i32 is NOT legal for VBMI2 (should NOT be combined)
94+
TEST_F(X86FunnelShiftCombineTest, TestFSHLv2i32NonLegal) {
95+
MVT VT = MVT::v2i32;
96+
const TargetLowering &TLI = DAG->getTargetLoweringInfo();
97+
EXPECT_FALSE(TLI.isOperationLegal(ISD::FSHL, VT));
98+
EXPECT_FALSE(TLI.isOperationLegal(ISD::FSHR, VT));
99+
}
100+
101+
// Test that v32i16 is legal for VBMI2 (should be combined)
102+
TEST_F(X86FunnelShiftCombineTest, TestFSHLv32i16Legal) {
103+
MVT VT = MVT::v32i16;
104+
const TargetLowering &TLI = DAG->getTargetLoweringInfo();
105+
EXPECT_TRUE(TLI.isOperationLegal(ISD::FSHL, VT));
106+
EXPECT_TRUE(TLI.isOperationLegal(ISD::FSHR, VT));
107+
}
108+
109+
// Test that v8i16 with VLX is legal
110+
TEST_F(X86FunnelShiftCombineTest, TestFSHLv8i16WithVLX) {
111+
Triple TargetTriple("x86_64-unknown-unknown");
112+
std::string Error;
113+
const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error);
114+
ASSERT_TRUE(T);
115+
116+
TargetOptions Options;
117+
TM = std::unique_ptr<TargetMachine>(T->createTargetMachine(
118+
TargetTriple, "", "+avx512f,+avx512vbmi2,+avx512vl", Options,
119+
std::nullopt, std::nullopt, CodeGenOptLevel::Default));
120+
ASSERT_TRUE(TM);
121+
122+
MachineModuleInfo MMI(TM.get());
123+
MF = std::make_unique<MachineFunction>(*F, *TM, *TM->getSubtargetImpl(*F),
124+
MMI.getContext(), 0);
125+
DAG = std::make_unique<SelectionDAG>(*TM, CodeGenOptLevel::Default);
126+
OptimizationRemarkEmitter ORE(F);
127+
DAG->init(*MF, ORE, nullptr, nullptr, nullptr, nullptr, nullptr, MMI,
128+
nullptr);
129+
130+
MVT VT = MVT::v8i16;
131+
const TargetLowering &TLI = DAG->getTargetLoweringInfo();
132+
EXPECT_TRUE(TLI.isOperationLegal(ISD::FSHL, VT));
133+
EXPECT_TRUE(TLI.isOperationLegal(ISD::FSHR, VT));
134+
}
135+
136+
// Test that v4i32 with VLX is legal
137+
TEST_F(X86FunnelShiftCombineTest, TestFSHLv4i32WithVLX) {
138+
Triple TargetTriple("x86_64-unknown-unknown");
139+
std::string Error;
140+
const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error);
141+
ASSERT_TRUE(T);
142+
143+
TargetOptions Options;
144+
TM = std::unique_ptr<TargetMachine>(T->createTargetMachine(
145+
TargetTriple, "", "+avx512f,+avx512vbmi2,+avx512vl", Options,
146+
std::nullopt, std::nullopt, CodeGenOptLevel::Default));
147+
ASSERT_TRUE(TM);
148+
149+
MachineModuleInfo MMI(TM.get());
150+
MF = std::make_unique<MachineFunction>(*F, *TM, *TM->getSubtargetImpl(*F),
151+
MMI.getContext(), 0);
152+
DAG = std::make_unique<SelectionDAG>(*TM, CodeGenOptLevel::Default);
153+
OptimizationRemarkEmitter ORE(F);
154+
DAG->init(*MF, ORE, nullptr, nullptr, nullptr, nullptr, nullptr, MMI,
155+
nullptr);
156+
157+
MVT VT = MVT::v4i32;
158+
const TargetLowering &TLI = DAG->getTargetLoweringInfo();
159+
EXPECT_TRUE(TLI.isOperationLegal(ISD::FSHL, VT));
160+
EXPECT_TRUE(TLI.isOperationLegal(ISD::FSHR, VT));
161+
}
162+
163+
} // namespace

0 commit comments

Comments
 (0)