From b0be55c61fba57cf4ad883b65636d6d68d407a99 Mon Sep 17 00:00:00 2001 From: Farzon Lotfi Date: Tue, 25 Feb 2025 23:44:27 -0500 Subject: [PATCH] [DirectX] initialize registers properties by calling addRegisterClass and computeRegisterProperties This fixes #126784 for the DirectX backend. This bug was marked critical for DX so it is going to go in first. At least one register class needs to be added via addRegisterClass for RegClassForVT to be valid. Further for costing information used by loop unroll and other optimizations to be valid we need to call computeRegisterProperties. This change does both of these. The test cases confirm that we can fetch costing information off of `getRegisterInfo` and that `DirectXTargetLowering` maps i32 typed registers to DXILClassRegClass. --- .../Target/DirectX/DirectXTargetMachine.cpp | 5 +- llvm/unittests/Target/DirectX/CMakeLists.txt | 1 + .../Target/DirectX/RegisterCostTests.cpp | 65 +++++++++++++++++++ 3 files changed, 70 insertions(+), 1 deletion(-) create mode 100644 llvm/unittests/Target/DirectX/RegisterCostTests.cpp diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp index a76c07f784177..dda650b0f6e15 100644 --- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp +++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp @@ -187,4 +187,7 @@ DirectXTargetMachine::getTargetTransformInfo(const Function &F) const { DirectXTargetLowering::DirectXTargetLowering(const DirectXTargetMachine &TM, const DirectXSubtarget &STI) - : TargetLowering(TM) {} + : TargetLowering(TM) { + addRegisterClass(MVT::i32, &dxil::DXILClassRegClass); + computeRegisterProperties(STI.getRegisterInfo()); +} diff --git a/llvm/unittests/Target/DirectX/CMakeLists.txt b/llvm/unittests/Target/DirectX/CMakeLists.txt index fd0d5a0dd52c1..5087727ff9800 100644 --- a/llvm/unittests/Target/DirectX/CMakeLists.txt +++ b/llvm/unittests/Target/DirectX/CMakeLists.txt @@ -16,4 +16,5 @@ add_llvm_target_unittest(DirectXTests CBufferDataLayoutTests.cpp PointerTypeAnalysisTests.cpp UniqueResourceFromUseTests.cpp + RegisterCostTests.cpp ) diff --git a/llvm/unittests/Target/DirectX/RegisterCostTests.cpp b/llvm/unittests/Target/DirectX/RegisterCostTests.cpp new file mode 100644 index 0000000000000..ebf740ed73e98 --- /dev/null +++ b/llvm/unittests/Target/DirectX/RegisterCostTests.cpp @@ -0,0 +1,65 @@ +//===- llvm/unittests/Target/DirectX/RegisterCostTests.cpp ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "DirectXInstrInfo.h" +#include "DirectXTargetLowering.h" +#include "DirectXTargetMachine.h" +#include "TargetInfo/DirectXTargetInfo.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/MC/MCTargetOptions.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Support/TargetSelect.h" + +#include "gtest/gtest.h" + +using namespace llvm; +using namespace llvm::dxil; + +namespace { +class RegisterCostTests : public testing::Test { +protected: + DirectXInstrInfo DXInstInfo; + DirectXRegisterInfo RI; + DirectXTargetLowering *DL; + + virtual void SetUp() { + LLVMInitializeDirectXTargetMC(); + Target T = getTheDirectXTarget(); + RegisterTargetMachine X(T); + Triple TT("dxil-pc-shadermodel6.3-library"); + StringRef CPU = ""; + StringRef FS = ""; + DirectXTargetMachine TM(T, TT, CPU, FS, TargetOptions(), Reloc::Static, + CodeModel::Small, CodeGenOptLevel::Default, false); + LLVMContext Context; + Function *F = + Function::Create(FunctionType::get(Type::getVoidTy(Context), false), + Function::ExternalLinkage, 0); + DL = new DirectXTargetLowering(TM, *TM.getSubtargetImpl(*F)); + delete F; + } + virtual void TearDown() { delete DL; } +}; + +TEST_F(RegisterCostTests, TestRepRegClassForVTSet) { + const TargetRegisterClass *RC = DL->getRepRegClassFor(MVT::i32); + EXPECT_EQ(&dxil::DXILClassRegClass, RC); +} + +TEST_F(RegisterCostTests, TestTrivialCopyCostGetter) { + + const TargetRegisterClass *RC = DXInstInfo.getRegisterInfo().getRegClass(0); + unsigned Cost = RC->getCopyCost(); + EXPECT_EQ(1u, Cost); + + RC = RI.getRegClass(0); + Cost = RC->getCopyCost(); + EXPECT_EQ(1u, Cost); +} +} // namespace