Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 201 additions & 0 deletions llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@
#include "llvm/IR/Function.h"
#include "llvm/IR/GetElementPtrTypeIterator.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
Expand All @@ -190,6 +191,7 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
#include <cassert>
#include <cstdint>
Expand All @@ -198,6 +200,8 @@
using namespace llvm;
using namespace llvm::PatternMatch;

#define DEBUG_TYPE "separate-offset-gep"

static cl::opt<bool> DisableSeparateConstOffsetFromGEP(
"disable-separate-const-offset-from-gep", cl::init(false),
cl::desc("Do not separate the constant offset from a GEP instruction"),
Expand Down Expand Up @@ -486,6 +490,39 @@ class SeparateConstOffsetFromGEP {
DenseMap<ExprKey, SmallVector<Instruction *, 2>> DominatingSubs;
};

/// A helper class that aims to convert xor operations into or operations when
/// their operands are disjoint and the result is used in a GEP's index. This
/// can then enable further GEP optimizations by effectively turning BaseVal |
/// Const into BaseVal + Const when they are disjoint, which
/// SeparateConstOffsetFromGEP can then process. This is a common pattern that
/// sets up a grid of memory accesses across a wave where each thread acesses
/// data at various offsets.
class XorToOrDisjointTransformer {
public:
XorToOrDisjointTransformer(Function &F, DominatorTree &DT,
const DataLayout &DL)
: F(F), DT(DT), DL(DL) {}

bool run();

private:
Function &F;
DominatorTree &DT;
const DataLayout &DL;
/// Maps a common operand to all Xor instructions
using XorOpList = SmallVector<std::pair<BinaryOperator *, APInt>, 8>;
using XorBaseValMap = DenseMap<Value *, XorOpList>;
XorBaseValMap XorGroups;

/// Checks if the given value has at least one GetElementPtr user
bool hasGEPUser(const Value *V) const;

/// Processes a group of XOR instructions that share the same non-constant
/// base operand. Returns true if this group's processing modified the
/// function.
bool processXorGroup(Value *OriginalBaseVal, XorOpList &XorsInGroup);
};

} // end anonymous namespace

char SeparateConstOffsetFromGEPLegacyPass::ID = 0;
Expand Down Expand Up @@ -1162,6 +1199,165 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) {
return true;
}

// Helper function to check if an instruction has at least one GEP user
bool XorToOrDisjointTransformer::hasGEPUser(const Value *V) const {
for (const User *U : V->users()) {
if (isa<GetElementPtrInst>(U)) {
return true;
}
}
return false;
}

