diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp index b62ff4c52f70c..7da5a71ab729b 100644 --- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp +++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp @@ -12,6 +12,7 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" #include "llvm/Pass.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include @@ -31,16 +32,17 @@ static void legalizeFreeze(Instruction &I, ToRemove.push_back(FI); } -static void fixI8TruncUseChain(Instruction &I, - SmallVectorImpl &ToRemove, - DenseMap &ReplacedValues) { +static void fixI8UseChain(Instruction &I, + SmallVectorImpl &ToRemove, + DenseMap &ReplacedValues) { auto ProcessOperands = [&](SmallVector &NewOperands) { Type *InstrType = IntegerType::get(I.getContext(), 32); for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) { Value *Op = I.getOperand(OpIdx); - if (ReplacedValues.count(Op)) + if (ReplacedValues.count(Op) && + ReplacedValues[Op]->getType()->isIntegerTy()) InstrType = ReplacedValues[Op]->getType(); } @@ -73,6 +75,31 @@ static void fixI8TruncUseChain(Instruction &I, } } + if (auto *Store = dyn_cast(&I)) { + if (!Store->getValueOperand()->getType()->isIntegerTy(8)) + return; + SmallVector NewOperands; + ProcessOperands(NewOperands); + Value *NewStore = Builder.CreateStore(NewOperands[0], NewOperands[1]); + ReplacedValues[Store] = NewStore; + ToRemove.push_back(Store); + return; + } + + if (auto *Load = dyn_cast(&I)) { + if (!I.getType()->isIntegerTy(8)) + return; + SmallVector NewOperands; + ProcessOperands(NewOperands); + Type *ElementType = NewOperands[0]->getType(); + if (auto *AI = dyn_cast(NewOperands[0])) + ElementType = AI->getAllocatedType(); + LoadInst *NewLoad = Builder.CreateLoad(ElementType, NewOperands[0]); + ReplacedValues[Load] = NewLoad; + ToRemove.push_back(Load); + return; + } + if (auto *BO = dyn_cast(&I)) { if (!I.getType()->isIntegerTy(8)) return; @@ -81,16 +108,29 @@ static void fixI8TruncUseChain(Instruction &I, Value *NewInst = Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]); if (auto *OBO = dyn_cast(&I)) { - if (OBO->hasNoSignedWrap()) - cast(NewInst)->setHasNoSignedWrap(); - if (OBO->hasNoUnsignedWrap()) - cast(NewInst)->setHasNoUnsignedWrap(); + auto *NewBO = dyn_cast(NewInst); + if (NewBO && OBO->hasNoSignedWrap()) + NewBO->setHasNoSignedWrap(); + if (NewBO && OBO->hasNoUnsignedWrap()) + NewBO->setHasNoUnsignedWrap(); } ReplacedValues[BO] = NewInst; ToRemove.push_back(BO); return; } + if (auto *Sel = dyn_cast(&I)) { + if (!I.getType()->isIntegerTy(8)) + return; + SmallVector NewOperands; + ProcessOperands(NewOperands); + Value *NewInst = Builder.CreateSelect(Sel->getCondition(), NewOperands[1], + NewOperands[2]); + ReplacedValues[Sel] = NewInst; + ToRemove.push_back(Sel); + return; + } + if (auto *Cmp = dyn_cast(&I)) { if (!Cmp->getOperand(0)->getType()->isIntegerTy(8)) return; @@ -105,13 +145,61 @@ static void fixI8TruncUseChain(Instruction &I, } if (auto *Cast = dyn_cast(&I)) { - if (Cast->getSrcTy()->isIntegerTy(8)) { - ToRemove.push_back(Cast); - Cast->replaceAllUsesWith(ReplacedValues[Cast->getOperand(0)]); + if (!Cast->getSrcTy()->isIntegerTy(8)) + return; + + ToRemove.push_back(Cast); + auto *Replacement = ReplacedValues[Cast->getOperand(0)]; + if (Cast->getType() == Replacement->getType()) { + Cast->replaceAllUsesWith(Replacement); + return; } + Value *AdjustedCast = nullptr; + if (Cast->getOpcode() == Instruction::ZExt) + AdjustedCast = Builder.CreateZExtOrTrunc(Replacement, Cast->getType()); + if (Cast->getOpcode() == Instruction::SExt) + AdjustedCast = Builder.CreateSExtOrTrunc(Replacement, Cast->getType()); + + if (AdjustedCast) + Cast->replaceAllUsesWith(AdjustedCast); } } +static void upcastI8AllocasAndUses(Instruction &I, + SmallVectorImpl &ToRemove, + DenseMap &ReplacedValues) { + auto *AI = dyn_cast(&I); + if (!AI || !AI->getAllocatedType()->isIntegerTy(8)) + return; + + Type *SmallestType = nullptr; + + // Gather all cast targets + for (User *U : AI->users()) { + auto *Load = dyn_cast(U); + if (!Load) + continue; + for (User *LU : Load->users()) { + auto *Cast = dyn_cast(LU); + if (!Cast) + continue; + Type *Ty = Cast->getType(); + if (!SmallestType || + Ty->getPrimitiveSizeInBits() < SmallestType->getPrimitiveSizeInBits()) + SmallestType = Ty; + } + } + + if (!SmallestType) + return; // no valid casts found + + // Replace alloca + IRBuilder<> Builder(AI); + auto *NewAlloca = Builder.CreateAlloca(SmallestType); + ReplacedValues[AI] = NewAlloca; + ToRemove.push_back(AI); +} + static void downcastI64toI32InsertExtractElements(Instruction &I, SmallVectorImpl &ToRemove, @@ -178,7 +266,8 @@ class DXILLegalizationPipeline { LegalizationPipeline; void initializeLegalizationPipeline() { - LegalizationPipeline.push_back(fixI8TruncUseChain); + LegalizationPipeline.push_back(upcastI8AllocasAndUses); + LegalizationPipeline.push_back(fixI8UseChain); LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements); LegalizationPipeline.push_back(legalizeFreeze); } diff --git a/llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll b/llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll new file mode 100644 index 0000000000000..529a69fca5d34 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll @@ -0,0 +1,99 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -S -passes='dxil-legalize' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s + +define void @const_i8_store() { +; CHECK-LABEL: define void @const_i8_store() { +; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4 +; CHECK-NEXT: [[TMP1:%.*]] = alloca i32, align 4 +; CHECK-NEXT: store i32 1, ptr [[TMP1]], align 4 +; CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[TMP1]], align 4 +; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0 +; CHECK-NEXT: store i32 [[TMP2]], ptr [[GEP]], align 4 +; CHECK-NEXT: ret void +; + %accum.i.flat = alloca [1 x i32], align 4 + %i = alloca i8, align 4 + store i8 1, ptr %i + %i8.load = load i8, ptr %i + %z = zext i8 %i8.load to i32 + %gep = getelementptr i32, ptr %accum.i.flat, i32 0 + store i32 %z, ptr %gep, align 4 + ret void +} + +define void @const_add_i8_store() { +; CHECK-LABEL: define void @const_add_i8_store() { +; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4 +; CHECK-NEXT: [[TMP1:%.*]] = alloca i32, align 4 +; CHECK-NEXT: store i32 4, ptr [[TMP1]], align 4 +; CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[TMP1]], align 4 +; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0 +; CHECK-NEXT: store i32 [[TMP2]], ptr [[GEP]], align 4 +; CHECK-NEXT: ret void +; + %accum.i.flat = alloca [1 x i32], align 4 + %i = alloca i8, align 4 + %add_i8 = add nsw i8 3, 1 + store i8 %add_i8, ptr %i + %i8.load = load i8, ptr %i + %z = zext i8 %i8.load to i32 + %gep = getelementptr i32, ptr %accum.i.flat, i32 0 + store i32 %z, ptr %gep, align 4 + ret void +} + +define void @var_i8_store(i1 %cmp.i8) { +; CHECK-LABEL: define void @var_i8_store( +; CHECK-SAME: i1 [[CMP_I8:%.*]]) { +; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4 +; CHECK-NEXT: [[TMP1:%.*]] = alloca i32, align 4 +; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[CMP_I8]], i32 1, i32 2 +; CHECK-NEXT: store i32 [[TMP2]], ptr [[TMP1]], align 4 +; CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr [[TMP1]], align 4 +; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0 +; CHECK-NEXT: store i32 [[TMP3]], ptr [[GEP]], align 4 +; CHECK-NEXT: ret void +; + %accum.i.flat = alloca [1 x i32], align 4 + %i = alloca i8, align 4 + %select.i8 = select i1 %cmp.i8, i8 1, i8 2 + store i8 %select.i8, ptr %i + %i8.load = load i8, ptr %i + %z = zext i8 %i8.load to i32 + %gep = getelementptr i32, ptr %accum.i.flat, i32 0 + store i32 %z, ptr %gep, align 4 + ret void +} + +define void @conflicting_cast(i1 %cmp.i8) { +; CHECK-LABEL: define void @conflicting_cast( +; CHECK-SAME: i1 [[CMP_I8:%.*]]) { +; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [2 x i32], align 4 +; CHECK-NEXT: [[TMP1:%.*]] = alloca i16, align 2 +; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[CMP_I8]], i32 1, i32 2 +; CHECK-NEXT: store i32 [[TMP2]], ptr [[TMP1]], align 4 +; CHECK-NEXT: [[TMP3:%.*]] = load i16, ptr [[TMP1]], align 2 +; CHECK-NEXT: [[GEP1:%.*]] = getelementptr i16, ptr [[ACCUM_I_FLAT]], i32 0 +; CHECK-NEXT: store i16 [[TMP3]], ptr [[GEP1]], align 2 +; CHECK-NEXT: [[GEP2:%.*]] = getelementptr i16, ptr [[ACCUM_I_FLAT]], i32 1 +; CHECK-NEXT: store i16 [[TMP3]], ptr [[GEP2]], align 2 +; CHECK-NEXT: [[TMP4:%.*]] = zext i16 [[TMP3]] to i32 +; CHECK-NEXT: [[GEP3:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 1 +; CHECK-NEXT: store i32 [[TMP4]], ptr [[GEP3]], align 4 +; CHECK-NEXT: ret void +; + %accum.i.flat = alloca [2 x i32], align 4 + %i = alloca i8, align 4 + %select.i8 = select i1 %cmp.i8, i8 1, i8 2 + store i8 %select.i8, ptr %i + %i8.load = load i8, ptr %i + %z = zext i8 %i8.load to i16 + %gep1 = getelementptr i16, ptr %accum.i.flat, i32 0 + store i16 %z, ptr %gep1, align 2 + %gep2 = getelementptr i16, ptr %accum.i.flat, i32 1 + store i16 %z, ptr %gep2, align 2 + %z2 = zext i8 %i8.load to i32 + %gep3 = getelementptr i32, ptr %accum.i.flat, i32 1 + store i32 %z2, ptr %gep3, align 4 + ret void +}