From 4cbaf28e9c35cd0dda35a7eb0e6ff4189bc3b4df Mon Sep 17 00:00:00 2001 From: Farzon Lotfi Date: Fri, 25 Apr 2025 16:51:23 -0400 Subject: [PATCH 1/3] [DirectX] Legalize i8 allocas --- llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 96 +++++++++++++++++-- .../CodeGen/DirectX/legalize-i8-alloca.ll | 53 ++++++++++ 2 files changed, 140 insertions(+), 9 deletions(-) create mode 100644 llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp index b62ff4c52f70cf..f4e443543c7282 100644 --- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp +++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp @@ -12,6 +12,7 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" #include "llvm/Pass.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include @@ -31,16 +32,17 @@ static void legalizeFreeze(Instruction &I, ToRemove.push_back(FI); } -static void fixI8TruncUseChain(Instruction &I, - SmallVectorImpl &ToRemove, - DenseMap &ReplacedValues) { +static void fixI8UseChain(Instruction &I, + SmallVectorImpl &ToRemove, + DenseMap &ReplacedValues) { auto ProcessOperands = [&](SmallVector &NewOperands) { Type *InstrType = IntegerType::get(I.getContext(), 32); for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) { Value *Op = I.getOperand(OpIdx); - if (ReplacedValues.count(Op)) + if (ReplacedValues.count(Op) && + ReplacedValues[Op]->getType()->isIntegerTy()) InstrType = ReplacedValues[Op]->getType(); } @@ -73,6 +75,31 @@ static void fixI8TruncUseChain(Instruction &I, } } + if (auto *Store = dyn_cast(&I)) { + if (!Store->getValueOperand()->getType()->isIntegerTy(8)) + return; + SmallVector NewOperands; + ProcessOperands(NewOperands); + Value *NewStore = Builder.CreateStore(NewOperands[0], NewOperands[1]); + ReplacedValues[Store] = NewStore; + ToRemove.push_back(Store); + return; + } + + if (auto *Load = dyn_cast(&I)) { + if (!I.getType()->isIntegerTy(8)) + return; + SmallVector NewOperands; + ProcessOperands(NewOperands); + Type *ElementType = NewOperands[0]->getType(); + if (auto *AI = dyn_cast(NewOperands[0])) + ElementType = AI->getAllocatedType(); + LoadInst *NewLoad = Builder.CreateLoad(ElementType, NewOperands[0]); + ReplacedValues[Load] = NewLoad; + ToRemove.push_back(Load); + return; + } + if (auto *BO = dyn_cast(&I)) { if (!I.getType()->isIntegerTy(8)) return; @@ -81,16 +108,29 @@ static void fixI8TruncUseChain(Instruction &I, Value *NewInst = Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]); if (auto *OBO = dyn_cast(&I)) { - if (OBO->hasNoSignedWrap()) - cast(NewInst)->setHasNoSignedWrap(); - if (OBO->hasNoUnsignedWrap()) - cast(NewInst)->setHasNoUnsignedWrap(); + auto *NewBO = dyn_cast(NewInst); + if (NewBO && OBO->hasNoSignedWrap()) + NewBO->setHasNoSignedWrap(); + if (NewBO && OBO->hasNoUnsignedWrap()) + NewBO->setHasNoUnsignedWrap(); } ReplacedValues[BO] = NewInst; ToRemove.push_back(BO); return; } + if (auto *Sel = dyn_cast(&I)) { + if (!I.getType()->isIntegerTy(8)) + return; + SmallVector NewOperands; + ProcessOperands(NewOperands); + Value *NewInst = Builder.CreateSelect(Sel->getCondition(), NewOperands[1], + NewOperands[2]); + ReplacedValues[Sel] = NewInst; + ToRemove.push_back(Sel); + return; + } + if (auto *Cmp = dyn_cast(&I)) { if (!Cmp->getOperand(0)->getType()->isIntegerTy(8)) return; @@ -105,6 +145,7 @@ static void fixI8TruncUseChain(Instruction &I, } if (auto *Cast = dyn_cast(&I)) { + if (Cast->getSrcTy()->isIntegerTy(8)) { ToRemove.push_back(Cast); Cast->replaceAllUsesWith(ReplacedValues[Cast->getOperand(0)]); @@ -112,6 +153,42 @@ static void fixI8TruncUseChain(Instruction &I, } } +static void upcastI8AllocasAndUses(Instruction &I, + SmallVectorImpl &ToRemove, + DenseMap &ReplacedValues) { + auto *AI = dyn_cast(&I); + if (!AI || !AI->getAllocatedType()->isIntegerTy(8)) + return; + + std::optional TargetType; + bool Conflict = false; + for (User *U : AI->users()) { + auto *Load = dyn_cast(U); + if (!Load) + continue; + for (User *LU : Load->users()) { + auto *Cast = dyn_cast(LU); + if (!Cast) + continue; + Type *T = Cast->getType(); + if (!TargetType) + TargetType = T; + + if (TargetType.value() != T) { + Conflict = true; + break; + } + } + } + if (!TargetType || Conflict) + return; + + IRBuilder<> Builder(AI); + AllocaInst *NewAlloca = Builder.CreateAlloca(TargetType.value()); + ReplacedValues[AI] = NewAlloca; + ToRemove.push_back(AI); +} + static void downcastI64toI32InsertExtractElements(Instruction &I, SmallVectorImpl &ToRemove, @@ -178,7 +255,8 @@ class DXILLegalizationPipeline { LegalizationPipeline; void initializeLegalizationPipeline() { - LegalizationPipeline.push_back(fixI8TruncUseChain); + LegalizationPipeline.push_back(upcastI8AllocasAndUses); + LegalizationPipeline.push_back(fixI8UseChain); LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements); LegalizationPipeline.push_back(legalizeFreeze); } diff --git a/llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll b/llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll new file mode 100644 index 00000000000000..a34b9be300a381 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll @@ -0,0 +1,53 @@ +; RUN: opt -S -passes='dxil-legalize' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s + +define void @const_i8_store() { + %accum.i.flat = alloca [1 x i32], align 4 + %i = alloca i8, align 4 + store i8 1, ptr %i + %i8.load = load i8, ptr %i + %z = zext i8 %i8.load to i32 + %gep = getelementptr i32, ptr %accum.i.flat, i32 0 + store i32 %z, ptr %gep, align 4 + ret void +} + +define void @const_add_i8_store() { + %accum.i.flat = alloca [1 x i32], align 4 + %i = alloca i8, align 4 + %add_i8 = add nsw i8 3, 1 + store i8 %add_i8, ptr %i + %i8.load = load i8, ptr %i + %z = zext i8 %i8.load to i32 + %gep = getelementptr i32, ptr %accum.i.flat, i32 0 + store i32 %z, ptr %gep, align 4 + ret void +} + +define void @var_i8_store(i1 %cmp.i8) { + %accum.i.flat = alloca [1 x i32], align 4 + %i = alloca i8, align 4 + %select.i8 = select i1 %cmp.i8, i8 1, i8 2 + store i8 %select.i8, ptr %i + %i8.load = load i8, ptr %i + %z = zext i8 %i8.load to i32 + %gep = getelementptr i32, ptr %accum.i.flat, i32 0 + store i32 %z, ptr %gep, align 4 + ret void +} + +define void @conflicting_cast(i1 %cmp.i8) { + %accum.i.flat = alloca [2 x i32], align 4 + %i = alloca i8, align 4 + %select.i8 = select i1 %cmp.i8, i8 1, i8 2 + store i8 %select.i8, ptr %i + %i8.load = load i8, ptr %i + %z = zext i8 %i8.load to i16 + %gep1 = getelementptr i16, ptr %accum.i.flat, i32 0 + store i16 %z, ptr %gep1, align 2 + %gep2 = getelementptr i16, ptr %accum.i.flat, i32 1 + store i16 %z, ptr %gep2, align 2 + %z2 = zext i8 %i8.load to i32 + %gep3 = getelementptr i32, ptr %accum.i.flat, i32 1 + store i32 %z2, ptr %gep3, align 4 + ret void +} \ No newline at end of file From 3fcad1f755c0f42b71773ac5c4ac311da762895f Mon Sep 17 00:00:00 2001 From: Farzon Lotfi Date: Fri, 25 Apr 2025 17:14:05 -0400 Subject: [PATCH 2/3] instead of detecting the conflicts lets pick the smallest value for the alloca then keep the cast but change the input type. --- llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 46 ++++--- .../CodeGen/DirectX/legalize-i8-alloca.ll | 128 ++++++++++++------ 2 files changed, 116 insertions(+), 58 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp index f4e443543c7282..b7b209fcecbc92 100644 --- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp +++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp @@ -145,11 +145,23 @@ static void fixI8UseChain(Instruction &I, } if (auto *Cast = dyn_cast(&I)) { - - if (Cast->getSrcTy()->isIntegerTy(8)) { - ToRemove.push_back(Cast); - Cast->replaceAllUsesWith(ReplacedValues[Cast->getOperand(0)]); + if (!Cast->getSrcTy()->isIntegerTy(8)) + return; + + ToRemove.push_back(Cast); + auto* Replacement =ReplacedValues[Cast->getOperand(0)]; + if (Cast->getType() == Replacement->getType()) { + Cast->replaceAllUsesWith(Replacement); + return; } + Value* AdjustedCast = nullptr; + if (Cast->getOpcode() == Instruction::ZExt) + AdjustedCast = Builder.CreateZExtOrTrunc(Replacement, Cast->getType()); + if (Cast->getOpcode() == Instruction::SExt) + AdjustedCast = Builder.CreateSExtOrTrunc(Replacement, Cast->getType()); + + if(AdjustedCast) + Cast->replaceAllUsesWith(AdjustedCast); } } @@ -160,8 +172,9 @@ static void upcastI8AllocasAndUses(Instruction &I, if (!AI || !AI->getAllocatedType()->isIntegerTy(8)) return; - std::optional TargetType; - bool Conflict = false; + Type *SmallestType = nullptr; + + // Gather all cast targets for (User *U : AI->users()) { auto *Load = dyn_cast(U); if (!Load) @@ -170,21 +183,20 @@ static void upcastI8AllocasAndUses(Instruction &I, auto *Cast = dyn_cast(LU); if (!Cast) continue; - Type *T = Cast->getType(); - if (!TargetType) - TargetType = T; - - if (TargetType.value() != T) { - Conflict = true; - break; - } + Type *Ty = Cast->getType(); + if (!SmallestType || + Ty->getPrimitiveSizeInBits() < SmallestType->getPrimitiveSizeInBits()) + SmallestType = Ty; } } - if (!TargetType || Conflict) - return; + if (!SmallestType) + return; // no valid casts found + + // Replace alloca IRBuilder<> Builder(AI); - AllocaInst *NewAlloca = Builder.CreateAlloca(TargetType.value()); + auto *NewAlloca = + Builder.CreateAlloca(SmallestType); ReplacedValues[AI] = NewAlloca; ToRemove.push_back(AI); } diff --git a/llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll b/llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll index a34b9be300a381..529a69fca5d342 100644 --- a/llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll +++ b/llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll @@ -1,53 +1,99 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 ; RUN: opt -S -passes='dxil-legalize' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s define void @const_i8_store() { - %accum.i.flat = alloca [1 x i32], align 4 - %i = alloca i8, align 4 - store i8 1, ptr %i - %i8.load = load i8, ptr %i - %z = zext i8 %i8.load to i32 - %gep = getelementptr i32, ptr %accum.i.flat, i32 0 - store i32 %z, ptr %gep, align 4 - ret void +; CHECK-LABEL: define void @const_i8_store() { +; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4 +; CHECK-NEXT: [[TMP1:%.*]] = alloca i32, align 4 +; CHECK-NEXT: store i32 1, ptr [[TMP1]], align 4 +; CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[TMP1]], align 4 +; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0 +; CHECK-NEXT: store i32 [[TMP2]], ptr [[GEP]], align 4 +; CHECK-NEXT: ret void +; + %accum.i.flat = alloca [1 x i32], align 4 + %i = alloca i8, align 4 + store i8 1, ptr %i + %i8.load = load i8, ptr %i + %z = zext i8 %i8.load to i32 + %gep = getelementptr i32, ptr %accum.i.flat, i32 0 + store i32 %z, ptr %gep, align 4 + ret void } define void @const_add_i8_store() { - %accum.i.flat = alloca [1 x i32], align 4 - %i = alloca i8, align 4 - %add_i8 = add nsw i8 3, 1 - store i8 %add_i8, ptr %i - %i8.load = load i8, ptr %i - %z = zext i8 %i8.load to i32 - %gep = getelementptr i32, ptr %accum.i.flat, i32 0 - store i32 %z, ptr %gep, align 4 - ret void +; CHECK-LABEL: define void @const_add_i8_store() { +; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4 +; CHECK-NEXT: [[TMP1:%.*]] = alloca i32, align 4 +; CHECK-NEXT: store i32 4, ptr [[TMP1]], align 4 +; CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[TMP1]], align 4 +; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0 +; CHECK-NEXT: store i32 [[TMP2]], ptr [[GEP]], align 4 +; CHECK-NEXT: ret void +; + %accum.i.flat = alloca [1 x i32], align 4 + %i = alloca i8, align 4 + %add_i8 = add nsw i8 3, 1 + store i8 %add_i8, ptr %i + %i8.load = load i8, ptr %i + %z = zext i8 %i8.load to i32 + %gep = getelementptr i32, ptr %accum.i.flat, i32 0 + store i32 %z, ptr %gep, align 4 + ret void } define void @var_i8_store(i1 %cmp.i8) { - %accum.i.flat = alloca [1 x i32], align 4 - %i = alloca i8, align 4 - %select.i8 = select i1 %cmp.i8, i8 1, i8 2 - store i8 %select.i8, ptr %i - %i8.load = load i8, ptr %i - %z = zext i8 %i8.load to i32 - %gep = getelementptr i32, ptr %accum.i.flat, i32 0 - store i32 %z, ptr %gep, align 4 - ret void +; CHECK-LABEL: define void @var_i8_store( +; CHECK-SAME: i1 [[CMP_I8:%.*]]) { +; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4 +; CHECK-NEXT: [[TMP1:%.*]] = alloca i32, align 4 +; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[CMP_I8]], i32 1, i32 2 +; CHECK-NEXT: store i32 [[TMP2]], ptr [[TMP1]], align 4 +; CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr [[TMP1]], align 4 +; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0 +; CHECK-NEXT: store i32 [[TMP3]], ptr [[GEP]], align 4 +; CHECK-NEXT: ret void +; + %accum.i.flat = alloca [1 x i32], align 4 + %i = alloca i8, align 4 + %select.i8 = select i1 %cmp.i8, i8 1, i8 2 + store i8 %select.i8, ptr %i + %i8.load = load i8, ptr %i + %z = zext i8 %i8.load to i32 + %gep = getelementptr i32, ptr %accum.i.flat, i32 0 + store i32 %z, ptr %gep, align 4 + ret void } define void @conflicting_cast(i1 %cmp.i8) { - %accum.i.flat = alloca [2 x i32], align 4 - %i = alloca i8, align 4 - %select.i8 = select i1 %cmp.i8, i8 1, i8 2 - store i8 %select.i8, ptr %i - %i8.load = load i8, ptr %i - %z = zext i8 %i8.load to i16 - %gep1 = getelementptr i16, ptr %accum.i.flat, i32 0 - store i16 %z, ptr %gep1, align 2 - %gep2 = getelementptr i16, ptr %accum.i.flat, i32 1 - store i16 %z, ptr %gep2, align 2 - %z2 = zext i8 %i8.load to i32 - %gep3 = getelementptr i32, ptr %accum.i.flat, i32 1 - store i32 %z2, ptr %gep3, align 4 - ret void -} \ No newline at end of file +; CHECK-LABEL: define void @conflicting_cast( +; CHECK-SAME: i1 [[CMP_I8:%.*]]) { +; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [2 x i32], align 4 +; CHECK-NEXT: [[TMP1:%.*]] = alloca i16, align 2 +; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[CMP_I8]], i32 1, i32 2 +; CHECK-NEXT: store i32 [[TMP2]], ptr [[TMP1]], align 4 +; CHECK-NEXT: [[TMP3:%.*]] = load i16, ptr [[TMP1]], align 2 +; CHECK-NEXT: [[GEP1:%.*]] = getelementptr i16, ptr [[ACCUM_I_FLAT]], i32 0 +; CHECK-NEXT: store i16 [[TMP3]], ptr [[GEP1]], align 2 +; CHECK-NEXT: [[GEP2:%.*]] = getelementptr i16, ptr [[ACCUM_I_FLAT]], i32 1 +; CHECK-NEXT: store i16 [[TMP3]], ptr [[GEP2]], align 2 +; CHECK-NEXT: [[TMP4:%.*]] = zext i16 [[TMP3]] to i32 +; CHECK-NEXT: [[GEP3:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 1 +; CHECK-NEXT: store i32 [[TMP4]], ptr [[GEP3]], align 4 +; CHECK-NEXT: ret void +; + %accum.i.flat = alloca [2 x i32], align 4 + %i = alloca i8, align 4 + %select.i8 = select i1 %cmp.i8, i8 1, i8 2 + store i8 %select.i8, ptr %i + %i8.load = load i8, ptr %i + %z = zext i8 %i8.load to i16 + %gep1 = getelementptr i16, ptr %accum.i.flat, i32 0 + store i16 %z, ptr %gep1, align 2 + %gep2 = getelementptr i16, ptr %accum.i.flat, i32 1 + store i16 %z, ptr %gep2, align 2 + %z2 = zext i8 %i8.load to i32 + %gep3 = getelementptr i32, ptr %accum.i.flat, i32 1 + store i32 %z2, ptr %gep3, align 4 + ret void +} From 6e3a11db8f8b9954839e4f4852b0946d5bf6915a Mon Sep 17 00:00:00 2001 From: Farzon Lotfi Date: Fri, 25 Apr 2025 17:22:20 -0400 Subject: [PATCH 3/3] fix formatting --- llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp index b7b209fcecbc92..7da5a71ab729b9 100644 --- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp +++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp @@ -147,20 +147,20 @@ static void fixI8UseChain(Instruction &I, if (auto *Cast = dyn_cast(&I)) { if (!Cast->getSrcTy()->isIntegerTy(8)) return; - + ToRemove.push_back(Cast); - auto* Replacement =ReplacedValues[Cast->getOperand(0)]; + auto *Replacement = ReplacedValues[Cast->getOperand(0)]; if (Cast->getType() == Replacement->getType()) { Cast->replaceAllUsesWith(Replacement); return; } - Value* AdjustedCast = nullptr; + Value *AdjustedCast = nullptr; if (Cast->getOpcode() == Instruction::ZExt) AdjustedCast = Builder.CreateZExtOrTrunc(Replacement, Cast->getType()); if (Cast->getOpcode() == Instruction::SExt) AdjustedCast = Builder.CreateSExtOrTrunc(Replacement, Cast->getType()); - - if(AdjustedCast) + + if (AdjustedCast) Cast->replaceAllUsesWith(AdjustedCast); } } @@ -195,8 +195,7 @@ static void upcastI8AllocasAndUses(Instruction &I, // Replace alloca IRBuilder<> Builder(AI); - auto *NewAlloca = - Builder.CreateAlloca(SmallestType); + auto *NewAlloca = Builder.CreateAlloca(SmallestType); ReplacedValues[AI] = NewAlloca; ToRemove.push_back(AI); }