bool XorToOrDisjointTransformer::processXorGroup(Value *OriginalBaseVal,
XorOpList &XorsInGroup) {
bool Changed = false;
if (XorsInGroup.size() <= 1)
return false;

// Sort XorsInGroup by the constant offset value in increasing order.
llvm::sort(
XorsInGroup.begin(), XorsInGroup.end(),
[](const auto &A, const auto &B) { return A.second.ult(B.second); });

// Dominance check
// The "base" XOR for dominance purposes is the one with the smallest
// constant.
BinaryOperator *XorWithSmallConst = XorsInGroup[0].first;

for (size_t i = 1; i < XorsInGroup.size(); ++i) {
BinaryOperator *currentXorToProcess = XorsInGroup[i].first;

// Check if the XorWithSmallConst dominates currentXorToProcess.
// If not, clone and add the instruction.
if (!DT.dominates(XorWithSmallConst, currentXorToProcess)) {
LLVM_DEBUG(
dbgs() << DEBUG_TYPE
<< ": Cloning and inserting XOR with smallest constant ("
<< *XorWithSmallConst << ") as it does not dominate "
<< *currentXorToProcess << " in function " << F.getName()
<< "\n");

BinaryOperator *ClonedXor =
cast<BinaryOperator>(XorWithSmallConst->clone());
ClonedXor->setName(XorWithSmallConst->getName() + ".dom_clone");
ClonedXor->insertAfter(dyn_cast<Instruction>(OriginalBaseVal));
LLVM_DEBUG(dbgs() << " Cloned Inst: " << *ClonedXor << "\n");
Changed = true;
XorWithSmallConst = ClonedXor;
break;
}
}

SmallVector<Instruction *, 8> InstructionsToErase;
const APInt SmallestConst =
dyn_cast<ConstantInt>(XorWithSmallConst->getOperand(1))->getValue();

// Main transformation loop: Iterate over the original XORs in the sorted
// group.
for (const auto &XorEntry : XorsInGroup) {
BinaryOperator *XorInst = XorEntry.first; // Original XOR instruction
const APInt ConstOffsetVal = XorEntry.second;

// Do not process the one with smallest constant as it is the base.
if (XorInst == XorWithSmallConst)
continue;

// Disjointness Check 1
APInt NewConstVal = ConstOffsetVal - SmallestConst;
if ((NewConstVal & SmallestConst) != 0) {
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Cannot transform XOR in function "
<< F.getName() << ":\n"
<< " New Const: " << NewConstVal << "\n"
<< " Smallest Const: " << SmallestConst << "\n"
<< " are not disjoint \n");
continue;
}

// Disjointness Check 2
KnownBits KnownBaseBits(
XorWithSmallConst->getType()->getScalarSizeInBits());
computeKnownBits(XorWithSmallConst, KnownBaseBits, DL, 0, nullptr,
XorWithSmallConst, &DT);
if ((KnownBaseBits.Zero & NewConstVal) == NewConstVal) {
LLVM_DEBUG(dbgs() << DEBUG_TYPE
<< ": Transforming XOR to OR (disjoint) in function "
<< F.getName() << ":\n"
<< " Xor: " << *XorInst << "\n"
<< " Base Val: " << *XorWithSmallConst << "\n"
<< " New Const: " << NewConstVal << "\n");

auto *NewOrInst = BinaryOperator::CreateDisjointOr(
XorWithSmallConst,
ConstantInt::get(OriginalBaseVal->getType(), NewConstVal),
XorInst->getName() + ".or_disjoint", XorInst->getIterator());

NewOrInst->copyMetadata(*XorInst);
XorInst->replaceAllUsesWith(NewOrInst);
LLVM_DEBUG(dbgs() << " New Inst: " << *NewOrInst << "\n");
InstructionsToErase.push_back(XorInst); // Mark original XOR for deletion

Changed = true;
} else {
LLVM_DEBUG(
dbgs() << DEBUG_TYPE
<< ": Cannot transform XOR (not proven disjoint) in function "
<< F.getName() << ":\n"
<< " Xor: " << *XorInst << "\n"
<< " Base Val: " << *XorWithSmallConst << "\n"
<< " New Const: " << NewConstVal << "\n");
}
}
if (!InstructionsToErase.empty())
for (Instruction *I : InstructionsToErase)
I->eraseFromParent();

return Changed;
}

// Try to transform XOR(A, B+C) in to XOR(A,C) + B where XOR(A,C) becomes
// the base for memory operations. This transformation is true under the
// following conditions
// Check 1 - B and C are disjoint.
// Check 2 - XOR(A,C) and B are disjoint.
//
// This transformation is beneficial particularly for GEPs because:
// 1. OR operations often map better to addressing modes than XOR
// 2. Disjoint OR operations preserve the semantics of the original XOR
// 3. This can enable further optimizations in the GEP offset folding pipeline
bool XorToOrDisjointTransformer::run() {
bool Changed = false;

// Collect all candidate XORs
for (Instruction &I : instructions(F)) {
if (auto *XorOp = dyn_cast<BinaryOperator>(&I)) {
if (XorOp->getOpcode() == Instruction::Xor) {
Value *Op0 = XorOp->getOperand(0);
ConstantInt *C1 = nullptr;
// Match: xor Op0, Constant
if (match(XorOp->getOperand(1), m_ConstantInt(C1))) {
if (hasGEPUser(XorOp)) {
XorGroups[Op0].push_back({XorOp, C1->getValue()});
}
}
}
}
}

if (XorGroups.empty())
return false;

// Process each group of XORs
for (auto &GroupPair : XorGroups) {
Value *OriginalBaseVal = GroupPair.first;
XorOpList &XorsInGroup = GroupPair.second;
if (processXorGroup(OriginalBaseVal, XorsInGroup))
Changed = true;
}

return Changed;
}

