diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp index feecfc0880e25..d507d71b99fc9 100644 --- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp +++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp @@ -343,9 +343,7 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) { GOp->replaceAllUsesWith(NewGEP); - if (auto *CE = dyn_cast(GOp)) - CE->destroyConstant(); - else if (auto *OldGEPI = dyn_cast(GOp)) + if (auto *OldGEPI = dyn_cast(GOp)) OldGEPI->eraseFromParent(); return true; diff --git a/llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp b/llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp index 13e3408815bba..aa16e795dc768 100644 --- a/llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp +++ b/llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp @@ -22,11 +22,13 @@ static bool finalizeLinkage(Module &M) { // Convert private globals and external globals with no usage to internal // linkage. - for (GlobalVariable &GV : M.globals()) + for (GlobalVariable &GV : M.globals()) { + GV.removeDeadConstantUsers(); if (GV.hasPrivateLinkage() || (GV.hasExternalLinkage() && GV.use_empty())) { GV.setLinkage(GlobalValue::InternalLinkage); MadeChange = true; } + } SmallVector Funcs; diff --git a/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp b/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp index 1d79c3018439e..bc1a3a7995bda 100644 --- a/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp +++ b/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp @@ -2113,7 +2113,7 @@ void DXILBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal, } break; case Instruction::GetElementPtr: { - Code = bitc::CST_CODE_CE_GEP; + Code = bitc::CST_CODE_CE_GEP_OLD; const auto *GO = cast(C); if (GO->isInBounds()) Code = bitc::CST_CODE_CE_INBOUNDS_GEP; diff --git a/llvm/lib/Target/DirectX/DirectXIRPasses/PointerTypeAnalysis.cpp b/llvm/lib/Target/DirectX/DirectXIRPasses/PointerTypeAnalysis.cpp index f99bb4f4eaee1..c2e139edc6bd1 100644 --- a/llvm/lib/Target/DirectX/DirectXIRPasses/PointerTypeAnalysis.cpp +++ b/llvm/lib/Target/DirectX/DirectXIRPasses/PointerTypeAnalysis.cpp @@ -15,25 +15,39 @@ #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" using namespace llvm; using namespace llvm::dxil; namespace { +Type *classifyFunctionType(const Function &F, PointerTypeMap &Map); + // Classifies the type of the value passed in by walking the value's users to // find a typed instruction to materialize a type from. Type *classifyPointerType(const Value *V, PointerTypeMap &Map) { assert(V->getType()->isPointerTy() && "classifyPointerType called with non-pointer"); + + // A CallInst will trigger this case, and we want to classify its Function + // operand as a Function rather than a generic Value. + if (const Function *F = dyn_cast(V)) + return classifyFunctionType(*F, Map); + + // There can potentially be dead constants hanging off of the globals we do + // not want to deal with. So we remove them here. + if (const GlobalVariable *GV = dyn_cast(V)) + GV->removeDeadConstantUsers(); + auto It = Map.find(V); if (It != Map.end()) return It->second; Type *PointeeTy = nullptr; - if (auto *Inst = dyn_cast(V)) { - if (!Inst->getResultElementType()->isPointerTy()) - PointeeTy = Inst->getResultElementType(); + if (auto *GEP = dyn_cast(V)) { + if (!GEP->getResultElementType()->isPointerTy()) + PointeeTy = GEP->getResultElementType(); } else if (auto *Inst = dyn_cast(V)) { PointeeTy = Inst->getAllocatedType(); } else if (auto *GV = dyn_cast(V)) { @@ -49,8 +63,8 @@ Type *classifyPointerType(const Value *V, PointerTypeMap &Map) { // When store value is ptr type, cannot get more type info. if (NewPointeeTy->isPointerTy()) continue; - } else if (const auto *Inst = dyn_cast(User)) { - NewPointeeTy = Inst->getSourceElementType(); + } else if (const auto *GEP = dyn_cast(User)) { + NewPointeeTy = GEP->getSourceElementType(); } if (NewPointeeTy) { // HLSL doesn't support pointers, so it is unlikely to get more than one @@ -204,6 +218,9 @@ PointerTypeMap PointerTypeAnalysis::run(const Module &M) { for (const auto &I : B) { if (I.getType()->isPointerTy()) classifyPointerType(&I, Map); + for (const auto &O : I.operands()) + if (O.get()->getType()->isPointerTy()) + classifyPointerType(O.get(), Map); } } } diff --git a/llvm/test/tools/dxil-dis/constantexpr-gep.ll b/llvm/test/tools/dxil-dis/constantexpr-gep.ll new file mode 100644 index 0000000000000..59251474f1a4b --- /dev/null +++ b/llvm/test/tools/dxil-dis/constantexpr-gep.ll @@ -0,0 +1,35 @@ +; RUN: llc --filetype=obj %s -o - | dxil-dis -o - | FileCheck %s +target triple = "dxil-unknown-shadermodel6.7-library" + +; CHECK: [[GLOBAL:@.*]] = unnamed_addr addrspace(3) global [10 x i32] zeroinitializer, align 4 +@g = local_unnamed_addr addrspace(3) global [10 x i32] zeroinitializer, align 4 + +define i32 @fn() #0 { +; CHECK-LABEL: define i32 @fn() +; CHECK-NEXT: [[LOAD:%.*]] = load i32, i32 addrspace(3)* getelementptr inbounds ([10 x i32], [10 x i32] addrspace(3)* [[GLOBAL]], i32 0, i32 1), align 4 +; CHECK-NEXT: ret i32 [[LOAD]] +; + %gep = getelementptr [10 x i32], ptr addrspace(3) @g, i32 0, i32 1 + %ld = load i32, ptr addrspace(3) %gep, align 4 + ret i32 %ld +} + +define i32 @fn2() #0 { +; CHECK-LABEL: define i32 @fn2() +; CHECK-NEXT: [[LOAD:%.*]] = load i32, i32 addrspace(3)* getelementptr inbounds ([10 x i32], [10 x i32] addrspace(3)* [[GLOBAL]], i32 0, i32 2), align 4 +; CHECK-NEXT: ret i32 [[LOAD]] +; + %ld = load i32, ptr addrspace(3) getelementptr ([10 x i32], ptr addrspace(3) @g, i32 0, i32 2), align 4 + ret i32 %ld +} + +define i32 @fn3() #0 { +; CHECK-LABEL: define i32 @fn3() +; CHECK-NEXT: [[LOAD:%.*]] = load i32, i32 addrspace(3)* getelementptr inbounds ([10 x i32], [10 x i32] addrspace(3)* [[GLOBAL]], i32 0, i32 3), align 4 +; CHECK-NEXT: ret i32 [[LOAD]] +; + %ld = load i32, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @g, i32 12), align 4 + ret i32 %ld +} + +attributes #0 = { "hlsl.export" } diff --git a/llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp b/llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp index 9d41e94bb0bae..6ae139e076281 100644 --- a/llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp +++ b/llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp @@ -8,6 +8,7 @@ #include "DirectXIRPasses/PointerTypeAnalysis.h" #include "llvm/AsmParser/Parser.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" @@ -123,6 +124,33 @@ TEST(PointerTypeAnalysis, DiscoverGEP) { EXPECT_THAT(Map, Contains(Pair(IsA(), I64Ptr))); } +TEST(PointerTypeAnalysis, DiscoverConstantExprGEP) { + StringRef Assembly = R"( + @g = internal global [10 x i32] zeroinitializer + define i32 @test() { + %i = load i32, ptr getelementptr ([10 x i32], ptr @g, i64 0, i64 1) + ret i32 %i + } + )"; + + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(Assembly, Error, Context); + ASSERT_TRUE(M) << "Bad assembly?"; + + PointerTypeMap Map = PointerTypeAnalysis::run(*M); + ASSERT_EQ(Map.size(), 3u); + Type *I32Ty = Type::getInt32Ty(Context); + Type *I32Ptr = TypedPointerType::get(I32Ty, 0); + Type *I32ArrPtr = TypedPointerType::get(ArrayType::get(I32Ty, 10), 0); + Type *FnTy = FunctionType::get(I32Ty, {}, false); + + EXPECT_THAT(Map, Contains(Pair(IsA(), I32ArrPtr))); + EXPECT_THAT(Map, + Contains(Pair(IsA(), TypedPointerType::get(FnTy, 0)))); + EXPECT_THAT(Map, Contains(Pair(IsA(), I32Ptr))); +} + TEST(PointerTypeAnalysis, TraceIndirect) { StringRef Assembly = R"( define i64 @test(ptr %p) {