diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp index 23883c936a20d..3e21f3c109456 100644 --- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp +++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp @@ -87,20 +87,63 @@ static void fixI8UseChain(Instruction &I, return; } - if (auto *Load = dyn_cast(&I)) { - if (!I.getType()->isIntegerTy(8)) - return; + if (auto *Load = dyn_cast(&I); + Load && I.getType()->isIntegerTy(8)) { SmallVector NewOperands; ProcessOperands(NewOperands); Type *ElementType = NewOperands[0]->getType(); if (auto *AI = dyn_cast(NewOperands[0])) ElementType = AI->getAllocatedType(); + if (auto *GEP = dyn_cast(NewOperands[0])) { + ElementType = GEP->getSourceElementType(); + if (ElementType->isArrayTy()) + ElementType = ElementType->getArrayElementType(); + } LoadInst *NewLoad = Builder.CreateLoad(ElementType, NewOperands[0]); ReplacedValues[Load] = NewLoad; ToRemove.push_back(Load); return; } + if (auto *Load = dyn_cast(&I); + Load && isa(Load->getPointerOperand())) { + auto *CE = dyn_cast(Load->getPointerOperand()); + if (!(CE->getOpcode() == Instruction::GetElementPtr)) + return; + auto *GEP = dyn_cast(CE); + if (!GEP->getSourceElementType()->isIntegerTy(8)) + return; + + Type *ElementType = Load->getType(); + ConstantInt *Offset = dyn_cast(GEP->getOperand(1)); + uint32_t ByteOffset = Offset->getZExtValue(); + uint32_t ElemSize = Load->getDataLayout().getTypeAllocSize(ElementType); + uint32_t Index = ByteOffset / ElemSize; + + Value *PtrOperand = GEP->getPointerOperand(); + Type *GEPType = GEP->getPointerOperandType(); + + if (auto *GV = dyn_cast(PtrOperand)) + GEPType = GV->getValueType(); + if (auto *AI = dyn_cast(PtrOperand)) + GEPType = AI->getAllocatedType(); + + if (auto *ArrTy = dyn_cast(GEPType)) + GEPType = ArrTy; + else + GEPType = ArrayType::get(ElementType, 1); // its a scalar + + Value *NewGEP = Builder.CreateGEP( + GEPType, PtrOperand, {Builder.getInt32(0), Builder.getInt32(Index)}, + GEP->getName(), GEP->getNoWrapFlags()); + + LoadInst *NewLoad = Builder.CreateLoad(ElementType, NewGEP); + ReplacedValues[Load] = NewLoad; + Load->replaceAllUsesWith(NewLoad); + ToRemove.push_back(Load); + return; + } + if (auto *BO = dyn_cast(&I)) { if (!I.getType()->isIntegerTy(8)) return; @@ -155,6 +198,7 @@ static void fixI8UseChain(Instruction &I, Cast->replaceAllUsesWith(Replacement); return; } + Value *AdjustedCast = nullptr; if (Cast->getOpcode() == Instruction::ZExt) AdjustedCast = Builder.CreateZExtOrTrunc(Replacement, Cast->getType()); @@ -164,6 +208,45 @@ static void fixI8UseChain(Instruction &I, if (AdjustedCast) Cast->replaceAllUsesWith(AdjustedCast); } + if (auto *GEP = dyn_cast(&I)) { + if (!GEP->getType()->isPointerTy() || + !GEP->getSourceElementType()->isIntegerTy(8)) + return; + + Value *BasePtr = GEP->getPointerOperand(); + if (ReplacedValues.count(BasePtr)) + BasePtr = ReplacedValues[BasePtr]; + + Type *ElementType = BasePtr->getType(); + + if (auto *AI = dyn_cast(BasePtr)) + ElementType = AI->getAllocatedType(); + if (auto *GV = dyn_cast(BasePtr)) + ElementType = GV->getValueType(); + + Type *GEPType = ElementType; + if (auto *ArrTy = dyn_cast(ElementType)) + ElementType = ArrTy->getArrayElementType(); + else + GEPType = ArrayType::get(ElementType, 1); // its a scalar + + ConstantInt *Offset = dyn_cast(GEP->getOperand(1)); + // Note: i8 to i32 offset conversion without emitting IR requires constant + // ints. Since offset conversion is common, we can safely assume Offset is + // always a ConstantInt, so no need to have a conditional bail out on + // nullptr, instead assert this is the case. + assert(Offset && "Offset is expected to be a ConstantInt"); + uint32_t ByteOffset = Offset->getZExtValue(); + uint32_t ElemSize = GEP->getDataLayout().getTypeAllocSize(ElementType); + assert(ElemSize > 0 && "ElementSize must be set"); + uint32_t Index = ByteOffset / ElemSize; + Value *NewGEP = Builder.CreateGEP( + GEPType, BasePtr, {Builder.getInt32(0), Builder.getInt32(Index)}, + GEP->getName(), GEP->getNoWrapFlags()); + ReplacedValues[GEP] = NewGEP; + GEP->replaceAllUsesWith(NewGEP); + ToRemove.push_back(GEP); + } } static void upcastI8AllocasAndUses(Instruction &I, @@ -175,15 +258,12 @@ static void upcastI8AllocasAndUses(Instruction &I, Type *SmallestType = nullptr; - for (User *U : AI->users()) { - auto *Load = dyn_cast(U); - if (!Load) - continue; + auto ProcessLoad = [&](LoadInst *Load) { for (User *LU : Load->users()) { Type *Ty = nullptr; - if (auto *Cast = dyn_cast(LU)) + if (CastInst *Cast = dyn_cast(LU)) Ty = Cast->getType(); - if (CallInst *CI = dyn_cast(LU)) { + else if (CallInst *CI = dyn_cast(LU)) { if (CI->getIntrinsicID() == Intrinsic::memset) Ty = Type::getInt32Ty(CI->getContext()); } @@ -195,6 +275,17 @@ static void upcastI8AllocasAndUses(Instruction &I, Ty->getPrimitiveSizeInBits() < SmallestType->getPrimitiveSizeInBits()) SmallestType = Ty; } + }; + + for (User *U : AI->users()) { + if (auto *Load = dyn_cast(U)) + ProcessLoad(Load); + else if (auto *GEP = dyn_cast(U)) { + for (User *GU : GEP->users()) { + if (auto *Load = dyn_cast(GU)) + ProcessLoad(Load); + } + } } if (!SmallestType) diff --git a/llvm/test/CodeGen/DirectX/legalize-i8.ll b/llvm/test/CodeGen/DirectX/legalize-i8.ll index 2602be778cd86..f8aa2c5ecd932 100644 --- a/llvm/test/CodeGen/DirectX/legalize-i8.ll +++ b/llvm/test/CodeGen/DirectX/legalize-i8.ll @@ -106,3 +106,75 @@ define i32 @all_imm() { %2 = sext i8 %1 to i32 ret i32 %2 } + +define i32 @scalar_i8_geps() { + ; CHECK-LABEL: define i32 @scalar_i8_geps( + ; CHECK-NEXT: [[ALLOCA:%.*]] = alloca i32, align 4 + ; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds nuw [1 x i32], ptr [[ALLOCA]], i32 0, i32 0 + ; CHECK: [[LOAD:%.*]] = load i32, ptr [[GEP]], align 4 + ; CHECK-NEXT: ret i32 [[LOAD]] + %1 = alloca i8, align 4 + %2 = getelementptr inbounds nuw i8, ptr %1, i32 0 + %3 = load i8, ptr %2 + %4 = sext i8 %3 to i32 + ret i32 %4 +} + +define i32 @i8_geps_index0() { + ; CHECK-LABEL: define i32 @i8_geps_index0( + ; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [2 x i32], align 8 + ; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds nuw [2 x i32], ptr [[ALLOCA]], i32 0, i32 0 + ; CHECK: [[LOAD:%.*]] = load i32, ptr [[GEP]], align 4 + ; CHECK-NEXT: ret i32 [[LOAD]] + %1 = alloca [2 x i32], align 8 + %2 = getelementptr inbounds nuw i8, ptr %1, i32 0 + %3 = load i8, ptr %2 + %4 = sext i8 %3 to i32 + ret i32 %4 +} + +define i32 @i8_geps_index1() { + ; CHECK-LABEL: define i32 @i8_geps_index1( + ; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [2 x i32], align 8 + ; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds nuw [2 x i32], ptr [[ALLOCA]], i32 0, i32 1 + ; CHECK: [[LOAD:%.*]] = load i32, ptr [[GEP]], align 4 + ; CHECK-NEXT: ret i32 [[LOAD]] + %1 = alloca [2 x i32], align 8 + %2 = getelementptr inbounds nuw i8, ptr %1, i32 4 + %3 = load i8, ptr %2 + %4 = sext i8 %3 to i32 + ret i32 %4 +} + +define i32 @i8_gep_store() { + ; CHECK-LABEL: define i32 @i8_gep_store( + ; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [2 x i32], align 8 + ; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds nuw [2 x i32], ptr [[ALLOCA]], i32 0, i32 1 + ; CHECK-NEXT: store i32 1, ptr [[GEP]], align 4 + ; CHECK: [[LOAD:%.*]] = load i32, ptr [[GEP]], align 4 + ; CHECK-NEXT: ret i32 [[LOAD]] + %1 = alloca [2 x i32], align 8 + %2 = getelementptr inbounds nuw i8, ptr %1, i32 4 + store i8 1, ptr %2 + %3 = load i8, ptr %2 + %4 = sext i8 %3 to i32 + ret i32 %4 +} + +@g = local_unnamed_addr addrspace(3) global [2 x float] zeroinitializer, align 4 +define float @i8_gep_global_index() { + ; CHECK-LABEL: define float @i8_gep_global_index( + ; CHECK-NEXT: [[LOAD:%.*]] = load float, ptr addrspace(3) getelementptr inbounds nuw ([2 x float], ptr addrspace(3) @g, i32 0, i32 1), align 4 + ; CHECK-NEXT: ret float [[LOAD]] + %1 = getelementptr inbounds nuw i8, ptr addrspace(3) @g, i32 4 + %2 = load float, ptr addrspace(3) %1, align 4 + ret float %2 +} + +define float @i8_gep_global_constexpr() { + ; CHECK-LABEL: define float @i8_gep_global_constexpr( + ; CHECK-NEXT: [[LOAD:%.*]] = load float, ptr addrspace(3) getelementptr inbounds nuw ([2 x float], ptr addrspace(3) @g, i32 0, i32 1), align 4 + ; CHECK-NEXT: ret float [[LOAD]] + %1 = load float, ptr addrspace(3) getelementptr inbounds nuw (i8, ptr addrspace(3) @g, i32 4), align 4 + ret float %1 +}