Skip to content

Commit 539f561

Browse files
admitricigcbot
authored andcommitted
Improve handling the remated instructions in CodeScheduling
CloneAddressArithmetic marks rematted instructions with metadata Use the metadata in RematChainsAnalysis pass to mark the patterns that are safe to consider in the scheduling. Use the estimation of the target instructions (because it's usually a load) in the RegisterPressureTracker of the scheduling and schedule the remat chain as a whole.
1 parent 64d0420 commit 539f561

File tree

10 files changed

+516
-36
lines changed

10 files changed

+516
-36
lines changed

IGC/Compiler/CISACodeGen/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ set(IGC_BUILD__SRC__CISACodeGen_Common
8484
"${CMAKE_CURRENT_SOURCE_DIR}/RayTracingStatefulPass.cpp"
8585
"${CMAKE_CURRENT_SOURCE_DIR}/RegisterEstimator.cpp"
8686
"${CMAKE_CURRENT_SOURCE_DIR}/RegisterPressureEstimate.cpp"
87+
"${CMAKE_CURRENT_SOURCE_DIR}/RematChainsAnalysis.cpp"
8788
"${CMAKE_CURRENT_SOURCE_DIR}/RemoveLoopDependency.cpp"
8889
"${CMAKE_CURRENT_SOURCE_DIR}/ResolvePredefinedConstant.cpp"
8990
"${CMAKE_CURRENT_SOURCE_DIR}/ResourceLoopAnalysis.cpp"
@@ -192,6 +193,7 @@ set(IGC_BUILD__HDR__CISACodeGen_Common
192193
"${CMAKE_CURRENT_SOURCE_DIR}/PushAnalysis.hpp"
193194
"${CMAKE_CURRENT_SOURCE_DIR}/RayTracingShaderLowering.hpp"
194195
"${CMAKE_CURRENT_SOURCE_DIR}/RayTracingStatefulPass.h"
196+
"${CMAKE_CURRENT_SOURCE_DIR}/RematChainsAnalysis.hpp"
195197
"${CMAKE_CURRENT_SOURCE_DIR}/RegisterEstimator.hpp"
196198
"${CMAKE_CURRENT_SOURCE_DIR}/RegisterPressureEstimate.hpp"
197199
"${CMAKE_CURRENT_SOURCE_DIR}/RemoveLoopDependency.hpp"

IGC/Compiler/CISACodeGen/CodeScheduling.cpp

Lines changed: 150 additions & 34 deletions
Large diffs are not rendered by default.

IGC/Compiler/CISACodeGen/CodeScheduling.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ SPDX-License-Identifier: MIT
1111
#include "Compiler/CISACodeGen/WIAnalysis.hpp"
1212
#include "Compiler/CISACodeGen/IGCLivenessAnalysis.h"
1313
#include "Compiler/CISACodeGen/VectorShuffleAnalysis.hpp"
14+
#include "Compiler/CISACodeGen/RematChainsAnalysis.hpp"
1415
#include "Compiler/CISACodeGen/TranslationTable.hpp"
1516
#include "Compiler/CodeGenContextWrapper.hpp"
1617
#include "Compiler/MetaDataUtilsWrapper.h"
@@ -28,6 +29,7 @@ class CodeScheduling : public llvm::FunctionPass {
2829
// llvm::LoopInfo* LI = nullptr;
2930
llvm::AliasAnalysis *AA = nullptr;
3031
VectorShuffleAnalysis *VSA = nullptr;
32+
RematChainsAnalysis *RCA = nullptr;
3133
WIAnalysisRunner *WI = nullptr;
3234
// IGCMD::MetaDataUtils* MDUtils = nullptr;
3335
IGCLivenessAnalysis *RPE = nullptr;
@@ -51,13 +53,15 @@ class CodeScheduling : public llvm::FunctionPass {
5153
AU.addRequired<IGCFunctionExternalRegPressureAnalysis>();
5254
AU.addRequired<CodeGenContextWrapper>();
5355
AU.addRequired<VectorShuffleAnalysis>();
56+
AU.addRequired<RematChainsAnalysis>();
5457

5558
// AU.addPreserved<llvm::DominatorTreeWrapperPass>();
5659
// AU.addPreserved<llvm::LoopInfoWrapperPass>();
5760
// AU.addPreserved<llvm::AAResultsWrapperPass>();
5861
AU.addPreserved<IGCLivenessAnalysis>();
5962
AU.addPreserved<IGCFunctionExternalRegPressureAnalysis>();
6063
AU.addPreserved<VectorShuffleAnalysis>();
64+
AU.addPreserved<RematChainsAnalysis>();
6165
}
6266

6367
private:

IGC/Compiler/CISACodeGen/RematAddressArithmetic.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,8 @@ void CloneAddressArithmetic::rematWholeChain(llvm::Instruction *I, RematChain &C
371371
Clone->setOperand(i, OldToNew[OldOp]);
372372
}
373373

374+
MDNode *Node = MDNode::get(I->getContext(), MDString::get(I->getContext(), "remat"));
375+
Clone->setMetadata("remat", Node);
374376
Clone->setName("remat");
375377
Clone->insertBefore(I);
376378
}
@@ -441,6 +443,8 @@ bool CloneAddressArithmetic::rematerialize(RematSet &ToProcess, unsigned int Flo
441443
PRINT_LOG(" --> ");
442444

443445
auto Clone = El->clone();
446+
MDNode *Node = MDNode::get(El->getContext(), MDString::get(El->getContext(), "remat"));
447+
Clone->setMetadata("remat", Node);
444448
Clone->setName("cloned_" + El->getName());
445449
Clone->insertBefore(UserInst);
446450
*Use = Clone;
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/*========================== begin_copyright_notice ============================
2+
3+
Copyright (C) 2025 Intel Corporation
4+
5+
SPDX-License-Identifier: MIT
6+
7+
============================= end_copyright_notice ===========================*/
8+
9+
#include "Compiler/CISACodeGen/RematChainsAnalysis.hpp"
10+
#include "Compiler/IGCPassSupport.h"
11+
#include "Compiler/CodeGenContextWrapper.hpp"
12+
#include "Compiler/CodeGenPublic.h"
13+
#include "common/debug/Debug.hpp"
14+
15+
#include "GenISAIntrinsics/GenIntrinsicInst.h"
16+
#include "llvmWrapper/IR/Function.h"
17+
#include "llvmWrapper/IR/Value.h"
18+
#include "llvmWrapper/IR/DerivedTypes.h"
19+
20+
using namespace llvm;
21+
using namespace IGC::Debug;
22+
23+
char RematChainsAnalysis::ID = 0;
24+
25+
// Register pass to igc-opt
26+
#define PASS_FLAG "igc-remat-chain-analysis"
27+
#define PASS_DESCRIPTION "Recognizes rematerialization chain patterns"
28+
#define PASS_CFG_ONLY false
29+
#define PASS_ANALYSIS true
30+
IGC_INITIALIZE_PASS_BEGIN(RematChainsAnalysis, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)
31+
IGC_INITIALIZE_PASS_END(RematChainsAnalysis, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)
32+
33+
RematChainsAnalysis::RematChainsAnalysis() : FunctionPass(ID) {
34+
initializeRematChainsAnalysisPass(*PassRegistry::getPassRegistry());
35+
};
36+
37+
static bool hasRematMetadata(llvm::Value *V) {
38+
if (auto *I = dyn_cast<Instruction>(V)) {
39+
if (auto *MD = I->getMetadata("remat")) {
40+
return true;
41+
}
42+
}
43+
return false;
44+
}
45+
46+
Value *getAddressOperand(llvm::Instruction *I) {
47+
if (!I) return nullptr;
48+
49+
// Check if the instruction is a Load or Store and return the address operand
50+
if (auto *LI = dyn_cast<LoadInst>(I)) {
51+
return LI->getPointerOperand();
52+
} else if (auto *SI = dyn_cast<StoreInst>(I)) {
53+
return SI->getPointerOperand();
54+
}
55+
56+
// If it's not a Load or Store, return nullptr
57+
return nullptr;
58+
}
59+
60+
// There are many passes that could change the remat chain,
61+
// So we consider not only the instructions that has remat metadata,
62+
// But also the instructions that have only one undroppable user,
63+
// which will be the load/store or the instruction that is rematerialized
64+
RematChainSet getRematChain(Value *V, Instruction *User) {
65+
if (!V) return {};
66+
Instruction *I = dyn_cast<Instruction>(V);
67+
if (!I) return {};
68+
if (!User) return {};
69+
70+
if (!isa<IntToPtrInst>(I) && !isa<AddrSpaceCastInst>(I)
71+
&& !isa<BitCastInst>(I) && !isa<GetElementPtrInst>(I)
72+
&& !isa<BinaryOperator>(I) && !isa<UnaryOperator>(I)) {
73+
return {};
74+
}
75+
76+
if (I->getParent() != User->getParent()) {
77+
return {};
78+
}
79+
80+
RematChainSet Chain;
81+
if ((IGCLLVM::getUniqueUndroppableUser(I) == User) || (hasRematMetadata(I))) {
82+
Chain.insert(I);
83+
84+
for (auto &Op : I->operands()) {
85+
Value *OpV = Op.get();
86+
if (auto *OpI = dyn_cast<Instruction>(OpV)) {
87+
auto SubChain = getRematChain(OpI, I);
88+
Chain.insert(SubChain.begin(), SubChain.end());
89+
}
90+
}
91+
}
92+
93+
return Chain;
94+
}
95+
96+
bool RematChainsAnalysis::runOnFunction(llvm::Function &F) {
97+
for (auto &BB : F) {
98+
for (Instruction &I : BB) {
99+
Value *AddrOperand = getAddressOperand(&I);
100+
if (!AddrOperand)
101+
continue;
102+
103+
Instruction *AI = dyn_cast<Instruction>(AddrOperand);
104+
if (!AI)
105+
continue;
106+
107+
RematChainSet Chain = getRematChain(AddrOperand, &I);
108+
109+
if (!Chain.empty()) {
110+
RematChainPatterns.push_back(std::make_unique<RematChainPattern>(Chain, AI, &I));
111+
for (auto *Inst : Chain) {
112+
ValueToRematChainMap[Inst] = RematChainPatterns.back().get();
113+
}
114+
}
115+
}
116+
for (Instruction &I : BB) {
117+
if (RematChainPattern *Pattern = getRematChainPattern(&I)) {
118+
if (Pattern->getFirstInst() == nullptr) {
119+
Pattern->setFirstInst(&I);
120+
}
121+
}
122+
}
123+
}
124+
125+
#ifdef _DEBUG
126+
for (const auto &Pattern : RematChainPatterns) {
127+
IGC_ASSERT(Pattern->getFirstInst() != nullptr && "Remat chain pattern must have a first instruction");
128+
IGC_ASSERT(Pattern->getLastInst() != nullptr && "Remat chain pattern must have a last instruction");
129+
IGC_ASSERT(Pattern->getRematTargetInst() != nullptr && "Remat chain pattern must have a remat target instruction");
130+
}
131+
#endif
132+
133+
return true;
134+
}
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*========================== begin_copyright_notice ============================
2+
3+
Copyright (C) 2025 Intel Corporation
4+
5+
SPDX-License-Identifier: MIT
6+
7+
============================= end_copyright_notice ===========================*/
8+
#pragma once
9+
10+
#include "Probe/Assertion.h"
11+
12+
#include "common/LLVMWarningsPush.hpp"
13+
#include <llvm/ADT/SmallPtrSet.h>
14+
#include <llvm/ADT/SmallVector.h>
15+
#include "common/LLVMWarningsPop.hpp"
16+
17+
#include "Compiler/CodeGenPublic.h"
18+
#include "Compiler/IGCPassSupport.h"
19+
20+
using namespace IGC;
21+
22+
namespace IGC {
23+
24+
typedef llvm::SmallSet<llvm::Instruction *, 8> RematChainSet;
25+
26+
class RematChainPattern {
27+
public:
28+
RematChainPattern(RematChainSet RematChain, llvm::Instruction *LastInst, llvm::Instruction *RematChainUser)
29+
: RematChain(RematChain), LastInstruction(LastInst), RematChainUser(RematChainUser) {
30+
IGC_ASSERT(!RematChain.empty() && "Remat chain cannot be empty");
31+
IGC_ASSERT(RematChainUser && "Remat chain user cannot be null");
32+
}
33+
34+
llvm::Instruction *getFirstInst() const {
35+
return FirstInstruction;
36+
}
37+
38+
llvm::Instruction *getLastInst() const {
39+
return LastInstruction;
40+
}
41+
42+
llvm::Instruction *getRematTargetInst() const {
43+
return RematChainUser;
44+
}
45+
46+
RematChainSet getRematChain() const {
47+
return RematChain;
48+
}
49+
50+
bool isRematInst(llvm::Value *V) const {
51+
if (auto *Inst = llvm::dyn_cast<llvm::Instruction>(V)) {
52+
return RematChain.contains(Inst);
53+
}
54+
return false;
55+
}
56+
57+
void setFirstInst(llvm::Instruction *Inst) {
58+
IGC_ASSERT(RematChain.count(Inst) && "First instruction must be part of the remat chain");
59+
FirstInstruction = Inst;
60+
}
61+
62+
private:
63+
RematChainSet RematChain;
64+
llvm::Instruction *LastInstruction;
65+
llvm::Instruction *FirstInstruction;
66+
llvm::Instruction *RematChainUser;
67+
};
68+
69+
class RematChainsAnalysis : public llvm::FunctionPass {
70+
public:
71+
static char ID;
72+
virtual llvm::StringRef getPassName() const override { return "RematChainsAnalysis"; };
73+
RematChainsAnalysis();
74+
virtual ~RematChainsAnalysis() {}
75+
RematChainsAnalysis(const RematChainsAnalysis &) = delete;
76+
virtual bool runOnFunction(llvm::Function &F) override;
77+
virtual void getAnalysisUsage(llvm::AnalysisUsage &AU) const override { AU.setPreservesAll(); }
78+
79+
RematChainPattern *getRematChainPattern(llvm::Value *V) {
80+
auto It = ValueToRematChainMap.find(V);
81+
if (It != ValueToRematChainMap.end())
82+
return It->second;
83+
return nullptr;
84+
}
85+
86+
private:
87+
std::vector<std::unique_ptr<RematChainPattern>> RematChainPatterns;
88+
llvm::DenseMap<llvm::Value *, RematChainPattern *> ValueToRematChainMap;
89+
};
90+
91+
}; // namespace IGC

IGC/Compiler/CISACodeGen/VectorShuffleAnalysis.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,19 @@ VectorShuffleAnalysis::tryCreatingDestVectorForVectorization(llvm::InsertElement
124124
return nullptr;
125125

126126
llvm::Value *Scalar = CurrentIE->getOperand(1);
127-
if (isa<ExtractElementInst>(Scalar))
128-
return nullptr;
129127

130128
if (!Scalar->getType()->isSingleValueType())
131129
return nullptr;
132130

131+
if (Instruction *EE = dyn_cast<ExtractElementInst>(Scalar)) {
132+
// allow only vector of 1 element
133+
Type *EEVectorType = EE->getOperand(0)->getType();
134+
auto EEVectorTypeVec = dyn_cast<IGCLLVM::FixedVectorType>(EEVectorType);
135+
IGC_ASSERT(EEVectorTypeVec);
136+
if (EEVectorTypeVec->getNumElements() != 1)
137+
return nullptr;
138+
}
139+
133140
ShuffleMask[IdxVal] = IdxVal;
134141
IEs.push_back(CurrentIE);
135142
Scalars.push_back(Scalar);

IGC/Compiler/InitializePasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ void initializeUnreachableHandlingPass(llvm::PassRegistry &);
263263
void initializeBreakdownIntrinsicPassPass(llvm::PassRegistry &);
264264
void initializeCatchAllLineNumberPass(llvm::PassRegistry &);
265265
void initializeDebugInfoPassPass(llvm::PassRegistry &);
266+
void initializeRematChainsAnalysisPass(llvm::PassRegistry &);
266267
void initializeVectorShuffleAnalysisPass(llvm::PassRegistry &);
267268
void initializeIGCLivenessAnalysisPass(llvm::PassRegistry &);
268269
void initializeIGCRegisterPressurePrinterPass(llvm::PassRegistry &);

0 commit comments

Comments
 (0)