diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp index 3e21f3c109456..79b0cf261ba31 100644 --- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp +++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp @@ -8,6 +8,8 @@ #include "DXILLegalizePass.h" #include "DirectX.h" +#include "llvm/ADT/APInt.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" @@ -510,6 +512,55 @@ static void updateFnegToFsub(Instruction &I, ToRemove.push_back(&I); } +static void +legalizeGetHighLowi64Bytes(Instruction &I, + SmallVectorImpl &ToRemove, + DenseMap &ReplacedValues) { + if (auto *BitCast = dyn_cast(&I)) { + if (BitCast->getDestTy() == + FixedVectorType::get(Type::getInt32Ty(I.getContext()), 2) && + BitCast->getSrcTy()->isIntegerTy(64)) { + ToRemove.push_back(BitCast); + ReplacedValues[BitCast] = BitCast->getOperand(0); + return; + } + } + + if (auto *Extract = dyn_cast(&I)) { + if (!dyn_cast(Extract->getVectorOperand())) + return; + auto *VecTy = dyn_cast(Extract->getVectorOperandType()); + if (VecTy && VecTy->getElementType()->isIntegerTy(32) && + VecTy->getNumElements() == 2) { + if (auto *Index = dyn_cast(Extract->getIndexOperand())) { + unsigned Idx = Index->getZExtValue(); + IRBuilder<> Builder(&I); + + auto *Replacement = ReplacedValues[Extract->getVectorOperand()]; + assert(Replacement && "The BitCast replacement should have been set " + "before working on ExtractElementInst."); + if (Idx == 0) { + Value *LowBytes = Builder.CreateTrunc( + Replacement, Type::getInt32Ty(I.getContext())); + ReplacedValues[Extract] = LowBytes; + } else { + assert(Idx == 1); + Value *LogicalShiftRight = Builder.CreateLShr( + Replacement, + ConstantInt::get( + Replacement->getType(), + APInt(Replacement->getType()->getIntegerBitWidth(), 32))); + Value *HighBytes = Builder.CreateTrunc( + LogicalShiftRight, Type::getInt32Ty(I.getContext())); + ReplacedValues[Extract] = HighBytes; + } + ToRemove.push_back(Extract); + Extract->replaceAllUsesWith(ReplacedValues[Extract]); + } + } + } +} + namespace { class DXILLegalizationPipeline { @@ -517,33 +568,49 @@ class DXILLegalizationPipeline { DXILLegalizationPipeline() { initializeLegalizationPipeline(); } bool runLegalizationPipeline(Function &F) { + bool MadeChange = false; SmallVector ToRemove; DenseMap ReplacedValues; - for (auto &I : instructions(F)) { - for (auto &LegalizationFn : LegalizationPipeline) - LegalizationFn(I, ToRemove, ReplacedValues); - } + for (int Stage = 0; Stage < NumStages; ++Stage) { + ToRemove.clear(); + ReplacedValues.clear(); + for (auto &I : instructions(F)) { + for (auto &LegalizationFn : LegalizationPipeline[Stage]) + LegalizationFn(I, ToRemove, ReplacedValues); + } - for (auto *Inst : reverse(ToRemove)) - Inst->eraseFromParent(); + for (auto *Inst : reverse(ToRemove)) + Inst->eraseFromParent(); - return !ToRemove.empty(); + MadeChange |= !ToRemove.empty(); + } + return MadeChange; } private: - SmallVector< + enum LegalizationStage { Stage1 = 0, Stage2 = 1, NumStages }; + + using LegalizationFnTy = std::function &, - DenseMap &)>> - LegalizationPipeline; + DenseMap &)>; + + SmallVector LegalizationPipeline[NumStages]; void initializeLegalizationPipeline() { - LegalizationPipeline.push_back(upcastI8AllocasAndUses); - LegalizationPipeline.push_back(fixI8UseChain); - LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements); - LegalizationPipeline.push_back(legalizeFreeze); - LegalizationPipeline.push_back(legalizeMemCpy); - LegalizationPipeline.push_back(removeMemSet); - LegalizationPipeline.push_back(updateFnegToFsub); + LegalizationPipeline[Stage1].push_back(upcastI8AllocasAndUses); + LegalizationPipeline[Stage1].push_back(fixI8UseChain); + LegalizationPipeline[Stage1].push_back(legalizeGetHighLowi64Bytes); + LegalizationPipeline[Stage1].push_back(legalizeFreeze); + LegalizationPipeline[Stage1].push_back(legalizeMemCpy); + LegalizationPipeline[Stage1].push_back(removeMemSet); + LegalizationPipeline[Stage1].push_back(updateFnegToFsub); + // Note: legalizeGetHighLowi64Bytes and + // downcastI64toI32InsertExtractElements both modify extractelement, so they + // must run staggered stages. legalizeGetHighLowi64Bytes runs first b\c it + // removes extractelements, reducing the number that + // downcastI64toI32InsertExtractElements needs to handle. + LegalizationPipeline[Stage2].push_back( + downcastI64toI32InsertExtractElements); } }; diff --git a/llvm/test/CodeGen/DirectX/legalize-i64-high-low-vec-split.ll b/llvm/test/CodeGen/DirectX/legalize-i64-high-low-vec-split.ll new file mode 100644 index 0000000000000..17fd3bf54acda --- /dev/null +++ b/llvm/test/CodeGen/DirectX/legalize-i64-high-low-vec-split.ll @@ -0,0 +1,18 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -S -passes='dxil-legalize' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s + +define void @split_via_extract(i64 noundef %a) { +; CHECK-LABEL: define void @split_via_extract( +; CHECK-SAME: i64 noundef [[A:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[TMP0:%.*]] = trunc i64 [[A]] to i32 +; CHECK-NEXT: [[TMP1:%.*]] = lshr i64 [[A]], 32 +; CHECK-NEXT: [[TMP2:%.*]] = trunc i64 [[TMP1]] to i32 +; CHECK-NEXT: ret void +; +entry: + %vecA = bitcast i64 %a to <2 x i32> + %low = extractelement <2 x i32> %vecA, i32 0 ; low 32 bits + %high = extractelement <2 x i32> %vecA, i32 1 ; high 32 bits + ret void +}