Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21340,6 +21340,15 @@ void RISCVTargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
Known = Known.sext(BitWidth);
break;
}
case RISCVISD::SRLW: {
KnownBits Known2;
Known = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
Known2 = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
Known = KnownBits::lshr(Known.trunc(32), Known2.trunc(5).zext(32));
// Restore the original width by sign extending.
Known = Known.sext(BitWidth);
break;
}
case RISCVISD::CTZW: {
KnownBits Known2 = DAG.computeKnownBits(Op.getOperand(0), Depth + 1);
unsigned PossibleTZ = Known2.trunc(32).countMaxTrailingZeros();
Expand Down
1 change: 1 addition & 0 deletions llvm/unittests/Target/RISCV/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ set(LLVM_LINK_COMPONENTS
add_llvm_target_unittest(RISCVTests
MCInstrAnalysisTest.cpp
RISCVInstrInfoTest.cpp
RISCVSelectionDAGTest.cpp
)
110 changes: 110 additions & 0 deletions llvm/unittests/Target/RISCV/RISCVSelectionDAGTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
//===----------------------------------------------------------------------===//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "RISCVISelLowering.h"
#include "RISCVSelectionDAGInfo.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/KnownBits.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetMachine.h"
#include "gtest/gtest.h"

namespace llvm {

class RISCVSelectionDAGTest : public testing::Test {

protected:
static void SetUpTestCase() {
LLVMInitializeRISCVTargetInfo();
LLVMInitializeRISCVTarget();
LLVMInitializeRISCVTargetMC();
}

void SetUp() override {
StringRef Assembly = "define void @f() { ret void }";

Triple TargetTriple("riscv64", "unknown", "linux");

std::string Error;
const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error);

TargetOptions Options;
TM = std::unique_ptr<TargetMachine>(T->createTargetMachine(
TargetTriple, "generic", "", Options, std::nullopt, std::nullopt,
CodeGenOptLevel::Default));

SMDiagnostic SMError;
M = parseAssemblyString(Assembly, SMError, Context);
if (!M)
report_fatal_error(SMError.getMessage());
M->setDataLayout(TM->createDataLayout());

F = M->getFunction("f");
if (!F)
report_fatal_error("Function 'f' not found");

MachineModuleInfo MMI(TM.get());

MF = std::make_unique<MachineFunction>(*F, *TM, *TM->getSubtargetImpl(*F),
MMI.getContext(), /*FunctionNum*/ 0);

DAG = std::make_unique<SelectionDAG>(*TM, CodeGenOptLevel::None);
if (!DAG)
report_fatal_error("SelectionDAG allocation failed");

OptimizationRemarkEmitter ORE(F);
DAG->init(*MF, ORE, /*LibInfo*/ nullptr, /*AA*/ nullptr,
/*AC*/ nullptr, /*MDT*/ nullptr, /*MSDT*/ nullptr, MMI, nullptr);
}

LLVMContext Context;
std::unique_ptr<TargetMachine> TM;
std::unique_ptr<Module> M;
Function *F = nullptr;
std::unique_ptr<MachineFunction> MF;
std::unique_ptr<SelectionDAG> DAG;
};

/// SRLW: Logical Shift Right
TEST_F(RISCVSelectionDAGTest, computeKnownBits_SRLW) {
// Following DAG is created from this IR snippet:
//
// define i64 @f(i32 %x, i32 %y) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could also add a LIT codegen test instead of unittest using this LLVM IR as input. Personally I thought that would be more succinct

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to write a LIT test, but couldn't get my head around it. My rationale to write a unittest was that I am adding a small portion of code for Known Bits Analysis of the SRLW, so I have to check whether computeKnownBits for SRLW returns appropriately. Is there a way to test for the values of known bits through the llvm IR, so a lit test can be written? As I understand, a lit test would be able to check if the SRLW instruction has been generated or not (and in the expected format). This, I believe already happens in the RISCV codegen. Can you show me an example to understand this clearly?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't test the KnownBits directly, but you can test that it enables optimizations. Here's the recent PR I wrote for sign bits of SRAW. #155564

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, would you recommend I work on a lit test too? What about the current unittest (drop it/keep it)?

// %a = and i32 %x, 2147483647 ; zeros the MSB for %x
// %b = lshr i32 %a, %y
// %c = zext i32 %b to i64 ; makes the most significant 32 bits 0
// ret i64 %c
// }
SDLoc Loc;
auto IntVT = EVT::getIntegerVT(Context, 32);
auto Int64VT = EVT::getIntegerVT(Context, 64);
auto Px = DAG->getRegister(0, IntVT);
auto Py = DAG->getConstant(2147483647, Loc, IntVT);
auto N1 = DAG->getNode(ISD::AND, Loc, IntVT, Px, Py);
auto Qx = DAG->getRegister(0, Int64VT);
auto N2 = DAG->getNode(RISCVISD::SRLW, Loc, Int64VT, N1, Qx);
auto N3 = DAG->getNode(ISD::ZERO_EXTEND, Loc, Int64VT, N2);
// N1 = 0???????????????????????????????
// N2 = 0???????????????????????????????
// N3 = 000000000000000000000000000000000???????????????????????????????
// After zero extend, we expect 33 most significant zeros to be known:
// 32 from sign extension and 1 from AND operation
KnownBits Known = DAG->computeKnownBits(N3);
EXPECT_EQ(Known.Zero, APInt(64, -2147483648));
EXPECT_EQ(Known.One, APInt(64, 0));
}

} // end namespace llvm