Skip to content

Commit f39d9f8

Browse files
author
chenqian
committed
[Pass] add RISCVESP32P4ConditionSplit Pass
1 parent e6715b2 commit f39d9f8

File tree

5 files changed

+593
-206
lines changed

5 files changed

+593
-206
lines changed

llvm/lib/Target/RISCV/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ add_llvm_target(RISCVCodeGen
4343
RISCVSplitLoopByLength.cpp
4444
RISCVCustomLICM.cpp
4545
RISCVLoopUnrollAndRemainder.cpp
46+
RISCVESP32P4ConditionSplit.cpp
4647
RISCVEsp32P4MemIntrin.cpp
4748
RISCVESP32P4LoopVectorizeExtractor.cpp
4849
RISCVIndirectBranchTracking.cpp
Lines changed: 354 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
//===- RISCVESP32P4ConditionSplit.cpp - Condition Split Pass -*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
///
9+
/// \file
10+
/// This file implements the RISCVESP32P4ConditionSplit pass.
11+
///
12+
/// This pass splits right-shift branches in matrix multiplication functions
13+
/// to create separate paths for SIMD-optimizable cases (k % 8 == 0) and
14+
/// scalar fallback cases. This enables subsequent SIMD optimization passes
15+
/// to target the aligned path specifically.
16+
///
17+
/// Transformation:
18+
/// if (final_shift <= 0) { /* right shift */ }
19+
/// =>
20+
/// if (final_shift <= 0) {
21+
/// if (k % 8 == 0) { /* SIMD path */ }
22+
/// else { /* scalar path */ }
23+
/// }
24+
///
25+
//===----------------------------------------------------------------------===//
26+
27+
#include "RISCVESP32P4ConditionSplit.h"
28+
29+
#include "llvm/Analysis/LoopInfo.h"
30+
#include "llvm/IR/BasicBlock.h"
31+
#include "llvm/IR/Constants.h"
32+
#include "llvm/IR/Function.h"
33+
#include "llvm/IR/IRBuilder.h"
34+
#include "llvm/IR/Instructions.h"
35+
#include "llvm/IR/PatternMatch.h"
36+
#include "llvm/Support/Debug.h"
37+
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
38+
#include "llvm/Transforms/Utils/Cloning.h"
39+
40+
using namespace llvm;
41+
using namespace llvm::PatternMatch;
42+
43+
#define DEBUG_TYPE "riscv-esp32p4-condition-split"
44+
45+
// Command line option to enable/disable RISCVESP32P4ConditionSplit
46+
cl::opt<bool> llvm::EnableRISCVESP32P4ConditionSplit(
47+
"riscv-esp32p4-condition-split", cl::init(false),
48+
cl::desc("Enable RISC-V ESP32-P4 condition split for matrix functions"));
49+
50+
namespace {
51+
52+
/// Check if the function has a triple-nested loop structure.
53+
/// This is used as additional validation for matrix multiplication patterns.
54+
static bool hasTripleNestedLoopStructure(const Function &F, LoopInfo &LI) {
55+
LLVM_DEBUG(dbgs() << "Analyzing loop structure for function: " << F.getName()
56+
<< "\n");
57+
58+
unsigned MaxDepth = 0;
59+
60+
// Helper lambda to recursively find maximum loop depth
61+
std::function<void(const Loop *)> visitLoop = [&](const Loop *L) {
62+
MaxDepth = std::max(MaxDepth, L->getLoopDepth());
63+
for (const Loop *SubLoop : L->getSubLoops())
64+
visitLoop(SubLoop);
65+
};
66+
67+
// Visit all top-level loops
68+
for (const Loop *L : LI)
69+
visitLoop(L);
70+
71+
LLVM_DEBUG(dbgs() << "Maximum loop depth found: " << MaxDepth << "\n");
72+
return MaxDepth >= 3;
73+
}
74+
75+
/// Find the alloca instruction for the 'k' variable in matrix multiplication.
76+
/// Returns nullptr if not found.
77+
static AllocaInst *findKVariableAlloca(Function &F) {
78+
for (BasicBlock &BB : F) {
79+
for (Instruction &I : BB) {
80+
if (auto *AI = dyn_cast<AllocaInst>(&I)) {
81+
if (AI->getName().contains("k.addr")) {
82+
LLVM_DEBUG(dbgs() << "Found k variable alloca: " << *AI << "\n");
83+
return AI;
84+
}
85+
}
86+
}
87+
}
88+
89+
LLVM_DEBUG(dbgs() << "No k variable alloca found\n");
90+
return nullptr;
91+
}
92+
93+
/// Find the comparison instruction for 'final_shift > 0' pattern.
94+
/// This identifies the branch we want to split for SIMD optimization.
95+
static ICmpInst *findFinalShiftComparison(Function &F) {
96+
for (BasicBlock &BB : F) {
97+
for (Instruction &I : BB) {
98+
auto *CmpI = dyn_cast<ICmpInst>(&I);
99+
if (!CmpI || CmpI->getPredicate() != ICmpInst::ICMP_SGT)
100+
continue;
101+
102+
// Check if this compares final_shift against zero
103+
auto *LoadI = dyn_cast<LoadInst>(CmpI->getOperand(0));
104+
auto *ZeroConst = dyn_cast<ConstantInt>(CmpI->getOperand(1));
105+
106+
if (!LoadI || !ZeroConst || !ZeroConst->isZero())
107+
continue;
108+
109+
auto *AI = dyn_cast<AllocaInst>(LoadI->getPointerOperand());
110+
if (AI && AI->getName().contains("final_shift")) {
111+
LLVM_DEBUG(dbgs() << "Found final_shift comparison: " << *CmpI << "\n");
112+
return CmpI;
113+
}
114+
}
115+
}
116+
117+
LLVM_DEBUG(dbgs() << "No final_shift comparison found\n");
118+
return nullptr;
119+
}
120+
121+
/// Find the branch instruction that uses the given comparison.
122+
static BranchInst *findConditionalBranch(ICmpInst *CmpInst) {
123+
for (User *U : CmpInst->users()) {
124+
if (auto *BI = dyn_cast<BranchInst>(U)) {
125+
if (BI->isConditional()) {
126+
LLVM_DEBUG(dbgs() << "Found conditional branch: " << *BI << "\n");
127+
return BI;
128+
}
129+
}
130+
}
131+
132+
LLVM_DEBUG(dbgs() << "No conditional branch found for comparison\n");
133+
return nullptr;
134+
}
135+
136+
/// Create the k % 8 == 0 alignment check in the given basic block.
137+
static Value *createAlignmentCheck(IRBuilder<> &Builder, AllocaInst *KAddr) {
138+
// Load k value
139+
LoadInst *KLoad = Builder.CreateLoad(Builder.getInt32Ty(), KAddr, "k.val");
140+
141+
// Calculate k % 8
142+
Value *EightConst = ConstantInt::get(Builder.getInt32Ty(), 8);
143+
Value *RemainderVal = Builder.CreateSRem(KLoad, EightConst, "k.rem8");
144+
145+
// Check k % 8 == 0
146+
Value *ZeroConst = ConstantInt::get(Builder.getInt32Ty(), 0);
147+
Value *IsAligned =
148+
Builder.CreateICmpEQ(RemainderVal, ZeroConst, "k.is_aligned");
149+
150+
LLVM_DEBUG(dbgs() << "Created alignment check: k % 8 == 0\n");
151+
return IsAligned;
152+
}
153+
154+
/// Clone instructions from source block to target block, excluding terminators.
155+
/// Updates the provided value map for use relationship fixing.
156+
static void cloneInstructionsToBlock(BasicBlock *SourceBB, BasicBlock *TargetBB,
157+
ValueToValueMapTy &VMap) {
158+
IRBuilder<> Builder(TargetBB);
159+
160+
for (Instruction &I : *SourceBB) {
161+
if (I.isTerminator())
162+
continue;
163+
164+
Instruction *ClonedInst = I.clone();
165+
Builder.Insert(ClonedInst);
166+
VMap[&I] = ClonedInst;
167+
}
168+
}
169+
170+
/// Fix operand references in the cloned instructions using the value map.
171+
static void fixClonedInstructionOperands(BasicBlock *BB,
172+
const ValueToValueMapTy &VMap) {
173+
for (Instruction &I : *BB) {
174+
for (unsigned Idx = 0, E = I.getNumOperands(); Idx != E; ++Idx) {
175+
Value *Op = I.getOperand(Idx);
176+
auto It = VMap.find(Op);
177+
if (It != VMap.end())
178+
I.setOperand(Idx, It->second);
179+
}
180+
}
181+
}
182+
183+
/// Update PHI nodes in the successor block to use the new predecessor.
184+
static void updateSuccessorPhiNodes(BasicBlock *SuccessorBB,
185+
BasicBlock *OldPred, BasicBlock *NewPred) {
186+
for (Instruction &I : *SuccessorBB) {
187+
auto *PHI = dyn_cast<PHINode>(&I);
188+
if (!PHI)
189+
break; // PHI nodes are always at the beginning
190+
191+
int Idx = PHI->getBasicBlockIndex(OldPred);
192+
if (Idx >= 0) {
193+
PHI->setIncomingBlock(Idx, NewPred);
194+
LLVM_DEBUG(dbgs() << "Updated PHI node predecessor from "
195+
<< OldPred->getName() << " to " << NewPred->getName()
196+
<< "\n");
197+
}
198+
}
199+
}
200+
201+
/// Perform the condition splitting transformation.
202+
/// Splits the false branch of final_shift comparison into SIMD and scalar
203+
/// paths.
204+
static bool performConditionSplit(Function &F, ICmpInst *FinalShiftCmp,
205+
AllocaInst *KAddr) {
206+
// Find the conditional branch that uses the comparison
207+
BranchInst *Branch = findConditionalBranch(FinalShiftCmp);
208+
if (!Branch)
209+
return false;
210+
211+
BasicBlock *OriginalFalseBB = Branch->getSuccessor(1); // Right shift path
212+
LLVM_DEBUG(dbgs() << "Splitting false branch: " << OriginalFalseBB->getName()
213+
<< "\n");
214+
215+
// Create new basic blocks for the transformation
216+
LLVMContext &Ctx = F.getContext();
217+
BasicBlock *CondCheckBB = BasicBlock::Create(Ctx, "k.align.check", &F);
218+
BasicBlock *SIMDPathBB = BasicBlock::Create(Ctx, "simd.path", &F);
219+
BasicBlock *ScalarPathBB = BasicBlock::Create(Ctx, "scalar.path", &F);
220+
BasicBlock *MergeBB = BasicBlock::Create(Ctx, "split.merge", &F);
221+
222+
// Redirect original branch to new condition check
223+
Branch->setSuccessor(1, CondCheckBB);
224+
225+
// Create k % 8 == 0 check
226+
IRBuilder<> CondBuilder(CondCheckBB);
227+
Value *IsAligned = createAlignmentCheck(CondBuilder, KAddr);
228+
CondBuilder.CreateCondBr(IsAligned, SIMDPathBB, ScalarPathBB);
229+
230+
// Clone original block content to both paths
231+
ValueToValueMapTy SIMDMap, ScalarMap;
232+
cloneInstructionsToBlock(OriginalFalseBB, SIMDPathBB, SIMDMap);
233+
cloneInstructionsToBlock(OriginalFalseBB, ScalarPathBB, ScalarMap);
234+
235+
// Fix operand references in cloned instructions
236+
fixClonedInstructionOperands(SIMDPathBB, SIMDMap);
237+
fixClonedInstructionOperands(ScalarPathBB, ScalarMap);
238+
239+
// Add branches to merge block
240+
IRBuilder<> SIMDBuilder(SIMDPathBB);
241+
IRBuilder<> ScalarBuilder(ScalarPathBB);
242+
SIMDBuilder.CreateBr(MergeBB);
243+
ScalarBuilder.CreateBr(MergeBB);
244+
245+
// Handle original successor
246+
IRBuilder<> MergeBuilder(MergeBB);
247+
BasicBlock *OriginalSuccessor = nullptr;
248+
249+
if (auto *Term = OriginalFalseBB->getTerminator()) {
250+
if (auto *Br = dyn_cast<BranchInst>(Term)) {
251+
if (!Br->isConditional())
252+
OriginalSuccessor = Br->getSuccessor(0);
253+
}
254+
}
255+
256+
if (OriginalSuccessor) {
257+
MergeBuilder.CreateBr(OriginalSuccessor);
258+
updateSuccessorPhiNodes(OriginalSuccessor, OriginalFalseBB, MergeBB);
259+
} else {
260+
MergeBuilder.CreateUnreachable();
261+
}
262+
263+
// Clean up: remove the original block
264+
OriginalFalseBB->eraseFromParent();
265+
266+
LLVM_DEBUG(
267+
dbgs() << "Successfully performed condition split transformation\n");
268+
return true;
269+
}
270+
271+
} // anonymous namespace
272+
273+
namespace {
274+
275+
/// Check if the function matches the expected matrix multiplication pattern.
276+
/// This validates that the function has the necessary components for
277+
/// transformation.
278+
static bool isMatrixMultiplicationCandidate(Function &F, LoopInfo &LI,
279+
AllocaInst **KAddr,
280+
ICmpInst **FinalShiftCmp) {
281+
// Find k variable alloca
282+
*KAddr = findKVariableAlloca(F);
283+
if (!*KAddr) {
284+
LLVM_DEBUG(
285+
dbgs()
286+
<< "Cannot find k variable - not a matrix multiplication function\n");
287+
return false;
288+
}
289+
290+
// Find final_shift comparison
291+
*FinalShiftCmp = findFinalShiftComparison(F);
292+
if (!*FinalShiftCmp) {
293+
LLVM_DEBUG(
294+
dbgs()
295+
<< "Cannot find final_shift comparison - not the target pattern\n");
296+
return false;
297+
}
298+
299+
// Check loop structure (optional validation)
300+
bool HasTripleNestedLoops = hasTripleNestedLoopStructure(F, LI);
301+
LLVM_DEBUG(dbgs() << "Triple nested loops check: "
302+
<< (HasTripleNestedLoops ? "PASS" : "FAIL") << "\n");
303+
304+
if (!HasTripleNestedLoops) {
305+
LLVM_DEBUG(dbgs() << "Warning: Loop structure validation failed, "
306+
<< "but proceeding based on pattern match\n");
307+
}
308+
309+
return true;
310+
}
311+
312+
} // anonymous namespace
313+
314+
PreservedAnalyses
315+
RISCVESP32P4ConditionSplitPass::run(Function &F, FunctionAnalysisManager &AM) {
316+
LLVM_DEBUG(dbgs() << "Running RISCVESP32P4ConditionSplitPass on function: "
317+
<< F.getName() << "\n");
318+
319+
// Early exit if pass is disabled
320+
if (!EnableRISCVESP32P4ConditionSplit) {
321+
LLVM_DEBUG(dbgs() << "Pass is disabled via command line option\n");
322+
return PreservedAnalyses::all();
323+
}
324+
325+
// Get required analysis
326+
LoopInfo &LI = AM.getResult<LoopAnalysis>(F);
327+
328+
// Validate function matches target pattern
329+
AllocaInst *KAddr = nullptr;
330+
ICmpInst *FinalShiftCmp = nullptr;
331+
332+
if (!isMatrixMultiplicationCandidate(F, LI, &KAddr, &FinalShiftCmp)) {
333+
LLVM_DEBUG(
334+
dbgs() << "Function does not match matrix multiplication pattern\n");
335+
return PreservedAnalyses::all();
336+
}
337+
338+
// Perform the transformation
339+
bool Changed = performConditionSplit(F, FinalShiftCmp, KAddr);
340+
341+
if (!Changed) {
342+
LLVM_DEBUG(dbgs() << "Failed to apply condition split transformation\n");
343+
return PreservedAnalyses::all();
344+
}
345+
346+
LLVM_DEBUG(dbgs() << "Successfully applied condition split transformation to "
347+
<< F.getName() << "\n");
348+
349+
// Return preserved analyses - we preserve LoopInfo as we don't modify loop
350+
// structure
351+
PreservedAnalyses PA;
352+
PA.preserve<LoopAnalysis>();
353+
return PA;
354+
}

0 commit comments

Comments
 (0)