From 162e8b4fd5be4571c15b76d240287703593274ec Mon Sep 17 00:00:00 2001 From: Farzon Lotfi Date: Thu, 27 Feb 2025 18:04:32 -0500 Subject: [PATCH 1/7] [DirectX] Working i8 legalization pass --- llvm/lib/Target/DirectX/CMakeLists.txt | 1 + llvm/lib/Target/DirectX/DirectX.h | 6 + .../Target/DirectX/DirectXPassRegistry.def | 1 + .../Target/DirectX/DirectXTargetMachine.cpp | 3 + llvm/lib/Target/DirectX/LegalizeI8Pass.cpp | 127 ++++++++++++++++++ llvm/lib/Target/DirectX/LegalizeI8Pass.h | 23 ++++ llvm/test/CodeGen/DirectX/legalize-i8.ll | 16 +++ 7 files changed, 177 insertions(+) create mode 100644 llvm/lib/Target/DirectX/LegalizeI8Pass.cpp create mode 100644 llvm/lib/Target/DirectX/LegalizeI8Pass.h create mode 100644 llvm/test/CodeGen/DirectX/legalize-i8.ll diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt index 6904a1c0f1e73..0b3b6a23ce739 100644 --- a/llvm/lib/Target/DirectX/CMakeLists.txt +++ b/llvm/lib/Target/DirectX/CMakeLists.txt @@ -32,6 +32,7 @@ add_llvm_target(DirectXCodeGen DXILShaderFlags.cpp DXILTranslateMetadata.cpp DXILRootSignature.cpp + LegalizeI8Pass.cpp LINK_COMPONENTS Analysis diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h index 42aa0da16e8aa..482f3b5a8f694 100644 --- a/llvm/lib/Target/DirectX/DirectX.h +++ b/llvm/lib/Target/DirectX/DirectX.h @@ -47,6 +47,12 @@ void initializeDXILFlattenArraysLegacyPass(PassRegistry &); /// Pass to flatten arrays into a one dimensional DXIL legal form ModulePass *createDXILFlattenArraysLegacyPass(); +/// Initializer I8 legalizationPass +void initializeLegalizeI8LegacyPass(PassRegistry &); + +/// Pass to remove i8 truncations +FunctionPass *createLegalizeI8LegacyPass(); + /// Initializer for DXILOpLowering void initializeDXILOpLoweringLegacyPass(PassRegistry &); diff --git a/llvm/lib/Target/DirectX/DirectXPassRegistry.def b/llvm/lib/Target/DirectX/DirectXPassRegistry.def index aee0a4ff83d43..297c6c10f68a3 100644 --- a/llvm/lib/Target/DirectX/DirectXPassRegistry.def +++ b/llvm/lib/Target/DirectX/DirectXPassRegistry.def @@ -38,4 +38,5 @@ MODULE_PASS("print", dxil::RootSignatureAnalysisPrinter(dbg #define FUNCTION_PASS(NAME, CREATE_PASS) #endif FUNCTION_PASS("dxil-resource-access", DXILResourceAccess()) +FUNCTION_PASS("dxil-legalize-i8", LegalizeI8Pass()) #undef FUNCTION_PASS diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp index 82dc1c6af562a..ec6b69089fb64 100644 --- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp +++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp @@ -25,6 +25,7 @@ #include "DirectX.h" #include "DirectXSubtarget.h" #include "DirectXTargetTransformInfo.h" +#include "LegalizeI8Pass.h" #include "TargetInfo/DirectXTargetInfo.h" #include "llvm/CodeGen/MachineModuleInfo.h" #include "llvm/CodeGen/Passes.h" @@ -52,6 +53,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() { initializeDXILDataScalarizationLegacyPass(*PR); initializeDXILFlattenArraysLegacyPass(*PR); initializeScalarizerLegacyPassPass(*PR); + initializeLegalizeI8LegacyPass(*PR); initializeDXILPrepareModulePass(*PR); initializeEmbedDXILPassPass(*PR); initializeWriteDXILPassPass(*PR); @@ -100,6 +102,7 @@ class DirectXPassConfig : public TargetPassConfig { DxilScalarOptions.ScalarizeLoadStore = true; addPass(createScalarizerPass(DxilScalarOptions)); addPass(createDXILTranslateMetadataLegacyPass()); + addPass(createLegalizeI8LegacyPass()); addPass(createDXILOpLoweringLegacyPass()); addPass(createDXILPrepareModulePass()); } diff --git a/llvm/lib/Target/DirectX/LegalizeI8Pass.cpp b/llvm/lib/Target/DirectX/LegalizeI8Pass.cpp new file mode 100644 index 0000000000000..8faca2be8048c --- /dev/null +++ b/llvm/lib/Target/DirectX/LegalizeI8Pass.cpp @@ -0,0 +1,127 @@ +//===- LegalizeI8Pass.cpp - A pass that reverts i8 conversions-*- C++ ---*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// +//===---------------------------------------------------------------------===// +/// +/// \file This file contains a pass to remove i8 truncations. +/// +//===----------------------------------------------------------------------===// +#include "DirectX.h" +#include "LegalizeI8Pass.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instruction.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include +#include + +#define DEBUG_TYPE "dxil-legalize-i8" + +using namespace llvm; +namespace { + +class LegalizeI8Legacy : public FunctionPass { + +public: + bool runOnFunction(Function &F) override; + LegalizeI8Legacy() : FunctionPass(ID) {} + + static char ID; // Pass identification. +}; +} // namespace + +static bool fixI8TruncUseChain(Function &F) { + std::stack ToRemove; + std::map ReplacedValues; + + for (auto &I : instructions(F)) { + if (auto *Trunc = dyn_cast(&I)) { + if (Trunc->getDestTy()->isIntegerTy(8)) { + ReplacedValues[Trunc] = Trunc->getOperand(0); + ToRemove.push(Trunc); + } + } else if (I.getType()->isIntegerTy(8)) { + IRBuilder<> Builder(&I); + + std::vector NewOperands; + Type* InstrType = nullptr; + for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) { + Value *Op = I.getOperand(OpIdx); + if (ReplacedValues.count(Op)) { + InstrType = ReplacedValues[Op]->getType(); + NewOperands.push_back(ReplacedValues[Op]); + } + else if (auto *Imm = dyn_cast(Op)) { + APInt Value = Imm->getValue(); + unsigned NewBitWidth = InstrType->getIntegerBitWidth(); + // Note: options here are sext or sextOrTrunc. + // Since i8 isn't suppport we assume new values + // will always have a higher bitness. + APInt NewValue = Value.sext(NewBitWidth); + NewOperands.push_back(ConstantInt::get(InstrType, NewValue)); + } else { + assert(!Op->getType()->isIntegerTy(8)); + NewOperands.push_back(Op); + } + + } + + Value *NewInst = nullptr; + if (auto *BO = dyn_cast(&I)) + NewInst = Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]); + else if (auto *Cmp = dyn_cast(&I)) + NewInst = Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0], NewOperands[1]); + else if (auto *Cast = dyn_cast(&I)) + NewInst = Builder.CreateCast(Cast->getOpcode(), NewOperands[0], Cast->getDestTy()); + else if (auto *UnaryOp = dyn_cast(&I)) + NewInst = Builder.CreateUnOp(UnaryOp->getOpcode(), NewOperands[0]); + + if (NewInst) { + ReplacedValues[&I] = NewInst; + ToRemove.push(&I); + } + } else if (auto *Sext = dyn_cast(&I)) { + if (Sext->getSrcTy()->isIntegerTy(8)) { + ToRemove.push(Sext); + Sext->replaceAllUsesWith(ReplacedValues[Sext->getOperand(0)]); + } + } + } + + while (!ToRemove.empty()) { + Instruction *I = ToRemove.top(); + I->eraseFromParent(); + ToRemove.pop(); + } + + return true; +} + +PreservedAnalyses LegalizeI8Pass::run(Function &F, FunctionAnalysisManager &FAM) { + bool MadeChanges = fixI8TruncUseChain(F); + if (!MadeChanges) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + return PA; + } + + bool LegalizeI8Legacy::runOnFunction(Function &F) { + return fixI8TruncUseChain(F); + } + + char LegalizeI8Legacy::ID = 0; + + INITIALIZE_PASS_BEGIN(LegalizeI8Legacy, DEBUG_TYPE, + "DXIL I8 Legalizer", false, false) + INITIALIZE_PASS_END(LegalizeI8Legacy, DEBUG_TYPE, "DXIL I8 Legalizer", + false, false) + +FunctionPass *llvm::createLegalizeI8LegacyPass() { + return new LegalizeI8Legacy(); + } \ No newline at end of file diff --git a/llvm/lib/Target/DirectX/LegalizeI8Pass.h b/llvm/lib/Target/DirectX/LegalizeI8Pass.h new file mode 100644 index 0000000000000..30ba4a88b8176 --- /dev/null +++ b/llvm/lib/Target/DirectX/LegalizeI8Pass.h @@ -0,0 +1,23 @@ +//===- LegalizeI8Pass.h - A pass that reverts i8 conversions-*- C++ -----*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#ifndef LLVM_TARGET_DIRECTX_LEGALIZEI8_H +#define LLVM_TARGET_DIRECTX_LEGALIZEI8_H + +#include "llvm/IR/PassManager.h" + +namespace llvm { + +/// A pass that transforms multidimensional arrays into one-dimensional arrays. +class LegalizeI8Pass : public PassInfoMixin { +public: + PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM); +}; +} // namespace llvm + +#endif // LLVM_TARGET_DIRECTX_LEGALIZEI8_H diff --git a/llvm/test/CodeGen/DirectX/legalize-i8.ll b/llvm/test/CodeGen/DirectX/legalize-i8.ll new file mode 100644 index 0000000000000..d9787531ae1c4 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/legalize-i8.ll @@ -0,0 +1,16 @@ +; RUN: opt -S -passes='dxil-legalize-i8' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s + +define i32 @i8trunc(float %0) #0 { + ; CHECK-NOT: %4 = trunc nsw i32 %3 to i8 + ; CHECK: add i32 + ; CHECK: srem i32 + ; CHECK-NOT: %7 = sext i8 %6 to i32 + + %2 = fptosi float %0 to i32 + %3 = srem i32 %2, 8 + %4 = trunc nsw i32 %3 to i8 + %5 = add nsw i8 %4, 1 + %6 = srem i8 %5, 8 + %7 = sext i8 %6 to i32 + ret i32 %7 +} \ No newline at end of file From 679ac35887e0aa69e7a62964a1b0979fdac3d838 Mon Sep 17 00:00:00 2001 From: Farzon Lotfi Date: Thu, 13 Mar 2025 16:54:16 -0400 Subject: [PATCH 2/7] modify the i8 legalization pass to be a more generic legalization so we can reduce i64 insert/extracts to i32 --- llvm/lib/Target/DirectX/CMakeLists.txt | 2 +- llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 201 ++++++++++++++++++ .../{LegalizeI8Pass.h => DXILLegalizePass.h} | 11 +- llvm/lib/Target/DirectX/DirectX.h | 9 +- .../Target/DirectX/DirectXPassRegistry.def | 2 +- .../Target/DirectX/DirectXTargetMachine.cpp | 6 +- llvm/lib/Target/DirectX/LegalizeI8Pass.cpp | 127 ----------- .../DirectX/ResourceGlobalElimination.ll | 4 +- .../legalize-i64-extract-insert-elements.ll | 24 +++ llvm/test/CodeGen/DirectX/legalize-i8.ll | 32 ++- llvm/test/CodeGen/DirectX/llc-pipeline.ll | 1 + .../DirectX/llc-vector-load-scalarize.ll | 32 +-- .../CodeGen/DirectX/scalarize-two-calls.ll | 16 +- 13 files changed, 294 insertions(+), 173 deletions(-) create mode 100644 llvm/lib/Target/DirectX/DXILLegalizePass.cpp rename llvm/lib/Target/DirectX/{LegalizeI8Pass.h => DXILLegalizePass.h} (55%) delete mode 100644 llvm/lib/Target/DirectX/LegalizeI8Pass.cpp create mode 100644 llvm/test/CodeGen/DirectX/legalize-i64-extract-insert-elements.ll diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt index 0b3b6a23ce739..13f8adbe4f132 100644 --- a/llvm/lib/Target/DirectX/CMakeLists.txt +++ b/llvm/lib/Target/DirectX/CMakeLists.txt @@ -32,7 +32,7 @@ add_llvm_target(DirectXCodeGen DXILShaderFlags.cpp DXILTranslateMetadata.cpp DXILRootSignature.cpp - LegalizeI8Pass.cpp + DXILLegalizePass.cpp LINK_COMPONENTS Analysis diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp new file mode 100644 index 0000000000000..3db718d56adf1 --- /dev/null +++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp @@ -0,0 +1,201 @@ +//===- DXILLegalizePass.cpp - Legalizes llvm IR for DXIL-*- C++----------*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// +//===---------------------------------------------------------------------===// +/// +/// \file This file contains a pass to remove i8 truncations and i64 extract +/// and insert elements. +/// +//===----------------------------------------------------------------------===// +#include "DXILLegalizePass.h" +#include "DirectX.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instruction.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include +#include +#include +#include + +#define DEBUG_TYPE "dxil-legalize" + +using namespace llvm; +namespace { + +static bool fixI8TruncUseChain(Instruction &I, + std::stack &ToRemove, + std::map &ReplacedValues) { + + if (auto *Trunc = dyn_cast(&I)) { + if (Trunc->getDestTy()->isIntegerTy(8)) { + ReplacedValues[Trunc] = Trunc->getOperand(0); + ToRemove.push(Trunc); + } + } else if (I.getType()->isIntegerTy(8)) { + IRBuilder<> Builder(&I); + + std::vector NewOperands; + Type *InstrType = nullptr; + for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) { + Value *Op = I.getOperand(OpIdx); + if (ReplacedValues.count(Op)) { + InstrType = ReplacedValues[Op]->getType(); + NewOperands.push_back(ReplacedValues[Op]); + } else if (auto *Imm = dyn_cast(Op)) { + APInt Value = Imm->getValue(); + unsigned NewBitWidth = InstrType->getIntegerBitWidth(); + // Note: options here are sext or sextOrTrunc. + // Since i8 isn't suppport we assume new values + // will always have a higher bitness. + APInt NewValue = Value.sext(NewBitWidth); + NewOperands.push_back(ConstantInt::get(InstrType, NewValue)); + } else { + assert(!Op->getType()->isIntegerTy(8)); + NewOperands.push_back(Op); + } + } + + Value *NewInst = nullptr; + if (auto *BO = dyn_cast(&I)) + NewInst = + Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]); + else if (auto *Cmp = dyn_cast(&I)) + NewInst = Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0], + NewOperands[1]); + else if (auto *Cast = dyn_cast(&I)) + NewInst = Builder.CreateCast(Cast->getOpcode(), NewOperands[0], + Cast->getDestTy()); + else if (auto *UnaryOp = dyn_cast(&I)) + NewInst = Builder.CreateUnOp(UnaryOp->getOpcode(), NewOperands[0]); + + if (NewInst) { + ReplacedValues[&I] = NewInst; + ToRemove.push(&I); + } + } else if (auto *Sext = dyn_cast(&I)) { + if (Sext->getSrcTy()->isIntegerTy(8)) { + ToRemove.push(Sext); + Sext->replaceAllUsesWith(ReplacedValues[Sext->getOperand(0)]); + } + } + + return !ToRemove.empty(); +} + +static bool +downcastI64toI32InsertExtractElements(Instruction &I, + std::stack &ToRemove, + std::map &) { + + if (auto *Extract = dyn_cast(&I)) { + Value *Idx = Extract->getIndexOperand(); + auto *CI = dyn_cast(Idx); + if (CI && CI->getBitWidth() == 64) { + IRBuilder<> Builder(Extract); + int64_t IndexValue = CI->getSExtValue(); + auto *Idx32 = + ConstantInt::get(Type::getInt32Ty(I.getContext()), IndexValue); + Value *NewExtract = Builder.CreateExtractElement( + Extract->getVectorOperand(), Idx32, Extract->getName()); + + Extract->replaceAllUsesWith(NewExtract); + ToRemove.push(Extract); + } + } + + if (auto *Insert = dyn_cast(&I)) { + Value *Idx = Insert->getOperand(2); + auto *CI = dyn_cast(Idx); + if (CI && CI->getBitWidth() == 64) { + int64_t IndexValue = CI->getSExtValue(); + auto *Idx32 = + ConstantInt::get(Type::getInt32Ty(I.getContext()), IndexValue); + IRBuilder<> Builder(Insert); + Value *Insert32Index = Builder.CreateInsertElement( + Insert->getOperand(0), Insert->getOperand(1), Idx32, + Insert->getName()); + + Insert->replaceAllUsesWith(Insert32Index); + ToRemove.push(Insert); + } + } + + return !ToRemove.empty(); +} + +class DXILLegalizationPipeline { + +public: + DXILLegalizationPipeline() { initializeLegalizationPipeline(); } + + bool runLegalizationPipeline(Function &F) { + std::stack ToRemove; + std::map ReplacedValues; + bool MadeChanges = false; + for (auto &I : instructions(F)) { + for (auto &LegalizationFn : LegalizationPipeline) { + MadeChanges = LegalizationFn(I, ToRemove, ReplacedValues); + } + } + while (!ToRemove.empty()) { + Instruction *I = ToRemove.top(); + I->eraseFromParent(); + ToRemove.pop(); + } + + return MadeChanges; + } + +private: + std::vector &, + std::map &)>> + LegalizationPipeline; + + void initializeLegalizationPipeline() { + LegalizationPipeline.push_back(fixI8TruncUseChain); + LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements); + } +}; + +class DXILLegalizeLegacy : public FunctionPass { + +public: + bool runOnFunction(Function &F) override; + DXILLegalizeLegacy() : FunctionPass(ID) {} + + static char ID; // Pass identification. +}; +} // namespace + +PreservedAnalyses DXILLegalizePass::run(Function &F, + FunctionAnalysisManager &FAM) { + DXILLegalizationPipeline DXLegalize; + bool MadeChanges = DXLegalize.runLegalizationPipeline(F); + if (!MadeChanges) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + return PA; +} + +bool DXILLegalizeLegacy::runOnFunction(Function &F) { + DXILLegalizationPipeline DXLegalize; + return DXLegalize.runLegalizationPipeline(F); +} + +char DXILLegalizeLegacy::ID = 0; + +INITIALIZE_PASS_BEGIN(DXILLegalizeLegacy, DEBUG_TYPE, "DXIL Legalizer", false, + false) +INITIALIZE_PASS_END(DXILLegalizeLegacy, DEBUG_TYPE, "DXIL Legalizer", false, + false) + +FunctionPass *llvm::createDXILLegalizeLegacyPass() { + return new DXILLegalizeLegacy(); +} diff --git a/llvm/lib/Target/DirectX/LegalizeI8Pass.h b/llvm/lib/Target/DirectX/DXILLegalizePass.h similarity index 55% rename from llvm/lib/Target/DirectX/LegalizeI8Pass.h rename to llvm/lib/Target/DirectX/DXILLegalizePass.h index 30ba4a88b8176..39ef6f532dca0 100644 --- a/llvm/lib/Target/DirectX/LegalizeI8Pass.h +++ b/llvm/lib/Target/DirectX/DXILLegalizePass.h @@ -1,4 +1,4 @@ -//===- LegalizeI8Pass.h - A pass that reverts i8 conversions-*- C++ -----*-===// +//===- DXILLegalizePass.h - Legalizes llvm IR for DXIL-*- C++------------*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,18 +6,17 @@ // //===---------------------------------------------------------------------===// -#ifndef LLVM_TARGET_DIRECTX_LEGALIZEI8_H -#define LLVM_TARGET_DIRECTX_LEGALIZEI8_H +#ifndef LLVM_TARGET_DIRECTX_LEGALIZE_H +#define LLVM_TARGET_DIRECTX_LEGALIZE_H #include "llvm/IR/PassManager.h" namespace llvm { -/// A pass that transforms multidimensional arrays into one-dimensional arrays. -class LegalizeI8Pass : public PassInfoMixin { +class DXILLegalizePass : public PassInfoMixin { public: PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM); }; } // namespace llvm -#endif // LLVM_TARGET_DIRECTX_LEGALIZEI8_H +#endif // LLVM_TARGET_DIRECTX_LEGALIZE_H diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h index 482f3b5a8f694..96a8a08c875f8 100644 --- a/llvm/lib/Target/DirectX/DirectX.h +++ b/llvm/lib/Target/DirectX/DirectX.h @@ -47,11 +47,12 @@ void initializeDXILFlattenArraysLegacyPass(PassRegistry &); /// Pass to flatten arrays into a one dimensional DXIL legal form ModulePass *createDXILFlattenArraysLegacyPass(); -/// Initializer I8 legalizationPass -void initializeLegalizeI8LegacyPass(PassRegistry &); +/// Initializer DXIL legalizationPass +void initializeDXILLegalizeLegacyPass(PassRegistry &); -/// Pass to remove i8 truncations -FunctionPass *createLegalizeI8LegacyPass(); +/// Pass to Legalize DXIL by remove i8 truncations and i64 insert/extract +/// elements +FunctionPass *createDXILLegalizeLegacyPass(); /// Initializer for DXILOpLowering void initializeDXILOpLoweringLegacyPass(PassRegistry &); diff --git a/llvm/lib/Target/DirectX/DirectXPassRegistry.def b/llvm/lib/Target/DirectX/DirectXPassRegistry.def index 297c6c10f68a3..87d91ead1896f 100644 --- a/llvm/lib/Target/DirectX/DirectXPassRegistry.def +++ b/llvm/lib/Target/DirectX/DirectXPassRegistry.def @@ -38,5 +38,5 @@ MODULE_PASS("print", dxil::RootSignatureAnalysisPrinter(dbg #define FUNCTION_PASS(NAME, CREATE_PASS) #endif FUNCTION_PASS("dxil-resource-access", DXILResourceAccess()) -FUNCTION_PASS("dxil-legalize-i8", LegalizeI8Pass()) +FUNCTION_PASS("dxil-legalize", DXILLegalizePass()) #undef FUNCTION_PASS diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp index ec6b69089fb64..ce408b4034f83 100644 --- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp +++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp @@ -15,6 +15,7 @@ #include "DXILDataScalarization.h" #include "DXILFlattenArrays.h" #include "DXILIntrinsicExpansion.h" +#include "DXILLegalizePass.h" #include "DXILOpLowering.h" #include "DXILPrettyPrinter.h" #include "DXILResourceAccess.h" @@ -25,7 +26,6 @@ #include "DirectX.h" #include "DirectXSubtarget.h" #include "DirectXTargetTransformInfo.h" -#include "LegalizeI8Pass.h" #include "TargetInfo/DirectXTargetInfo.h" #include "llvm/CodeGen/MachineModuleInfo.h" #include "llvm/CodeGen/Passes.h" @@ -53,7 +53,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() { initializeDXILDataScalarizationLegacyPass(*PR); initializeDXILFlattenArraysLegacyPass(*PR); initializeScalarizerLegacyPassPass(*PR); - initializeLegalizeI8LegacyPass(*PR); + initializeDXILLegalizeLegacyPass(*PR); initializeDXILPrepareModulePass(*PR); initializeEmbedDXILPassPass(*PR); initializeWriteDXILPassPass(*PR); @@ -101,8 +101,8 @@ class DirectXPassConfig : public TargetPassConfig { ScalarizerPassOptions DxilScalarOptions; DxilScalarOptions.ScalarizeLoadStore = true; addPass(createScalarizerPass(DxilScalarOptions)); + addPass(createDXILLegalizeLegacyPass()); addPass(createDXILTranslateMetadataLegacyPass()); - addPass(createLegalizeI8LegacyPass()); addPass(createDXILOpLoweringLegacyPass()); addPass(createDXILPrepareModulePass()); } diff --git a/llvm/lib/Target/DirectX/LegalizeI8Pass.cpp b/llvm/lib/Target/DirectX/LegalizeI8Pass.cpp deleted file mode 100644 index 8faca2be8048c..0000000000000 --- a/llvm/lib/Target/DirectX/LegalizeI8Pass.cpp +++ /dev/null @@ -1,127 +0,0 @@ -//===- LegalizeI8Pass.cpp - A pass that reverts i8 conversions-*- C++ ---*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===---------------------------------------------------------------------===// -//===---------------------------------------------------------------------===// -/// -/// \file This file contains a pass to remove i8 truncations. -/// -//===----------------------------------------------------------------------===// -#include "DirectX.h" -#include "LegalizeI8Pass.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/InstIterator.h" -#include "llvm/IR/Instruction.h" -#include "llvm/Pass.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include -#include - -#define DEBUG_TYPE "dxil-legalize-i8" - -using namespace llvm; -namespace { - -class LegalizeI8Legacy : public FunctionPass { - -public: - bool runOnFunction(Function &F) override; - LegalizeI8Legacy() : FunctionPass(ID) {} - - static char ID; // Pass identification. -}; -} // namespace - -static bool fixI8TruncUseChain(Function &F) { - std::stack ToRemove; - std::map ReplacedValues; - - for (auto &I : instructions(F)) { - if (auto *Trunc = dyn_cast(&I)) { - if (Trunc->getDestTy()->isIntegerTy(8)) { - ReplacedValues[Trunc] = Trunc->getOperand(0); - ToRemove.push(Trunc); - } - } else if (I.getType()->isIntegerTy(8)) { - IRBuilder<> Builder(&I); - - std::vector NewOperands; - Type* InstrType = nullptr; - for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) { - Value *Op = I.getOperand(OpIdx); - if (ReplacedValues.count(Op)) { - InstrType = ReplacedValues[Op]->getType(); - NewOperands.push_back(ReplacedValues[Op]); - } - else if (auto *Imm = dyn_cast(Op)) { - APInt Value = Imm->getValue(); - unsigned NewBitWidth = InstrType->getIntegerBitWidth(); - // Note: options here are sext or sextOrTrunc. - // Since i8 isn't suppport we assume new values - // will always have a higher bitness. - APInt NewValue = Value.sext(NewBitWidth); - NewOperands.push_back(ConstantInt::get(InstrType, NewValue)); - } else { - assert(!Op->getType()->isIntegerTy(8)); - NewOperands.push_back(Op); - } - - } - - Value *NewInst = nullptr; - if (auto *BO = dyn_cast(&I)) - NewInst = Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]); - else if (auto *Cmp = dyn_cast(&I)) - NewInst = Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0], NewOperands[1]); - else if (auto *Cast = dyn_cast(&I)) - NewInst = Builder.CreateCast(Cast->getOpcode(), NewOperands[0], Cast->getDestTy()); - else if (auto *UnaryOp = dyn_cast(&I)) - NewInst = Builder.CreateUnOp(UnaryOp->getOpcode(), NewOperands[0]); - - if (NewInst) { - ReplacedValues[&I] = NewInst; - ToRemove.push(&I); - } - } else if (auto *Sext = dyn_cast(&I)) { - if (Sext->getSrcTy()->isIntegerTy(8)) { - ToRemove.push(Sext); - Sext->replaceAllUsesWith(ReplacedValues[Sext->getOperand(0)]); - } - } - } - - while (!ToRemove.empty()) { - Instruction *I = ToRemove.top(); - I->eraseFromParent(); - ToRemove.pop(); - } - - return true; -} - -PreservedAnalyses LegalizeI8Pass::run(Function &F, FunctionAnalysisManager &FAM) { - bool MadeChanges = fixI8TruncUseChain(F); - if (!MadeChanges) - return PreservedAnalyses::all(); - PreservedAnalyses PA; - return PA; - } - - bool LegalizeI8Legacy::runOnFunction(Function &F) { - return fixI8TruncUseChain(F); - } - - char LegalizeI8Legacy::ID = 0; - - INITIALIZE_PASS_BEGIN(LegalizeI8Legacy, DEBUG_TYPE, - "DXIL I8 Legalizer", false, false) - INITIALIZE_PASS_END(LegalizeI8Legacy, DEBUG_TYPE, "DXIL I8 Legalizer", - false, false) - -FunctionPass *llvm::createLegalizeI8LegacyPass() { - return new LegalizeI8Legacy(); - } \ No newline at end of file diff --git a/llvm/test/CodeGen/DirectX/ResourceGlobalElimination.ll b/llvm/test/CodeGen/DirectX/ResourceGlobalElimination.ll index cd21adc11a9b4..50c5ff92024b0 100644 --- a/llvm/test/CodeGen/DirectX/ResourceGlobalElimination.ll +++ b/llvm/test/CodeGen/DirectX/ResourceGlobalElimination.ll @@ -19,8 +19,8 @@ ; CHECK-LABEL define void @main() define void @main() local_unnamed_addr #0 { entry: - ; DXOP: %In_h.i1 = call %dx.types.Handle @dx.op.createHandle - ; DXOP: %Out_h.i2 = call %dx.types.Handle @dx.op.createHandle + ; DXOP: [[In_h_i:%.*]] = call %dx.types.Handle @dx.op.createHandle + ; DXOP: [[Out_h_i:%.*]] = call %dx.types.Handle @dx.op.createHandle %In_h.i = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v4f32_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false) store target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %In_h.i, ptr @In, align 4 %Out_h.i = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v4f32_1_0_0t(i32 4, i32 1, i32 1, i32 0, i1 false) diff --git a/llvm/test/CodeGen/DirectX/legalize-i64-extract-insert-elements.ll b/llvm/test/CodeGen/DirectX/legalize-i64-extract-insert-elements.ll new file mode 100644 index 0000000000000..8a59986524c90 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/legalize-i64-extract-insert-elements.ll @@ -0,0 +1,24 @@ +; RUN: opt -S -passes='dxil-legalize' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s + +define noundef <4 x float> @float4_extract(<4 x float> noundef %a) { +entry: + ; CHECK: [[ee0:%.*]] = extractelement <4 x float> %a, i32 0 + ; CHECK: [[ee1:%.*]] = extractelement <4 x float> %a, i32 1 + ; CHECK: [[ee2:%.*]] = extractelement <4 x float> %a, i32 2 + ; CHECK: [[ee3:%.*]] = extractelement <4 x float> %a, i32 3 + ; CHECK: insertelement <4 x float> poison, float [[ee0]], i32 0 + ; CHECK: insertelement <4 x float> %{{.*}}, float [[ee1]], i32 1 + ; CHECK: insertelement <4 x float> %{{.*}}, float [[ee2]], i32 2 + ; CHECK: insertelement <4 x float> %{{.*}}, float [[ee3]], i32 3 + + %a.i0 = extractelement <4 x float> %a, i64 0 + %a.i1 = extractelement <4 x float> %a, i64 1 + %a.i2 = extractelement <4 x float> %a, i64 2 + %a.i3 = extractelement <4 x float> %a, i64 3 + + %.upto0 = insertelement <4 x float> poison, float %a.i0, i64 0 + %.upto1 = insertelement <4 x float> %.upto0, float %a.i1, i64 1 + %.upto2 = insertelement <4 x float> %.upto1, float %a.i2, i64 2 + %0 = insertelement <4 x float> %.upto2, float %a.i3, i64 3 + ret <4 x float> %0 +} diff --git a/llvm/test/CodeGen/DirectX/legalize-i8.ll b/llvm/test/CodeGen/DirectX/legalize-i8.ll index d9787531ae1c4..d7cea585a056b 100644 --- a/llvm/test/CodeGen/DirectX/legalize-i8.ll +++ b/llvm/test/CodeGen/DirectX/legalize-i8.ll @@ -1,9 +1,20 @@ -; RUN: opt -S -passes='dxil-legalize-i8' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s +; RUN: opt -S -passes='dxil-legalize' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s define i32 @i8trunc(float %0) #0 { ; CHECK-NOT: %4 = trunc nsw i32 %3 to i8 ; CHECK: add i32 - ; CHECK: srem i32 + ; CHECK-NEXT: srem i32 + ; CHECK-NEXT: sub i32 + ; CHECK-NEXT: mul i32 + ; CHECK-NEXT: udiv i32 + ; CHECK-NEXT: sdiv i32 + ; CHECK-NEXT: urem i32 + ; CHECK-NEXT: and i32 + ; CHECK-NEXT: or i32 + ; CHECK-NEXT: xor i32 + ; CHECK-NEXT: shl i32 + ; CHECK-NEXT: lshr i32 + ; CHECK-NEXT: ashr i32 ; CHECK-NOT: %7 = sext i8 %6 to i32 %2 = fptosi float %0 to i32 @@ -11,6 +22,17 @@ define i32 @i8trunc(float %0) #0 { %4 = trunc nsw i32 %3 to i8 %5 = add nsw i8 %4, 1 %6 = srem i8 %5, 8 - %7 = sext i8 %6 to i32 - ret i32 %7 -} \ No newline at end of file + %7 = sub i8 %6, 1 + %8 = mul i8 %7, 1 + %9 = udiv i8 %8, 1 + %10 = sdiv i8 %9, 1 + %11 = urem i8 %10, 1 + %12 = and i8 %11, 1 + %13 = or i8 %12, 1 + %14 = xor i8 %13, 1 + %15 = shl i8 %14, 1 + %16 = lshr i8 %15, 1 + %17 = ashr i8 %16, 1 + %18 = sext i8 %17 to i32 + ret i32 %18 +} diff --git a/llvm/test/CodeGen/DirectX/llc-pipeline.ll b/llvm/test/CodeGen/DirectX/llc-pipeline.ll index 3a9af4d744f98..ee70cec534bc5 100644 --- a/llvm/test/CodeGen/DirectX/llc-pipeline.ll +++ b/llvm/test/CodeGen/DirectX/llc-pipeline.ll @@ -21,6 +21,7 @@ ; CHECK-NEXT: DXIL Resource Access ; CHECK-NEXT: Dominator Tree Construction ; CHECK-NEXT: Scalarize vector operations +; CHECK-NEXT: DXIL Legalizer ; CHECK-NEXT: DXIL Resource Binding Analysis ; CHECK-NEXT: DXIL Module Metadata analysis ; CHECK-NEXT: DXIL Shader Flag Analysis diff --git a/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll b/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll index 4e522c6ef5da7..7e5a92e1311f8 100644 --- a/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll +++ b/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll @@ -44,10 +44,10 @@ define <4 x i32> @load_array_vec_test() #0 { ; CHECK-NEXT: [[DOTI19:%.*]] = add i32 [[TMP4]], [[DOTI13]] ; CHECK-NEXT: [[DOTI210:%.*]] = add i32 [[TMP6]], [[DOTI25]] ; CHECK-NEXT: [[DOTI311:%.*]] = add i32 [[TMP8]], [[DOTI37]] -; CHECK-NEXT: [[DOTUPTO015:%.*]] = insertelement <4 x i32> poison, i32 [[DOTI08]], i64 0 -; CHECK-NEXT: [[DOTUPTO116:%.*]] = insertelement <4 x i32> [[DOTUPTO015]], i32 [[DOTI19]], i64 1 -; CHECK-NEXT: [[DOTUPTO217:%.*]] = insertelement <4 x i32> [[DOTUPTO116]], i32 [[DOTI210]], i64 2 -; CHECK-NEXT: [[TMP16:%.*]] = insertelement <4 x i32> [[DOTUPTO217]], i32 [[DOTI311]], i64 3 +; CHECK-NEXT: [[DOTUPTO015:%.*]] = insertelement <4 x i32> poison, i32 [[DOTI08]], i32 0 +; CHECK-NEXT: [[DOTUPTO116:%.*]] = insertelement <4 x i32> [[DOTUPTO015]], i32 [[DOTI19]], i32 1 +; CHECK-NEXT: [[DOTUPTO217:%.*]] = insertelement <4 x i32> [[DOTUPTO116]], i32 [[DOTI210]], i32 2 +; CHECK-NEXT: [[TMP16:%.*]] = insertelement <4 x i32> [[DOTUPTO217]], i32 [[DOTI311]], i32 3 ; CHECK-NEXT: ret <4 x i32> [[TMP16]] ; %1 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([2 x <4 x i32>], [2 x <4 x i32>] addrspace(3)* @"arrayofVecData", i32 0, i32 0), align 4 @@ -68,10 +68,10 @@ define <4 x i32> @load_vec_test() #0 { ; CHECK-NEXT: [[TMP6:%.*]] = load i32, ptr addrspace(3) [[TMP5]], align 4 ; CHECK-NEXT: [[TMP7:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @vecData.scalarized, i32 3) to ptr addrspace(3) ; CHECK-NEXT: [[TMP8:%.*]] = load i32, ptr addrspace(3) [[TMP7]], align 4 -; CHECK-NEXT: [[DOTUPTO0:%.*]] = insertelement <4 x i32> poison, i32 [[TMP2]], i64 0 -; CHECK-NEXT: [[DOTUPTO1:%.*]] = insertelement <4 x i32> [[DOTUPTO0]], i32 [[TMP4]], i64 1 -; CHECK-NEXT: [[DOTUPTO2:%.*]] = insertelement <4 x i32> [[DOTUPTO1]], i32 [[TMP6]], i64 2 -; CHECK-NEXT: [[TMP9:%.*]] = insertelement <4 x i32> [[DOTUPTO2]], i32 [[TMP8]], i64 3 +; CHECK-NEXT: [[DOTUPTO0:%.*]] = insertelement <4 x i32> poison, i32 [[TMP2]], i32 0 +; CHECK-NEXT: [[DOTUPTO1:%.*]] = insertelement <4 x i32> [[DOTUPTO0]], i32 [[TMP4]], i32 1 +; CHECK-NEXT: [[DOTUPTO2:%.*]] = insertelement <4 x i32> [[DOTUPTO1]], i32 [[TMP6]], i32 2 +; CHECK-NEXT: [[TMP9:%.*]] = insertelement <4 x i32> [[DOTUPTO2]], i32 [[TMP8]], i32 3 ; CHECK-NEXT: ret <4 x i32> [[TMP9]] ; %1 = load <4 x i32>, <4 x i32> addrspace(3)* @"vecData", align 4 @@ -93,10 +93,10 @@ define <4 x i32> @load_static_array_of_vec_test(i32 %index) #0 { ; CHECK-NEXT: [[TMP5:%.*]] = bitcast ptr [[DOTFLAT]] to ptr ; CHECK-NEXT: [[DOTFLAT_I3:%.*]] = getelementptr i32, ptr [[TMP5]], i32 3 ; CHECK-NEXT: [[DOTI3:%.*]] = load i32, ptr [[DOTFLAT_I3]], align 4 -; CHECK-NEXT: [[DOTUPTO0:%.*]] = insertelement <4 x i32> poison, i32 [[TMP2]], i64 0 -; CHECK-NEXT: [[DOTUPTO1:%.*]] = insertelement <4 x i32> [[DOTUPTO0]], i32 [[DOTI1]], i64 1 -; CHECK-NEXT: [[DOTUPTO2:%.*]] = insertelement <4 x i32> [[DOTUPTO1]], i32 [[DOTI2]], i64 2 -; CHECK-NEXT: [[TMP6:%.*]] = insertelement <4 x i32> [[DOTUPTO2]], i32 [[DOTI3]], i64 3 +; CHECK-NEXT: [[DOTUPTO0:%.*]] = insertelement <4 x i32> poison, i32 [[TMP2]], i32 0 +; CHECK-NEXT: [[DOTUPTO1:%.*]] = insertelement <4 x i32> [[DOTUPTO0]], i32 [[DOTI1]], i32 1 +; CHECK-NEXT: [[DOTUPTO2:%.*]] = insertelement <4 x i32> [[DOTUPTO1]], i32 [[DOTI2]], i32 2 +; CHECK-NEXT: [[TMP6:%.*]] = insertelement <4 x i32> [[DOTUPTO2]], i32 [[DOTI3]], i32 3 ; CHECK-NEXT: ret <4 x i32> [[TMP6]] ; %3 = getelementptr inbounds [3 x <4 x i32>], [3 x <4 x i32>]* @staticArrayOfVecData, i32 0, i32 %index @@ -127,10 +127,10 @@ define <4 x i32> @multid_load_test() #0 { ; CHECK-NEXT: [[DOTI19:%.*]] = add i32 [[TMP4]], [[DOTI13]] ; CHECK-NEXT: [[DOTI210:%.*]] = add i32 [[TMP6]], [[DOTI25]] ; CHECK-NEXT: [[DOTI311:%.*]] = add i32 [[TMP8]], [[DOTI37]] -; CHECK-NEXT: [[DOTUPTO015:%.*]] = insertelement <4 x i32> poison, i32 [[DOTI08]], i64 0 -; CHECK-NEXT: [[DOTUPTO116:%.*]] = insertelement <4 x i32> [[DOTUPTO015]], i32 [[DOTI19]], i64 1 -; CHECK-NEXT: [[DOTUPTO217:%.*]] = insertelement <4 x i32> [[DOTUPTO116]], i32 [[DOTI210]], i64 2 -; CHECK-NEXT: [[TMP16:%.*]] = insertelement <4 x i32> [[DOTUPTO217]], i32 [[DOTI311]], i64 3 +; CHECK-NEXT: [[DOTUPTO015:%.*]] = insertelement <4 x i32> poison, i32 [[DOTI08]], i32 0 +; CHECK-NEXT: [[DOTUPTO116:%.*]] = insertelement <4 x i32> [[DOTUPTO015]], i32 [[DOTI19]], i32 1 +; CHECK-NEXT: [[DOTUPTO217:%.*]] = insertelement <4 x i32> [[DOTUPTO116]], i32 [[DOTI210]], i32 2 +; CHECK-NEXT: [[TMP16:%.*]] = insertelement <4 x i32> [[DOTUPTO217]], i32 [[DOTI311]], i32 3 ; CHECK-NEXT: ret <4 x i32> [[TMP16]] ; %1 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([3 x [3 x <4 x i32>]], [3 x [3 x <4 x i32>]] addrspace(3)* @"groushared2dArrayofVectors", i32 0, i32 0, i32 0), align 4 diff --git a/llvm/test/CodeGen/DirectX/scalarize-two-calls.ll b/llvm/test/CodeGen/DirectX/scalarize-two-calls.ll index 0546a5505416f..7e8f58c0576f0 100644 --- a/llvm/test/CodeGen/DirectX/scalarize-two-calls.ll +++ b/llvm/test/CodeGen/DirectX/scalarize-two-calls.ll @@ -3,22 +3,22 @@ ; CHECK: target triple = "dxilv1.3-pc-shadermodel6.3-library" ; CHECK-LABEL: cos_sin_float_test define noundef <4 x float> @cos_sin_float_test(<4 x float> noundef %a) #0 { - ; CHECK: [[ee0:%.*]] = extractelement <4 x float> %a, i64 0 + ; CHECK: [[ee0:%.*]] = extractelement <4 x float> %a, i32 0 ; CHECK: [[ie0:%.*]] = call float @dx.op.unary.f32(i32 13, float [[ee0]]) - ; CHECK: [[ee1:%.*]] = extractelement <4 x float> %a, i64 1 + ; CHECK: [[ee1:%.*]] = extractelement <4 x float> %a, i32 1 ; CHECK: [[ie1:%.*]] = call float @dx.op.unary.f32(i32 13, float [[ee1]]) - ; CHECK: [[ee2:%.*]] = extractelement <4 x float> %a, i64 2 + ; CHECK: [[ee2:%.*]] = extractelement <4 x float> %a, i32 2 ; CHECK: [[ie2:%.*]] = call float @dx.op.unary.f32(i32 13, float [[ee2]]) - ; CHECK: [[ee3:%.*]] = extractelement <4 x float> %a, i64 3 + ; CHECK: [[ee3:%.*]] = extractelement <4 x float> %a, i32 3 ; CHECK: [[ie3:%.*]] = call float @dx.op.unary.f32(i32 13, float [[ee3]]) ; CHECK: [[ie4:%.*]] = call float @dx.op.unary.f32(i32 12, float [[ie0]]) ; CHECK: [[ie5:%.*]] = call float @dx.op.unary.f32(i32 12, float [[ie1]]) ; CHECK: [[ie6:%.*]] = call float @dx.op.unary.f32(i32 12, float [[ie2]]) ; CHECK: [[ie7:%.*]] = call float @dx.op.unary.f32(i32 12, float [[ie3]]) - ; CHECK: insertelement <4 x float> poison, float [[ie4]], i64 0 - ; CHECK: insertelement <4 x float> %{{.*}}, float [[ie5]], i64 1 - ; CHECK: insertelement <4 x float> %{{.*}}, float [[ie6]], i64 2 - ; CHECK: insertelement <4 x float> %{{.*}}, float [[ie7]], i64 3 + ; CHECK: insertelement <4 x float> poison, float [[ie4]], i32 0 + ; CHECK: insertelement <4 x float> %{{.*}}, float [[ie5]], i32 1 + ; CHECK: insertelement <4 x float> %{{.*}}, float [[ie6]], i32 2 + ; CHECK: insertelement <4 x float> %{{.*}}, float [[ie7]], i32 3 %2 = tail call <4 x float> @llvm.sin.v4f32(<4 x float> %a) %3 = tail call <4 x float> @llvm.cos.v4f32(<4 x float> %2) ret <4 x float> %3 From 7d9323134674f57aec26c344c977a75de514a379 Mon Sep 17 00:00:00 2001 From: Farzon Lotfi Date: Thu, 13 Mar 2025 20:00:07 -0400 Subject: [PATCH 3/7] fix an issue where we didn't have a type for immediated if ther were the first argument --- llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 23 ++++----- llvm/test/CodeGen/DirectX/legalize-i8.ll | 49 ++++++++++++++++++++ 2 files changed, 61 insertions(+), 11 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp index 3db718d56adf1..0abba50eee873 100644 --- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp +++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp @@ -33,16 +33,19 @@ static bool fixI8TruncUseChain(Instruction &I, std::stack &ToRemove, std::map &ReplacedValues) { + auto *Cmp = dyn_cast(&I); + if (auto *Trunc = dyn_cast(&I)) { if (Trunc->getDestTy()->isIntegerTy(8)) { ReplacedValues[Trunc] = Trunc->getOperand(0); ToRemove.push(Trunc); } - } else if (I.getType()->isIntegerTy(8)) { + } else if (I.getType()->isIntegerTy(8) || + (Cmp && Cmp->getOperand(0)->getType()->isIntegerTy(8))) { IRBuilder<> Builder(&I); std::vector NewOperands; - Type *InstrType = nullptr; + Type *InstrType = IntegerType::get(I.getContext(), 32); for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) { Value *Op = I.getOperand(OpIdx); if (ReplacedValues.count(Op)) { @@ -66,23 +69,21 @@ static bool fixI8TruncUseChain(Instruction &I, if (auto *BO = dyn_cast(&I)) NewInst = Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]); - else if (auto *Cmp = dyn_cast(&I)) + else if (Cmp) { NewInst = Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0], NewOperands[1]); - else if (auto *Cast = dyn_cast(&I)) - NewInst = Builder.CreateCast(Cast->getOpcode(), NewOperands[0], - Cast->getDestTy()); - else if (auto *UnaryOp = dyn_cast(&I)) + Cmp->replaceAllUsesWith(NewInst); + } else if (auto *UnaryOp = dyn_cast(&I)) NewInst = Builder.CreateUnOp(UnaryOp->getOpcode(), NewOperands[0]); if (NewInst) { ReplacedValues[&I] = NewInst; ToRemove.push(&I); } - } else if (auto *Sext = dyn_cast(&I)) { - if (Sext->getSrcTy()->isIntegerTy(8)) { - ToRemove.push(Sext); - Sext->replaceAllUsesWith(ReplacedValues[Sext->getOperand(0)]); + } else if (auto *Cast = dyn_cast(&I)) { + if (Cast->getSrcTy()->isIntegerTy(8)) { + ToRemove.push(Cast); + Cast->replaceAllUsesWith(ReplacedValues[Cast->getOperand(0)]); } } diff --git a/llvm/test/CodeGen/DirectX/legalize-i8.ll b/llvm/test/CodeGen/DirectX/legalize-i8.ll index d7cea585a056b..a18be375ddf63 100644 --- a/llvm/test/CodeGen/DirectX/legalize-i8.ll +++ b/llvm/test/CodeGen/DirectX/legalize-i8.ll @@ -1,6 +1,16 @@ ; RUN: opt -S -passes='dxil-legalize' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s +define i32 @removal_only_test(i32 %a) { + ; CHECK-LABEL: define i32 @removal_only_test( + ; CHECK-SAME: i32 [[A:%.*]]) { + ; CHECK: ret i32 [[A]] + %1 = trunc nsw i32 %a to i8 + %3 = sext i8 %1 to i32 + ret i32 %3 +} + define i32 @i8trunc(float %0) #0 { + ; CHECK-LABEL: define i32 @i8trunc( ; CHECK-NOT: %4 = trunc nsw i32 %3 to i8 ; CHECK: add i32 ; CHECK-NEXT: srem i32 @@ -36,3 +46,42 @@ define i32 @i8trunc(float %0) #0 { %18 = sext i8 %17 to i32 ret i32 %18 } + +define i32 @cast_removal_test(i32 %a) { + ; CHECK-LABEL: define i32 @cast_removal_test( + ; CHECK-SAME: i32 [[A:%.*]]) { + ; CHECK-NOT: trunc + ; CHECK-NOT: zext i8 + ; CHECK-NOT: sext i8 + ; CHECK: add i32 [[A]], [[A]] + %1 = trunc nsw i32 %a to i8 + %2 = zext i8 %1 to i32 + %3 = sext i8 %1 to i32 + %4 = add i32 %2, %3 + ret i32 %4 +} + +define i1 @trunc_cmp_test(i32 %a, i32 %b) { + ; CHECK-LABEL: define i1 @trunc_cmp_test( + ; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) { + ; CHECK: icmp slt i32 [[A]], [[B]] + ; CHECK: icmp sgt i32 [[A]], [[B]] + %1 = trunc nsw i32 %a to i8 + %2 = trunc nsw i32 %b to i8 + %3 = icmp slt i8 %1, %2 + %4 = icmp sgt i8 %1, %2 + %5 = and i1 %3, %4 + ret i1 %5 +} + +define i32 @first_operand_imm_test(i32 %a) { + ; CHECK-LABEL: define i32 @first_operand_imm_test( + ; CHECK-SAME: i32 [[A:%.*]]) { + ; CHECK-NOT: trunc + ; CHECK: sub i32 0, [[A]] + ; CHECK-NOT: sext i8 + %1 = trunc nsw i32 %a to i8 + %2 = sub i8 0, %1 + %3 = sext i8 %2 to i32 + ret i32 %3 +} From 0a561f58a02820dadc27ba68a048d2d8eca6787a Mon Sep 17 00:00:00 2001 From: Farzon Lotfi Date: Fri, 14 Mar 2025 01:28:24 -0400 Subject: [PATCH 4/7] allow non i32 imm values --- llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 10 ++++++--- llvm/test/CodeGen/DirectX/legalize-i8.ll | 22 ++++++++++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp index 0abba50eee873..a9f9334910a43 100644 --- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp +++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp @@ -48,10 +48,14 @@ static bool fixI8TruncUseChain(Instruction &I, Type *InstrType = IntegerType::get(I.getContext(), 32); for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) { Value *Op = I.getOperand(OpIdx); - if (ReplacedValues.count(Op)) { + if (ReplacedValues.count(Op)) InstrType = ReplacedValues[Op]->getType(); + } + for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) { + Value *Op = I.getOperand(OpIdx); + if (ReplacedValues.count(Op)) NewOperands.push_back(ReplacedValues[Op]); - } else if (auto *Imm = dyn_cast(Op)) { + else if (auto *Imm = dyn_cast(Op)) { APInt Value = Imm->getValue(); unsigned NewBitWidth = InstrType->getIntegerBitWidth(); // Note: options here are sext or sextOrTrunc. @@ -142,7 +146,7 @@ class DXILLegalizationPipeline { bool MadeChanges = false; for (auto &I : instructions(F)) { for (auto &LegalizationFn : LegalizationPipeline) { - MadeChanges = LegalizationFn(I, ToRemove, ReplacedValues); + MadeChanges |= LegalizationFn(I, ToRemove, ReplacedValues); } } while (!ToRemove.empty()) { diff --git a/llvm/test/CodeGen/DirectX/legalize-i8.ll b/llvm/test/CodeGen/DirectX/legalize-i8.ll index a18be375ddf63..1b7eb2c552cf1 100644 --- a/llvm/test/CodeGen/DirectX/legalize-i8.ll +++ b/llvm/test/CodeGen/DirectX/legalize-i8.ll @@ -85,3 +85,25 @@ define i32 @first_operand_imm_test(i32 %a) { %3 = sext i8 %2 to i32 ret i32 %3 } + +define i16 @i16_test(i16 %a) { + ; CHECK-LABEL: define i16 @i16_test( + ; CHECK-SAME: i16 [[A:%.*]]) { + ; CHECK-NOT: trunc + ; CHECK: sub i16 0, [[A]] + ; CHECK-NOT: sext i8 + %1 = trunc nsw i16 %a to i8 + %2 = sub i8 0, %1 + %3 = sext i8 %2 to i16 + ret i16 %3 +} + +define i32 @all_imm() { + ; CHECK-LABEL: define i32 @all_imm( + ; CHECK-NOT: trunc + ; CHECK-NOT: sext i8 + ; CHECK: ret i32 -1 + %1 = sub i8 0, 1 + %2 = sext i8 %1 to i32 + ret i32 %2 +} From 64937d2199d1057036befd550f1427fbdd4774ec Mon Sep 17 00:00:00 2001 From: Farzon Lotfi Date: Fri, 14 Mar 2025 03:12:36 -0400 Subject: [PATCH 5/7] remove unary since fneg only case. copy NUW or NSW over from orig binOp to new binOp. --- llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 14 ++++++++++---- llvm/test/CodeGen/DirectX/legalize-i8.ll | 2 +- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp index a9f9334910a43..b8de0c1965068 100644 --- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp +++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp @@ -70,15 +70,21 @@ static bool fixI8TruncUseChain(Instruction &I, } Value *NewInst = nullptr; - if (auto *BO = dyn_cast(&I)) + if (auto *BO = dyn_cast(&I)) { NewInst = Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]); - else if (Cmp) { + + if (auto *OBO = dyn_cast(&I)) { + if (OBO->hasNoSignedWrap()) + cast(NewInst)->setHasNoSignedWrap(); + if (OBO->hasNoUnsignedWrap()) + cast(NewInst)->setHasNoUnsignedWrap(); + } + } else if (Cmp) { NewInst = Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0], NewOperands[1]); Cmp->replaceAllUsesWith(NewInst); - } else if (auto *UnaryOp = dyn_cast(&I)) - NewInst = Builder.CreateUnOp(UnaryOp->getOpcode(), NewOperands[0]); + } if (NewInst) { ReplacedValues[&I] = NewInst; diff --git a/llvm/test/CodeGen/DirectX/legalize-i8.ll b/llvm/test/CodeGen/DirectX/legalize-i8.ll index 1b7eb2c552cf1..d17157c78e3c2 100644 --- a/llvm/test/CodeGen/DirectX/legalize-i8.ll +++ b/llvm/test/CodeGen/DirectX/legalize-i8.ll @@ -12,7 +12,7 @@ define i32 @removal_only_test(i32 %a) { define i32 @i8trunc(float %0) #0 { ; CHECK-LABEL: define i32 @i8trunc( ; CHECK-NOT: %4 = trunc nsw i32 %3 to i8 - ; CHECK: add i32 + ; CHECK: add nsw i32 ; CHECK-NEXT: srem i32 ; CHECK-NEXT: sub i32 ; CHECK-NEXT: mul i32 From cbbdd62a404451b649fb8bef03690d2cc2617a60 Mon Sep 17 00:00:00 2001 From: Farzon Lotfi Date: Mon, 17 Mar 2025 14:36:43 -0400 Subject: [PATCH 6/7] address pr comments --- llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 4 ++-- llvm/lib/Target/DirectX/DXILLegalizePass.h | 2 +- llvm/test/CodeGen/DirectX/legalize-i8.ll | 1 - 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp index b8de0c1965068..ee59be414cbb1 100644 --- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp +++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp @@ -1,4 +1,4 @@ -//===- DXILLegalizePass.cpp - Legalizes llvm IR for DXIL-*- C++----------*-===// +//===- DXILLegalizePass.cpp - Legalizes llvm IR for DXIL ------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -59,7 +59,7 @@ static bool fixI8TruncUseChain(Instruction &I, APInt Value = Imm->getValue(); unsigned NewBitWidth = InstrType->getIntegerBitWidth(); // Note: options here are sext or sextOrTrunc. - // Since i8 isn't suppport we assume new values + // Since i8 isn't supported, we assume new values // will always have a higher bitness. APInt NewValue = Value.sext(NewBitWidth); NewOperands.push_back(ConstantInt::get(InstrType, NewValue)); diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.h b/llvm/lib/Target/DirectX/DXILLegalizePass.h index 39ef6f532dca0..9d6d1cd19081d 100644 --- a/llvm/lib/Target/DirectX/DXILLegalizePass.h +++ b/llvm/lib/Target/DirectX/DXILLegalizePass.h @@ -1,4 +1,4 @@ -//===- DXILLegalizePass.h - Legalizes llvm IR for DXIL-*- C++------------*-===// +//===- DXILLegalizePass.h - Legalizes llvm IR for DXIL --------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/llvm/test/CodeGen/DirectX/legalize-i8.ll b/llvm/test/CodeGen/DirectX/legalize-i8.ll index d17157c78e3c2..2602be778cd86 100644 --- a/llvm/test/CodeGen/DirectX/legalize-i8.ll +++ b/llvm/test/CodeGen/DirectX/legalize-i8.ll @@ -100,7 +100,6 @@ define i16 @i16_test(i16 %a) { define i32 @all_imm() { ; CHECK-LABEL: define i32 @all_imm( - ; CHECK-NOT: trunc ; CHECK-NOT: sext i8 ; CHECK: ret i32 -1 %1 = sub i8 0, 1 From f8be37f050ced0344c633be1c6b826972234c357 Mon Sep 17 00:00:00 2001 From: Farzon Lotfi Date: Mon, 17 Mar 2025 15:25:22 -0400 Subject: [PATCH 7/7] address pr comments --- llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp index ee59be414cbb1..f9a494ce63dd3 100644 --- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp +++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp @@ -29,7 +29,7 @@ using namespace llvm; namespace { -static bool fixI8TruncUseChain(Instruction &I, +static void fixI8TruncUseChain(Instruction &I, std::stack &ToRemove, std::map &ReplacedValues) { @@ -96,11 +96,9 @@ static bool fixI8TruncUseChain(Instruction &I, Cast->replaceAllUsesWith(ReplacedValues[Cast->getOperand(0)]); } } - - return !ToRemove.empty(); } -static bool +static void downcastI64toI32InsertExtractElements(Instruction &I, std::stack &ToRemove, std::map &) { @@ -137,8 +135,6 @@ downcastI64toI32InsertExtractElements(Instruction &I, ToRemove.push(Insert); } } - - return !ToRemove.empty(); } class DXILLegalizationPipeline { @@ -149,12 +145,13 @@ class DXILLegalizationPipeline { bool runLegalizationPipeline(Function &F) { std::stack ToRemove; std::map ReplacedValues; - bool MadeChanges = false; for (auto &I : instructions(F)) { for (auto &LegalizationFn : LegalizationPipeline) { - MadeChanges |= LegalizationFn(I, ToRemove, ReplacedValues); + LegalizationFn(I, ToRemove, ReplacedValues); } } + bool MadeChanges = !ToRemove.empty(); + while (!ToRemove.empty()) { Instruction *I = ToRemove.top(); I->eraseFromParent(); @@ -165,7 +162,7 @@ class DXILLegalizationPipeline { } private: - std::vector &, + std::vector &, std::map &)>> LegalizationPipeline;