diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp index 06708cec00cec..61c5301ed5051 100644 --- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp +++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp @@ -14,11 +14,13 @@ #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/ReplaceConstant.h" #include "llvm/IR/Type.h" +#include "llvm/Support/Casting.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" @@ -127,71 +129,75 @@ bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) { } bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) { - unsigned NumOperands = LI.getNumOperands(); - for (unsigned I = 0; I < NumOperands; ++I) { - Value *CurrOpperand = LI.getOperand(I); - ConstantExpr *CE = dyn_cast(CurrOpperand); - if (CE && CE->getOpcode() == Instruction::GetElementPtr) { - GetElementPtrInst *OldGEP = - cast(CE->getAsInstruction()); - OldGEP->insertBefore(LI.getIterator()); - IRBuilder<> Builder(&LI); - LoadInst *NewLoad = - Builder.CreateLoad(LI.getType(), OldGEP, LI.getName()); - NewLoad->setAlignment(LI.getAlign()); - LI.replaceAllUsesWith(NewLoad); - LI.eraseFromParent(); - visitGetElementPtrInst(*OldGEP); - return true; - } - if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand)) - LI.setOperand(I, NewGlobal); + Value *PtrOperand = LI.getPointerOperand(); + ConstantExpr *CE = dyn_cast(PtrOperand); + if (CE && CE->getOpcode() == Instruction::GetElementPtr) { + GetElementPtrInst *OldGEP = cast(CE->getAsInstruction()); + OldGEP->insertBefore(LI.getIterator()); + IRBuilder<> Builder(&LI); + LoadInst *NewLoad = Builder.CreateLoad(LI.getType(), OldGEP, LI.getName()); + NewLoad->setAlignment(LI.getAlign()); + LI.replaceAllUsesWith(NewLoad); + LI.eraseFromParent(); + visitGetElementPtrInst(*OldGEP); + return true; } + if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand)) + LI.setOperand(LI.getPointerOperandIndex(), NewGlobal); return false; } bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) { - unsigned NumOperands = SI.getNumOperands(); - for (unsigned I = 0; I < NumOperands; ++I) { - Value *CurrOpperand = SI.getOperand(I); - ConstantExpr *CE = dyn_cast(CurrOpperand); - if (CE && CE->getOpcode() == Instruction::GetElementPtr) { - GetElementPtrInst *OldGEP = - cast(CE->getAsInstruction()); - OldGEP->insertBefore(SI.getIterator()); - IRBuilder<> Builder(&SI); - StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP); - NewStore->setAlignment(SI.getAlign()); - SI.replaceAllUsesWith(NewStore); - SI.eraseFromParent(); - visitGetElementPtrInst(*OldGEP); - return true; - } - if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand)) - SI.setOperand(I, NewGlobal); + + Value *PtrOperand = SI.getPointerOperand(); + ConstantExpr *CE = dyn_cast(PtrOperand); + if (CE && CE->getOpcode() == Instruction::GetElementPtr) { + GetElementPtrInst *OldGEP = cast(CE->getAsInstruction()); + OldGEP->insertBefore(SI.getIterator()); + IRBuilder<> Builder(&SI); + StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP); + NewStore->setAlignment(SI.getAlign()); + SI.replaceAllUsesWith(NewStore); + SI.eraseFromParent(); + visitGetElementPtrInst(*OldGEP); + return true; } + if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand)) + SI.setOperand(SI.getPointerOperandIndex(), NewGlobal); + return false; } bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) { - - unsigned NumOperands = GEPI.getNumOperands(); - GlobalVariable *NewGlobal = nullptr; - for (unsigned I = 0; I < NumOperands; ++I) { - Value *CurrOpperand = GEPI.getOperand(I); - NewGlobal = lookupReplacementGlobal(CurrOpperand); - if (NewGlobal) - break; + Value *PtrOperand = GEPI.getPointerOperand(); + Type *OrigGEPType = GEPI.getPointerOperandType(); + Type *NewGEPType = OrigGEPType; + bool NeedsTransform = false; + + if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand)) { + NewGEPType = NewGlobal->getValueType(); + PtrOperand = NewGlobal; + NeedsTransform = true; + } else if (AllocaInst *Alloca = dyn_cast(PtrOperand)) { + Type *AllocatedType = Alloca->getAllocatedType(); + // OrigGEPType might just be a pointer lets make sure + // to add the allocated type so we have a size + if (AllocatedType != OrigGEPType) { + NewGEPType = AllocatedType; + NeedsTransform = true; + } } - if (!NewGlobal) + + // Note: We bail if this isn't a gep touched via alloca or global + // transformations + if (!NeedsTransform) return false; IRBuilder<> Builder(&GEPI); SmallVector Indices(GEPI.indices()); - Value *NewGEP = - Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices, - GEPI.getName(), GEPI.getNoWrapFlags()); + Value *NewGEP = Builder.CreateGEP(NewGEPType, PtrOperand, Indices, + GEPI.getName(), GEPI.getNoWrapFlags()); GEPI.replaceAllUsesWith(NewGEP); GEPI.eraseFromParent(); return true; diff --git a/llvm/test/CodeGen/DirectX/scalarize-alloca.ll b/llvm/test/CodeGen/DirectX/scalarize-alloca.ll index 4829f3a31791f..b589136d6965c 100644 --- a/llvm/test/CodeGen/DirectX/scalarize-alloca.ll +++ b/llvm/test/CodeGen/DirectX/scalarize-alloca.ll @@ -1,10 +1,25 @@ -; RUN: opt -S -passes='dxil-data-scalarization' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefix=SCHECK -; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefix=FCHECK +; RUN: opt -S -passes='dxil-data-scalarization' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=SCHECK,CHECK +; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=FCHECK,CHECK ; CHECK-LABEL: alloca_2d__vec_test define void @alloca_2d__vec_test() local_unnamed_addr #2 { ; SCHECK: alloca [2 x [4 x i32]], align 16 ; FCHECK: alloca [8 x i32], align 16 + ; CHECK: ret void %1 = alloca [2 x <4 x i32>], align 16 ret void } + +; CHECK-LABEL: alloca_2d_gep_test +define void @alloca_2d_gep_test() { + ; SCHECK: [[alloca_val:%.*]] = alloca [2 x [2 x i32]], align 16 + ; FCHECK: [[alloca_val:%.*]] = alloca [4 x i32], align 16 + ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0) + ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [2 x [2 x i32]], ptr [[alloca_val]], i32 0, i32 [[tid]] + ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [4 x i32], ptr [[alloca_val]], i32 0, i32 [[tid]] + ; CHECK: ret void + %1 = alloca [2 x <2 x i32>], align 16 + %2 = tail call i32 @llvm.dx.thread.id(i32 0) + %3 = getelementptr inbounds nuw [2 x <2 x i32>], ptr %1, i32 0, i32 %2 + ret void +}