From 723e1837593131c10d6ed096674a273d5c530532 Mon Sep 17 00:00:00 2001 From: Farzon Lotfi Date: Mon, 17 Mar 2025 21:44:37 -0400 Subject: [PATCH 1/4] [DirectX] Address PR comments to #131221 --- llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 55 +++++++------------- 1 file changed, 20 insertions(+), 35 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp index f9a494ce63dd3..317bff40caf7e 100644 --- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp +++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp @@ -5,12 +5,7 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===---------------------------------------------------------------------===// -//===---------------------------------------------------------------------===// -/// -/// \file This file contains a pass to remove i8 truncations and i64 extract -/// and insert elements. -/// -//===----------------------------------------------------------------------===// + #include "DXILLegalizePass.h" #include "DirectX.h" #include "llvm/IR/Function.h" @@ -20,31 +15,27 @@ #include "llvm/Pass.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include -#include -#include -#include #define DEBUG_TYPE "dxil-legalize" using namespace llvm; namespace { -static void fixI8TruncUseChain(Instruction &I, - std::stack &ToRemove, - std::map &ReplacedValues) { +void fixI8TruncUseChain(Instruction &I, SmallVector &ToRemove, + DenseMap &ReplacedValues) { auto *Cmp = dyn_cast(&I); if (auto *Trunc = dyn_cast(&I)) { if (Trunc->getDestTy()->isIntegerTy(8)) { ReplacedValues[Trunc] = Trunc->getOperand(0); - ToRemove.push(Trunc); + ToRemove.push_back(Trunc); } } else if (I.getType()->isIntegerTy(8) || (Cmp && Cmp->getOperand(0)->getType()->isIntegerTy(8))) { IRBuilder<> Builder(&I); - std::vector NewOperands; + SmallVector NewOperands; Type *InstrType = IntegerType::get(I.getContext(), 32); for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) { Value *Op = I.getOperand(OpIdx); @@ -88,20 +79,19 @@ static void fixI8TruncUseChain(Instruction &I, if (NewInst) { ReplacedValues[&I] = NewInst; - ToRemove.push(&I); + ToRemove.push_back(&I); } } else if (auto *Cast = dyn_cast(&I)) { if (Cast->getSrcTy()->isIntegerTy(8)) { - ToRemove.push(Cast); + ToRemove.push_back(Cast); Cast->replaceAllUsesWith(ReplacedValues[Cast->getOperand(0)]); } } } -static void -downcastI64toI32InsertExtractElements(Instruction &I, - std::stack &ToRemove, - std::map &) { +void downcastI64toI32InsertExtractElements(Instruction &I, + SmallVector &ToRemove, + DenseMap &) { if (auto *Extract = dyn_cast(&I)) { Value *Idx = Extract->getIndexOperand(); @@ -115,7 +105,7 @@ downcastI64toI32InsertExtractElements(Instruction &I, Extract->getVectorOperand(), Idx32, Extract->getName()); Extract->replaceAllUsesWith(NewExtract); - ToRemove.push(Extract); + ToRemove.push_back(Extract); } } @@ -132,7 +122,7 @@ downcastI64toI32InsertExtractElements(Instruction &I, Insert->getName()); Insert->replaceAllUsesWith(Insert32Index); - ToRemove.push(Insert); + ToRemove.push_back(Insert); } } } @@ -143,27 +133,22 @@ class DXILLegalizationPipeline { DXILLegalizationPipeline() { initializeLegalizationPipeline(); } bool runLegalizationPipeline(Function &F) { - std::stack ToRemove; - std::map ReplacedValues; + SmallVector ToRemove; + DenseMap ReplacedValues; for (auto &I : instructions(F)) { - for (auto &LegalizationFn : LegalizationPipeline) { + for (auto &LegalizationFn : LegalizationPipeline) LegalizationFn(I, ToRemove, ReplacedValues); - } } - bool MadeChanges = !ToRemove.empty(); - while (!ToRemove.empty()) { - Instruction *I = ToRemove.top(); - I->eraseFromParent(); - ToRemove.pop(); - } + for (auto *Inst : reverse(ToRemove)) + Inst->eraseFromParent(); - return MadeChanges; + return !ToRemove.empty(); } private: - std::vector &, - std::map &)>> + SmallVector &, + DenseMap &)>> LegalizationPipeline; void initializeLegalizationPipeline() { From 137873b88b4d51906beef1b8e0cc7cf7e8d3d96b Mon Sep 17 00:00:00 2001 From: Farzon Lotfi Date: Mon, 17 Mar 2025 22:16:52 -0400 Subject: [PATCH 2/4] address pr comments --- llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp index 317bff40caf7e..44cd92cfc51f8 100644 --- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp +++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp @@ -19,10 +19,10 @@ #define DEBUG_TYPE "dxil-legalize" using namespace llvm; -namespace { -void fixI8TruncUseChain(Instruction &I, SmallVector &ToRemove, - DenseMap &ReplacedValues) { +static void fixI8TruncUseChain(Instruction &I, + SmallVectorImpl &ToRemove, + DenseMap &ReplacedValues) { auto *Cmp = dyn_cast(&I); @@ -89,9 +89,10 @@ void fixI8TruncUseChain(Instruction &I, SmallVector &ToRemove, } } -void downcastI64toI32InsertExtractElements(Instruction &I, - SmallVector &ToRemove, - DenseMap &) { +static void +downcastI64toI32InsertExtractElements(Instruction &I, + SmallVectorImpl &ToRemove, + DenseMap &) { if (auto *Extract = dyn_cast(&I)) { Value *Idx = Extract->getIndexOperand(); @@ -127,6 +128,7 @@ void downcastI64toI32InsertExtractElements(Instruction &I, } } +namespace { class DXILLegalizationPipeline { public: @@ -147,8 +149,9 @@ class DXILLegalizationPipeline { } private: - SmallVector &, - DenseMap &)>> + SmallVector< + std::function &, + DenseMap &)>> LegalizationPipeline; void initializeLegalizationPipeline() { From adb61f8e189ac7a28d135aa0d7fad8265d5ff45c Mon Sep 17 00:00:00 2001 From: Farzon Lotfi Date: Thu, 20 Mar 2025 14:02:13 -0400 Subject: [PATCH 3/4] refactor fixI8TruncUseChain to remove nesting --- llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 76 +++++++++++--------- 1 file changed, 43 insertions(+), 33 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp index 44cd92cfc51f8..a9edf77aa08af 100644 --- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp +++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp @@ -24,24 +24,15 @@ static void fixI8TruncUseChain(Instruction &I, SmallVectorImpl &ToRemove, DenseMap &ReplacedValues) { - auto *Cmp = dyn_cast(&I); - - if (auto *Trunc = dyn_cast(&I)) { - if (Trunc->getDestTy()->isIntegerTy(8)) { - ReplacedValues[Trunc] = Trunc->getOperand(0); - ToRemove.push_back(Trunc); - } - } else if (I.getType()->isIntegerTy(8) || - (Cmp && Cmp->getOperand(0)->getType()->isIntegerTy(8))) { - IRBuilder<> Builder(&I); - - SmallVector NewOperands; + auto ProcessOperands = [&](SmallVector &NewOperands) { Type *InstrType = IntegerType::get(I.getContext(), 32); + for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) { Value *Op = I.getOperand(OpIdx); if (ReplacedValues.count(Op)) InstrType = ReplacedValues[Op]->getType(); } + for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) { Value *Op = I.getOperand(OpIdx); if (ReplacedValues.count(Op)) @@ -52,6 +43,8 @@ static void fixI8TruncUseChain(Instruction &I, // Note: options here are sext or sextOrTrunc. // Since i8 isn't supported, we assume new values // will always have a higher bitness. + assert(NewBitWidth > Value.getBitWidth() && + "Replacement's BitWidth should be larger than Current."); APInt NewValue = Value.sext(NewBitWidth); NewOperands.push_back(ConstantInt::get(InstrType, NewValue)); } else { @@ -59,29 +52,46 @@ static void fixI8TruncUseChain(Instruction &I, NewOperands.push_back(Op); } } - - Value *NewInst = nullptr; - if (auto *BO = dyn_cast(&I)) { - NewInst = - Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]); - - if (auto *OBO = dyn_cast(&I)) { - if (OBO->hasNoSignedWrap()) - cast(NewInst)->setHasNoSignedWrap(); - if (OBO->hasNoUnsignedWrap()) - cast(NewInst)->setHasNoUnsignedWrap(); - } - } else if (Cmp) { - NewInst = Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0], - NewOperands[1]); - Cmp->replaceAllUsesWith(NewInst); + }; + IRBuilder<> Builder(&I); + if (auto *Trunc = dyn_cast(&I)) { + if (Trunc->getDestTy()->isIntegerTy(8)) { + ReplacedValues[Trunc] = Trunc->getOperand(0); + ToRemove.push_back(Trunc); } - - if (NewInst) { - ReplacedValues[&I] = NewInst; - ToRemove.push_back(&I); + } + Value *NewInst = nullptr; + if (auto *BO = dyn_cast(&I)) { + if (!I.getType()->isIntegerTy(8)) + return; + SmallVector NewOperands; + ProcessOperands(NewOperands); + NewInst = + Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]); + if (auto *OBO = dyn_cast(&I)) { + if (OBO->hasNoSignedWrap()) + cast(NewInst)->setHasNoSignedWrap(); + if (OBO->hasNoUnsignedWrap()) + cast(NewInst)->setHasNoUnsignedWrap(); } - } else if (auto *Cast = dyn_cast(&I)) { + } + + if (auto *Cmp = dyn_cast(&I)) { + if (!Cmp->getOperand(0)->getType()->isIntegerTy(8)) + return; + SmallVector NewOperands; + ProcessOperands(NewOperands); + NewInst = + Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0], NewOperands[1]); + Cmp->replaceAllUsesWith(NewInst); + } + if (NewInst) { + ReplacedValues[&I] = NewInst; + ToRemove.push_back(&I); + return; + } + + if (auto *Cast = dyn_cast(&I)) { if (Cast->getSrcTy()->isIntegerTy(8)) { ToRemove.push_back(Cast); Cast->replaceAllUsesWith(ReplacedValues[Cast->getOperand(0)]); From 142def2ae72f88cd738b283b77e8c612d13f2239 Mon Sep 17 00:00:00 2001 From: Farzon Lotfi Date: Thu, 20 Mar 2025 17:29:45 -0400 Subject: [PATCH 4/4] address pr feedback --- llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp index a9edf77aa08af..395311e430fbb 100644 --- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp +++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp @@ -58,15 +58,16 @@ static void fixI8TruncUseChain(Instruction &I, if (Trunc->getDestTy()->isIntegerTy(8)) { ReplacedValues[Trunc] = Trunc->getOperand(0); ToRemove.push_back(Trunc); + return; } } - Value *NewInst = nullptr; + if (auto *BO = dyn_cast(&I)) { if (!I.getType()->isIntegerTy(8)) return; SmallVector NewOperands; ProcessOperands(NewOperands); - NewInst = + Value *NewInst = Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]); if (auto *OBO = dyn_cast(&I)) { if (OBO->hasNoSignedWrap()) @@ -74,6 +75,9 @@ static void fixI8TruncUseChain(Instruction &I, if (OBO->hasNoUnsignedWrap()) cast(NewInst)->setHasNoUnsignedWrap(); } + ReplacedValues[BO] = NewInst; + ToRemove.push_back(BO); + return; } if (auto *Cmp = dyn_cast(&I)) { @@ -81,13 +85,11 @@ static void fixI8TruncUseChain(Instruction &I, return; SmallVector NewOperands; ProcessOperands(NewOperands); - NewInst = + Value *NewInst = Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0], NewOperands[1]); Cmp->replaceAllUsesWith(NewInst); - } - if (NewInst) { - ReplacedValues[&I] = NewInst; - ToRemove.push_back(&I); + ReplacedValues[Cmp] = NewInst; + ToRemove.push_back(Cmp); return; }