Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions llvm/lib/Target/DirectX/DXILDataScalarization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,7 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {

GOp->replaceAllUsesWith(NewGEP);

if (auto *CE = dyn_cast<ConstantExpr>(GOp))
CE->destroyConstant();
else if (auto *OldGEPI = dyn_cast<GetElementPtrInst>(GOp))
if (auto *OldGEPI = dyn_cast<GetElementPtrInst>(GOp))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A little confused what is going on here. In the diff it looks like we always did OldGEPI->eraseFromParent(); was this previously defined somewhere above? Is the diff off? Why do we have to dyn_cast GOp? Shouldn't that already be a GEP?

Copy link
Contributor Author

@Icohedron Icohedron Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GOp is indeed a GEPOperator, but it could be a ConstantExpr GEP or an Instruction.
The original implementation here checked if GOp was a ConstantExpr or an Instruction and then called destroyConstant or eraseFromParent respectively.
The change/diff is that destroyConstant is no longer called on GOp if it is a ConstantExpr. But eraseFromParent will still be called if GOp is an Instruction.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I see this is clearer when you don't combine the +/- diffs

OldGEPI->eraseFromParent();

return true;
Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ static bool finalizeLinkage(Module &M) {

// Convert private globals and external globals with no usage to internal
// linkage.
for (GlobalVariable &GV : M.globals())
for (GlobalVariable &GV : M.globals()) {
GV.removeDeadConstantUsers();
if (GV.hasPrivateLinkage() || (GV.hasExternalLinkage() && GV.use_empty())) {
GV.setLinkage(GlobalValue::InternalLinkage);
MadeChange = true;
}
}

SmallVector<Function *> Funcs;

Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2113,7 +2113,7 @@ void DXILBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
}
break;
case Instruction::GetElementPtr: {
Code = bitc::CST_CODE_CE_GEP;
Code = bitc::CST_CODE_CE_GEP_OLD;
const auto *GO = cast<GEPOperator>(C);
if (GO->isInBounds())
Code = bitc::CST_CODE_CE_INBOUNDS_GEP;
Expand Down
27 changes: 22 additions & 5 deletions llvm/lib/Target/DirectX/DirectXIRPasses/PointerTypeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,39 @@
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Operator.h"

using namespace llvm;
using namespace llvm::dxil;

namespace {

Type *classifyFunctionType(const Function &F, PointerTypeMap &Map);

// Classifies the type of the value passed in by walking the value's users to
// find a typed instruction to materialize a type from.
Type *classifyPointerType(const Value *V, PointerTypeMap &Map) {
assert(V->getType()->isPointerTy() &&
"classifyPointerType called with non-pointer");

// A CallInst will trigger this case, and we want to classify its Function
// operand as a Function rather than a generic Value.
if (const Function *F = dyn_cast<Function>(V))
return classifyFunctionType(*F, Map);

// There can potentially be dead constants hanging off of the globals we do
// not want to deal with. So we remove them here.
if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
GV->removeDeadConstantUsers();

auto It = Map.find(V);
if (It != Map.end())
return It->second;

Type *PointeeTy = nullptr;
if (auto *Inst = dyn_cast<GetElementPtrInst>(V)) {
if (!Inst->getResultElementType()->isPointerTy())
PointeeTy = Inst->getResultElementType();
if (auto *GEP = dyn_cast<GEPOperator>(V)) {
if (!GEP->getResultElementType()->isPointerTy())
PointeeTy = GEP->getResultElementType();
} else if (auto *Inst = dyn_cast<AllocaInst>(V)) {
PointeeTy = Inst->getAllocatedType();
} else if (auto *GV = dyn_cast<GlobalVariable>(V)) {
Expand All @@ -49,8 +63,8 @@ Type *classifyPointerType(const Value *V, PointerTypeMap &Map) {
// When store value is ptr type, cannot get more type info.
if (NewPointeeTy->isPointerTy())
continue;
} else if (const auto *Inst = dyn_cast<GetElementPtrInst>(User)) {
NewPointeeTy = Inst->getSourceElementType();
} else if (const auto *GEP = dyn_cast<GEPOperator>(User)) {
NewPointeeTy = GEP->getSourceElementType();
}
if (NewPointeeTy) {
// HLSL doesn't support pointers, so it is unlikely to get more than one
Expand Down Expand Up @@ -204,6 +218,9 @@ PointerTypeMap PointerTypeAnalysis::run(const Module &M) {
for (const auto &I : B) {
if (I.getType()->isPointerTy())
classifyPointerType(&I, Map);
for (const auto &O : I.operands())
if (O.get()->getType()->isPointerTy())
classifyPointerType(O.get(), Map);
}
}
}
Expand Down
26 changes: 26 additions & 0 deletions llvm/test/tools/dxil-dis/constantexpr-gep.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
; RUN: llc --filetype=obj %s -o - | dxil-dis -o - | FileCheck %s
target triple = "dxil-unknown-shadermodel6.7-library"

; CHECK: [[GLOBAL:@.*]] = unnamed_addr addrspace(3) global [10 x i32] zeroinitializer, align 4
@g = local_unnamed_addr addrspace(3) global [10 x i32] zeroinitializer, align 4

define i32 @fn() #0 {
; CHECK-LABEL: define i32 @fn()
; CHECK-NEXT: [[LOAD:%.*]] = load i32, i32 addrspace(3)* getelementptr inbounds ([10 x i32], [10 x i32] addrspace(3)* [[GLOBAL]], i32 0, i32 1), align 4
; CHECK-NEXT: ret i32 [[LOAD]]
;
%gep = getelementptr [10 x i32], ptr addrspace(3) @g, i32 0, i32 1
%ld = load i32, ptr addrspace(3) %gep, align 4
ret i32 %ld
}

define i32 @fn2() #0 {
; CHECK-LABEL: define i32 @fn2()
; CHECK-NEXT: [[LOAD:%.*]] = load i32, i32 addrspace(3)* getelementptr inbounds ([10 x i32], [10 x i32] addrspace(3)* [[GLOBAL]], i32 0, i32 2), align 4
; CHECK-NEXT: ret i32 [[LOAD]]
;
%ld = load i32, ptr addrspace(3) getelementptr ([10 x i32], ptr addrspace(3) @g, i32 0, i32 2), align 4
ret i32 %ld
}

attributes #0 = { "hlsl.export" }
28 changes: 28 additions & 0 deletions llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "DirectXIRPasses/PointerTypeAnalysis.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
Expand Down Expand Up @@ -123,6 +124,33 @@ TEST(PointerTypeAnalysis, DiscoverGEP) {
EXPECT_THAT(Map, Contains(Pair(IsA<GetElementPtrInst>(), I64Ptr)));
}

TEST(PointerTypeAnalysis, DiscoverConstantExprGEP) {
StringRef Assembly = R"(
@g = internal global [10 x i32] zeroinitializer
define i32 @test() {
%i = load i32, ptr getelementptr ([10 x i32], ptr @g, i64 0, i64 1)
ret i32 %i
}
)";

LLVMContext Context;
SMDiagnostic Error;
auto M = parseAssemblyString(Assembly, Error, Context);
ASSERT_TRUE(M) << "Bad assembly?";

PointerTypeMap Map = PointerTypeAnalysis::run(*M);
ASSERT_EQ(Map.size(), 3u);
Type *I32Ty = Type::getInt32Ty(Context);
Type *I32Ptr = TypedPointerType::get(I32Ty, 0);
Type *I32ArrPtr = TypedPointerType::get(ArrayType::get(I32Ty, 10), 0);
Type *FnTy = FunctionType::get(I32Ty, {}, false);

EXPECT_THAT(Map, Contains(Pair(IsA<GlobalVariable>(), I32ArrPtr)));
EXPECT_THAT(Map,
Contains(Pair(IsA<Function>(), TypedPointerType::get(FnTy, 0))));
EXPECT_THAT(Map, Contains(Pair(IsA<ConstantExpr>(), I32Ptr)));
}

TEST(PointerTypeAnalysis, TraceIndirect) {
StringRef Assembly = R"(
define i64 @test(ptr %p) {
Expand Down