diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp index f9a494ce63dd3..395311e430fbb 100644 --- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp +++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp @@ -5,12 +5,7 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===---------------------------------------------------------------------===// -//===---------------------------------------------------------------------===// -/// -/// \file This file contains a pass to remove i8 truncations and i64 extract -/// and insert elements. -/// -//===----------------------------------------------------------------------===// + #include "DXILLegalizePass.h" #include "DirectX.h" #include "llvm/IR/Function.h" @@ -20,37 +15,24 @@ #include "llvm/Pass.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include -#include -#include -#include #define DEBUG_TYPE "dxil-legalize" using namespace llvm; -namespace { static void fixI8TruncUseChain(Instruction &I, - std::stack &ToRemove, - std::map &ReplacedValues) { - - auto *Cmp = dyn_cast(&I); + SmallVectorImpl &ToRemove, + DenseMap &ReplacedValues) { - if (auto *Trunc = dyn_cast(&I)) { - if (Trunc->getDestTy()->isIntegerTy(8)) { - ReplacedValues[Trunc] = Trunc->getOperand(0); - ToRemove.push(Trunc); - } - } else if (I.getType()->isIntegerTy(8) || - (Cmp && Cmp->getOperand(0)->getType()->isIntegerTy(8))) { - IRBuilder<> Builder(&I); - - std::vector NewOperands; + auto ProcessOperands = [&](SmallVector &NewOperands) { Type *InstrType = IntegerType::get(I.getContext(), 32); + for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) { Value *Op = I.getOperand(OpIdx); if (ReplacedValues.count(Op)) InstrType = ReplacedValues[Op]->getType(); } + for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) { Value *Op = I.getOperand(OpIdx); if (ReplacedValues.count(Op)) @@ -61,6 +43,8 @@ static void fixI8TruncUseChain(Instruction &I, // Note: options here are sext or sextOrTrunc. // Since i8 isn't supported, we assume new values // will always have a higher bitness. + assert(NewBitWidth > Value.getBitWidth() && + "Replacement's BitWidth should be larger than Current."); APInt NewValue = Value.sext(NewBitWidth); NewOperands.push_back(ConstantInt::get(InstrType, NewValue)); } else { @@ -68,31 +52,50 @@ static void fixI8TruncUseChain(Instruction &I, NewOperands.push_back(Op); } } - - Value *NewInst = nullptr; - if (auto *BO = dyn_cast(&I)) { - NewInst = - Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]); - - if (auto *OBO = dyn_cast(&I)) { - if (OBO->hasNoSignedWrap()) - cast(NewInst)->setHasNoSignedWrap(); - if (OBO->hasNoUnsignedWrap()) - cast(NewInst)->setHasNoUnsignedWrap(); - } - } else if (Cmp) { - NewInst = Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0], - NewOperands[1]); - Cmp->replaceAllUsesWith(NewInst); + }; + IRBuilder<> Builder(&I); + if (auto *Trunc = dyn_cast(&I)) { + if (Trunc->getDestTy()->isIntegerTy(8)) { + ReplacedValues[Trunc] = Trunc->getOperand(0); + ToRemove.push_back(Trunc); + return; } + } - if (NewInst) { - ReplacedValues[&I] = NewInst; - ToRemove.push(&I); + if (auto *BO = dyn_cast(&I)) { + if (!I.getType()->isIntegerTy(8)) + return; + SmallVector NewOperands; + ProcessOperands(NewOperands); + Value *NewInst = + Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]); + if (auto *OBO = dyn_cast(&I)) { + if (OBO->hasNoSignedWrap()) + cast(NewInst)->setHasNoSignedWrap(); + if (OBO->hasNoUnsignedWrap()) + cast(NewInst)->setHasNoUnsignedWrap(); } - } else if (auto *Cast = dyn_cast(&I)) { + ReplacedValues[BO] = NewInst; + ToRemove.push_back(BO); + return; + } + + if (auto *Cmp = dyn_cast(&I)) { + if (!Cmp->getOperand(0)->getType()->isIntegerTy(8)) + return; + SmallVector NewOperands; + ProcessOperands(NewOperands); + Value *NewInst = + Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0], NewOperands[1]); + Cmp->replaceAllUsesWith(NewInst); + ReplacedValues[Cmp] = NewInst; + ToRemove.push_back(Cmp); + return; + } + + if (auto *Cast = dyn_cast(&I)) { if (Cast->getSrcTy()->isIntegerTy(8)) { - ToRemove.push(Cast); + ToRemove.push_back(Cast); Cast->replaceAllUsesWith(ReplacedValues[Cast->getOperand(0)]); } } @@ -100,8 +103,8 @@ static void fixI8TruncUseChain(Instruction &I, static void downcastI64toI32InsertExtractElements(Instruction &I, - std::stack &ToRemove, - std::map &) { + SmallVectorImpl &ToRemove, + DenseMap &) { if (auto *Extract = dyn_cast(&I)) { Value *Idx = Extract->getIndexOperand(); @@ -115,7 +118,7 @@ downcastI64toI32InsertExtractElements(Instruction &I, Extract->getVectorOperand(), Idx32, Extract->getName()); Extract->replaceAllUsesWith(NewExtract); - ToRemove.push(Extract); + ToRemove.push_back(Extract); } } @@ -132,38 +135,35 @@ downcastI64toI32InsertExtractElements(Instruction &I, Insert->getName()); Insert->replaceAllUsesWith(Insert32Index); - ToRemove.push(Insert); + ToRemove.push_back(Insert); } } } +namespace { class DXILLegalizationPipeline { public: DXILLegalizationPipeline() { initializeLegalizationPipeline(); } bool runLegalizationPipeline(Function &F) { - std::stack ToRemove; - std::map ReplacedValues; + SmallVector ToRemove; + DenseMap ReplacedValues; for (auto &I : instructions(F)) { - for (auto &LegalizationFn : LegalizationPipeline) { + for (auto &LegalizationFn : LegalizationPipeline) LegalizationFn(I, ToRemove, ReplacedValues); - } } - bool MadeChanges = !ToRemove.empty(); - while (!ToRemove.empty()) { - Instruction *I = ToRemove.top(); - I->eraseFromParent(); - ToRemove.pop(); - } + for (auto *Inst : reverse(ToRemove)) + Inst->eraseFromParent(); - return MadeChanges; + return !ToRemove.empty(); } private: - std::vector &, - std::map &)>> + SmallVector< + std::function &, + DenseMap &)>> LegalizationPipeline; void initializeLegalizationPipeline() {