bool SeparateConstOffsetFromGEPLegacyPass::runOnFunction(Function &F) {
if (skipFunction(F))
return false;
Expand All @@ -1181,6 +1377,11 @@ bool SeparateConstOffsetFromGEP::run(Function &F) {

DL = &F.getDataLayout();
bool Changed = false;

// Decompose xor in to "or disjoint" if possible.
XorToOrDisjointTransformer XorTransformer(F, *DT, *DL);
Changed |= XorTransformer.run();

for (BasicBlock &B : F) {
if (!DT->isReachableFromEntry(&B))
continue;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -mtriple=amdgcn-amd-amdhsa -passes=separate-const-offset-from-gep \
; RUN: -S < %s | FileCheck %s


; Test a simple case of xor to or disjoint transformation
define void @test_basic_transformation(ptr %ptr, i64 %input) {
; CHECK-LABEL: define void @test_basic_transformation(
; CHECK-SAME: ptr [[PTR:%.*]], i64 [[INPUT:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[BASE:%.*]] = and i64 [[INPUT]], -8192
; CHECK-NEXT: [[ADDR1:%.*]] = xor i64 [[BASE]], 32
; CHECK-NEXT: [[ADDR2_OR_DISJOINT:%.*]] = or disjoint i64 [[ADDR1]], 2048
; CHECK-NEXT: [[ADDR3_OR_DISJOINT:%.*]] = or disjoint i64 [[ADDR1]], 4096
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR1]]
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR2_OR_DISJOINT]]
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR3_OR_DISJOINT]]
; CHECK-NEXT: [[VAL1:%.*]] = load half, ptr [[GEP1]], align 2
; CHECK-NEXT: [[VAL2:%.*]] = load half, ptr [[GEP2]], align 2
; CHECK-NEXT: [[VAL3:%.*]] = load half, ptr [[GEP3]], align 2
; CHECK-NEXT: ret void
;
entry:
%base = and i64 %input, -8192 ; Clear low bits
%addr1 = xor i64 %base, 32
%addr2 = xor i64 %base, 2080
%addr3 = xor i64 %base, 4128
%gep1 = getelementptr i8, ptr %ptr, i64 %addr1
%gep2 = getelementptr i8, ptr %ptr, i64 %addr2
%gep3 = getelementptr i8, ptr %ptr, i64 %addr3
%val1 = load half, ptr %gep1
%val2 = load half, ptr %gep2
%val3 = load half, ptr %gep3
ret void
}


; Test the decreasing order of offset xor to or disjoint transformation
define void @test_descending_offset_transformation(ptr %ptr, i64 %input) {
; CHECK-LABEL: define void @test_descending_offset_transformation(
; CHECK-SAME: ptr [[PTR:%.*]], i64 [[INPUT:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[BASE:%.*]] = and i64 [[INPUT]], -8192
; CHECK-NEXT: [[ADDR3_DOM_CLONE:%.*]] = xor i64 [[BASE]], 32
; CHECK-NEXT: [[ADDR1_OR_DISJOINT:%.*]] = or disjoint i64 [[ADDR3_DOM_CLONE]], 4096
; CHECK-NEXT: [[ADDR2_OR_DISJOINT:%.*]] = or disjoint i64 [[ADDR3_DOM_CLONE]], 2048
; CHECK-NEXT: [[ADDR3_OR_DISJOINT:%.*]] = or disjoint i64 [[ADDR3_DOM_CLONE]], 0
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR1_OR_DISJOINT]]
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR2_OR_DISJOINT]]
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR3_OR_DISJOINT]]
; CHECK-NEXT: [[VAL1:%.*]] = load half, ptr [[GEP1]], align 2
; CHECK-NEXT: [[VAL2:%.*]] = load half, ptr [[GEP2]], align 2
; CHECK-NEXT: [[VAL3:%.*]] = load half, ptr [[GEP3]], align 2
; CHECK-NEXT: ret void
;
entry:
%base = and i64 %input, -8192 ; Clear low bits
%addr1 = xor i64 %base, 4128
%addr2 = xor i64 %base, 2080
%addr3 = xor i64 %base, 32
%gep1 = getelementptr i8, ptr %ptr, i64 %addr1
%gep2 = getelementptr i8, ptr %ptr, i64 %addr2
%gep3 = getelementptr i8, ptr %ptr, i64 %addr3
%val1 = load half, ptr %gep1
%val2 = load half, ptr %gep2
%val3 = load half, ptr %gep3
ret void
}


; Test that %addr2 is not transformed to or disjoint.
define void @test_no_transfomation(ptr %ptr, i64 %input) {
; CHECK-LABEL: define void @test_no_transfomation(
; CHECK-SAME: ptr [[PTR:%.*]], i64 [[INPUT:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[BASE:%.*]] = and i64 [[INPUT]], -8192
; CHECK-NEXT: [[ADDR1:%.*]] = xor i64 [[BASE]], 32
; CHECK-NEXT: [[ADDR2:%.*]] = xor i64 [[BASE]], 64
; CHECK-NEXT: [[ADDR3_OR_DISJOINT:%.*]] = or disjoint i64 [[ADDR1]], 2048
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR1]]
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR2]]
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR3_OR_DISJOINT]]
; CHECK-NEXT: [[VAL1:%.*]] = load half, ptr [[GEP1]], align 2
; CHECK-NEXT: [[VAL2:%.*]] = load half, ptr [[GEP2]], align 2
; CHECK-NEXT: [[VAL3:%.*]] = load half, ptr [[GEP3]], align 2
; CHECK-NEXT: ret void
;
entry:
%base = and i64 %input, -8192 ; Clear low bits
%addr1 = xor i64 %base, 32
%addr2 = xor i64 %base, 64 ; Should not be transformed
%addr3 = xor i64 %base, 2080
%gep1 = getelementptr i8, ptr %ptr, i64 %addr1
%gep2 = getelementptr i8, ptr %ptr, i64 %addr2
%gep3 = getelementptr i8, ptr %ptr, i64 %addr3
%val1 = load half, ptr %gep1
%val2 = load half, ptr %gep2
%val3 = load half, ptr %gep3
ret void
}


; Test case with xor instructions in different basic blocks
define void @test_dom_tree(ptr %ptr, i64 %input, i1 %cond) {
; CHECK-LABEL: define void @test_dom_tree(
; CHECK-SAME: ptr [[PTR:%.*]], i64 [[INPUT:%.*]], i1 [[COND:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[BASE:%.*]] = and i64 [[INPUT]], -8192
; CHECK-NEXT: [[ADDR1:%.*]] = xor i64 [[BASE]], 16
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR1]]
; CHECK-NEXT: [[VAL1:%.*]] = load half, ptr [[GEP1]], align 2
; CHECK-NEXT: br i1 [[COND]], label %[[THEN:.*]], label %[[ELSE:.*]]
; CHECK: [[THEN]]:
; CHECK-NEXT: [[ADDR2_OR_DISJOINT:%.*]] = or disjoint i64 [[ADDR1]], 32
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR2_OR_DISJOINT]]
; CHECK-NEXT: [[VAL2:%.*]] = load half, ptr [[GEP2]], align 2
; CHECK-NEXT: br label %[[MERGE:.*]]
; CHECK: [[ELSE]]:
; CHECK-NEXT: [[ADDR3_OR_DISJOINT:%.*]] = or disjoint i64 [[ADDR1]], 96
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR3_OR_DISJOINT]]
; CHECK-NEXT: [[VAL3:%.*]] = load half, ptr [[GEP3]], align 2
; CHECK-NEXT: br label %[[MERGE]]
; CHECK: [[MERGE]]:
; CHECK-NEXT: [[ADDR4_OR_DISJOINT:%.*]] = or disjoint i64 [[ADDR1]], 224
; CHECK-NEXT: [[GEP4:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR4_OR_DISJOINT]]
; CHECK-NEXT: [[VAL4:%.*]] = load half, ptr [[GEP4]], align 2
; CHECK-NEXT: ret void
;
entry:
%base = and i64 %input, -8192 ; Clear low bits
%addr1 = xor i64 %base,16
%gep1 = getelementptr i8, ptr %ptr, i64 %addr1
%val1 = load half, ptr %gep1
br i1 %cond, label %then, label %else

then:
%addr2 = xor i64 %base, 48
%gep2 = getelementptr i8, ptr %ptr, i64 %addr2
%val2 = load half, ptr %gep2
br label %merge

else:
%addr3 = xor i64 %base, 112
%gep3 = getelementptr i8, ptr %ptr, i64 %addr3
%val3 = load half, ptr %gep3
br label %merge

merge:
%addr4 = xor i64 %base, 240
%gep4 = getelementptr i8, ptr %ptr, i64 %addr4
%val4 = load half, ptr %gep4
ret void
}