|
| 1 | +//===- RISCVESP32P4ConditionSplit.cpp - Condition Split Pass -*- C++ -*-===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +/// |
| 9 | +/// \file |
| 10 | +/// This file implements the RISCVESP32P4ConditionSplit pass. |
| 11 | +/// |
| 12 | +/// This pass splits right-shift branches in matrix multiplication functions |
| 13 | +/// to create separate paths for SIMD-optimizable cases (k % 8 == 0) and |
| 14 | +/// scalar fallback cases. This enables subsequent SIMD optimization passes |
| 15 | +/// to target the aligned path specifically. |
| 16 | +/// |
| 17 | +/// Transformation: |
| 18 | +/// if (final_shift <= 0) { /* right shift */ } |
| 19 | +/// => |
| 20 | +/// if (final_shift <= 0) { |
| 21 | +/// if (k % 8 == 0) { /* SIMD path */ } |
| 22 | +/// else { /* scalar path */ } |
| 23 | +/// } |
| 24 | +/// |
| 25 | +//===----------------------------------------------------------------------===// |
| 26 | + |
| 27 | +#include "RISCVESP32P4ConditionSplit.h" |
| 28 | + |
| 29 | +#include "llvm/Analysis/LoopInfo.h" |
| 30 | +#include "llvm/IR/BasicBlock.h" |
| 31 | +#include "llvm/IR/Constants.h" |
| 32 | +#include "llvm/IR/Function.h" |
| 33 | +#include "llvm/IR/IRBuilder.h" |
| 34 | +#include "llvm/IR/Instructions.h" |
| 35 | +#include "llvm/IR/PatternMatch.h" |
| 36 | +#include "llvm/Support/Debug.h" |
| 37 | +#include "llvm/Transforms/Utils/BasicBlockUtils.h" |
| 38 | +#include "llvm/Transforms/Utils/Cloning.h" |
| 39 | + |
| 40 | +using namespace llvm; |
| 41 | +using namespace llvm::PatternMatch; |
| 42 | + |
| 43 | +#define DEBUG_TYPE "riscv-esp32p4-condition-split" |
| 44 | + |
| 45 | +// Command line option to enable/disable RISCVESP32P4ConditionSplit |
| 46 | +cl::opt<bool> llvm::EnableRISCVESP32P4ConditionSplit( |
| 47 | + "riscv-esp32p4-condition-split", cl::init(false), |
| 48 | + cl::desc("Enable RISC-V ESP32-P4 condition split for matrix functions")); |
| 49 | + |
| 50 | +namespace { |
| 51 | + |
| 52 | +/// Check if the function has a triple-nested loop structure. |
| 53 | +/// This is used as additional validation for matrix multiplication patterns. |
| 54 | +static bool hasTripleNestedLoopStructure(const Function &F, LoopInfo &LI) { |
| 55 | + LLVM_DEBUG(dbgs() << "Analyzing loop structure for function: " << F.getName() |
| 56 | + << "\n"); |
| 57 | + |
| 58 | + unsigned MaxDepth = 0; |
| 59 | + |
| 60 | + // Helper lambda to recursively find maximum loop depth |
| 61 | + std::function<void(const Loop *)> visitLoop = [&](const Loop *L) { |
| 62 | + MaxDepth = std::max(MaxDepth, L->getLoopDepth()); |
| 63 | + for (const Loop *SubLoop : L->getSubLoops()) |
| 64 | + visitLoop(SubLoop); |
| 65 | + }; |
| 66 | + |
| 67 | + // Visit all top-level loops |
| 68 | + for (const Loop *L : LI) |
| 69 | + visitLoop(L); |
| 70 | + |
| 71 | + LLVM_DEBUG(dbgs() << "Maximum loop depth found: " << MaxDepth << "\n"); |
| 72 | + return MaxDepth >= 3; |
| 73 | +} |
| 74 | + |
| 75 | +/// Find the alloca instruction for the 'k' variable in matrix multiplication. |
| 76 | +/// Returns nullptr if not found. |
| 77 | +static AllocaInst *findKVariableAlloca(Function &F) { |
| 78 | + for (BasicBlock &BB : F) { |
| 79 | + for (Instruction &I : BB) { |
| 80 | + if (auto *AI = dyn_cast<AllocaInst>(&I)) { |
| 81 | + if (AI->getName().contains("k.addr")) { |
| 82 | + LLVM_DEBUG(dbgs() << "Found k variable alloca: " << *AI << "\n"); |
| 83 | + return AI; |
| 84 | + } |
| 85 | + } |
| 86 | + } |
| 87 | + } |
| 88 | + |
| 89 | + LLVM_DEBUG(dbgs() << "No k variable alloca found\n"); |
| 90 | + return nullptr; |
| 91 | +} |
| 92 | + |
| 93 | +/// Find the comparison instruction for 'final_shift > 0' pattern. |
| 94 | +/// This identifies the branch we want to split for SIMD optimization. |
| 95 | +static ICmpInst *findFinalShiftComparison(Function &F) { |
| 96 | + for (BasicBlock &BB : F) { |
| 97 | + for (Instruction &I : BB) { |
| 98 | + auto *CmpI = dyn_cast<ICmpInst>(&I); |
| 99 | + if (!CmpI || CmpI->getPredicate() != ICmpInst::ICMP_SGT) |
| 100 | + continue; |
| 101 | + |
| 102 | + // Check if this compares final_shift against zero |
| 103 | + auto *LoadI = dyn_cast<LoadInst>(CmpI->getOperand(0)); |
| 104 | + auto *ZeroConst = dyn_cast<ConstantInt>(CmpI->getOperand(1)); |
| 105 | + |
| 106 | + if (!LoadI || !ZeroConst || !ZeroConst->isZero()) |
| 107 | + continue; |
| 108 | + |
| 109 | + auto *AI = dyn_cast<AllocaInst>(LoadI->getPointerOperand()); |
| 110 | + if (AI && AI->getName().contains("final_shift")) { |
| 111 | + LLVM_DEBUG(dbgs() << "Found final_shift comparison: " << *CmpI << "\n"); |
| 112 | + return CmpI; |
| 113 | + } |
| 114 | + } |
| 115 | + } |
| 116 | + |
| 117 | + LLVM_DEBUG(dbgs() << "No final_shift comparison found\n"); |
| 118 | + return nullptr; |
| 119 | +} |
| 120 | + |
| 121 | +/// Find the branch instruction that uses the given comparison. |
| 122 | +static BranchInst *findConditionalBranch(ICmpInst *CmpInst) { |
| 123 | + for (User *U : CmpInst->users()) { |
| 124 | + if (auto *BI = dyn_cast<BranchInst>(U)) { |
| 125 | + if (BI->isConditional()) { |
| 126 | + LLVM_DEBUG(dbgs() << "Found conditional branch: " << *BI << "\n"); |
| 127 | + return BI; |
| 128 | + } |
| 129 | + } |
| 130 | + } |
| 131 | + |
| 132 | + LLVM_DEBUG(dbgs() << "No conditional branch found for comparison\n"); |
| 133 | + return nullptr; |
| 134 | +} |
| 135 | + |
| 136 | +/// Create the k % 8 == 0 alignment check in the given basic block. |
| 137 | +static Value *createAlignmentCheck(IRBuilder<> &Builder, AllocaInst *KAddr) { |
| 138 | + // Load k value |
| 139 | + LoadInst *KLoad = Builder.CreateLoad(Builder.getInt32Ty(), KAddr, "k.val"); |
| 140 | + |
| 141 | + // Calculate k % 8 |
| 142 | + Value *EightConst = ConstantInt::get(Builder.getInt32Ty(), 8); |
| 143 | + Value *RemainderVal = Builder.CreateSRem(KLoad, EightConst, "k.rem8"); |
| 144 | + |
| 145 | + // Check k % 8 == 0 |
| 146 | + Value *ZeroConst = ConstantInt::get(Builder.getInt32Ty(), 0); |
| 147 | + Value *IsAligned = |
| 148 | + Builder.CreateICmpEQ(RemainderVal, ZeroConst, "k.is_aligned"); |
| 149 | + |
| 150 | + LLVM_DEBUG(dbgs() << "Created alignment check: k % 8 == 0\n"); |
| 151 | + return IsAligned; |
| 152 | +} |
| 153 | + |
| 154 | +/// Clone instructions from source block to target block, excluding terminators. |
| 155 | +/// Updates the provided value map for use relationship fixing. |
| 156 | +static void cloneInstructionsToBlock(BasicBlock *SourceBB, BasicBlock *TargetBB, |
| 157 | + ValueToValueMapTy &VMap) { |
| 158 | + IRBuilder<> Builder(TargetBB); |
| 159 | + |
| 160 | + for (Instruction &I : *SourceBB) { |
| 161 | + if (I.isTerminator()) |
| 162 | + continue; |
| 163 | + |
| 164 | + Instruction *ClonedInst = I.clone(); |
| 165 | + Builder.Insert(ClonedInst); |
| 166 | + VMap[&I] = ClonedInst; |
| 167 | + } |
| 168 | +} |
| 169 | + |
| 170 | +/// Fix operand references in the cloned instructions using the value map. |
| 171 | +static void fixClonedInstructionOperands(BasicBlock *BB, |
| 172 | + const ValueToValueMapTy &VMap) { |
| 173 | + for (Instruction &I : *BB) { |
| 174 | + for (unsigned Idx = 0, E = I.getNumOperands(); Idx != E; ++Idx) { |
| 175 | + Value *Op = I.getOperand(Idx); |
| 176 | + auto It = VMap.find(Op); |
| 177 | + if (It != VMap.end()) |
| 178 | + I.setOperand(Idx, It->second); |
| 179 | + } |
| 180 | + } |
| 181 | +} |
| 182 | + |
| 183 | +/// Update PHI nodes in the successor block to use the new predecessor. |
| 184 | +static void updateSuccessorPhiNodes(BasicBlock *SuccessorBB, |
| 185 | + BasicBlock *OldPred, BasicBlock *NewPred) { |
| 186 | + for (Instruction &I : *SuccessorBB) { |
| 187 | + auto *PHI = dyn_cast<PHINode>(&I); |
| 188 | + if (!PHI) |
| 189 | + break; // PHI nodes are always at the beginning |
| 190 | + |
| 191 | + int Idx = PHI->getBasicBlockIndex(OldPred); |
| 192 | + if (Idx >= 0) { |
| 193 | + PHI->setIncomingBlock(Idx, NewPred); |
| 194 | + LLVM_DEBUG(dbgs() << "Updated PHI node predecessor from " |
| 195 | + << OldPred->getName() << " to " << NewPred->getName() |
| 196 | + << "\n"); |
| 197 | + } |
| 198 | + } |
| 199 | +} |
| 200 | + |
| 201 | +/// Perform the condition splitting transformation. |
| 202 | +/// Splits the false branch of final_shift comparison into SIMD and scalar |
| 203 | +/// paths. |
| 204 | +static bool performConditionSplit(Function &F, ICmpInst *FinalShiftCmp, |
| 205 | + AllocaInst *KAddr) { |
| 206 | + // Find the conditional branch that uses the comparison |
| 207 | + BranchInst *Branch = findConditionalBranch(FinalShiftCmp); |
| 208 | + if (!Branch) |
| 209 | + return false; |
| 210 | + |
| 211 | + BasicBlock *OriginalFalseBB = Branch->getSuccessor(1); // Right shift path |
| 212 | + LLVM_DEBUG(dbgs() << "Splitting false branch: " << OriginalFalseBB->getName() |
| 213 | + << "\n"); |
| 214 | + |
| 215 | + // Create new basic blocks for the transformation |
| 216 | + LLVMContext &Ctx = F.getContext(); |
| 217 | + BasicBlock *CondCheckBB = BasicBlock::Create(Ctx, "k.align.check", &F); |
| 218 | + BasicBlock *SIMDPathBB = BasicBlock::Create(Ctx, "simd.path", &F); |
| 219 | + BasicBlock *ScalarPathBB = BasicBlock::Create(Ctx, "scalar.path", &F); |
| 220 | + BasicBlock *MergeBB = BasicBlock::Create(Ctx, "split.merge", &F); |
| 221 | + |
| 222 | + // Redirect original branch to new condition check |
| 223 | + Branch->setSuccessor(1, CondCheckBB); |
| 224 | + |
| 225 | + // Create k % 8 == 0 check |
| 226 | + IRBuilder<> CondBuilder(CondCheckBB); |
| 227 | + Value *IsAligned = createAlignmentCheck(CondBuilder, KAddr); |
| 228 | + CondBuilder.CreateCondBr(IsAligned, SIMDPathBB, ScalarPathBB); |
| 229 | + |
| 230 | + // Clone original block content to both paths |
| 231 | + ValueToValueMapTy SIMDMap, ScalarMap; |
| 232 | + cloneInstructionsToBlock(OriginalFalseBB, SIMDPathBB, SIMDMap); |
| 233 | + cloneInstructionsToBlock(OriginalFalseBB, ScalarPathBB, ScalarMap); |
| 234 | + |
| 235 | + // Fix operand references in cloned instructions |
| 236 | + fixClonedInstructionOperands(SIMDPathBB, SIMDMap); |
| 237 | + fixClonedInstructionOperands(ScalarPathBB, ScalarMap); |
| 238 | + |
| 239 | + // Add branches to merge block |
| 240 | + IRBuilder<> SIMDBuilder(SIMDPathBB); |
| 241 | + IRBuilder<> ScalarBuilder(ScalarPathBB); |
| 242 | + SIMDBuilder.CreateBr(MergeBB); |
| 243 | + ScalarBuilder.CreateBr(MergeBB); |
| 244 | + |
| 245 | + // Handle original successor |
| 246 | + IRBuilder<> MergeBuilder(MergeBB); |
| 247 | + BasicBlock *OriginalSuccessor = nullptr; |
| 248 | + |
| 249 | + if (auto *Term = OriginalFalseBB->getTerminator()) { |
| 250 | + if (auto *Br = dyn_cast<BranchInst>(Term)) { |
| 251 | + if (!Br->isConditional()) |
| 252 | + OriginalSuccessor = Br->getSuccessor(0); |
| 253 | + } |
| 254 | + } |
| 255 | + |
| 256 | + if (OriginalSuccessor) { |
| 257 | + MergeBuilder.CreateBr(OriginalSuccessor); |
| 258 | + updateSuccessorPhiNodes(OriginalSuccessor, OriginalFalseBB, MergeBB); |
| 259 | + } else { |
| 260 | + MergeBuilder.CreateUnreachable(); |
| 261 | + } |
| 262 | + |
| 263 | + // Clean up: remove the original block |
| 264 | + OriginalFalseBB->eraseFromParent(); |
| 265 | + |
| 266 | + LLVM_DEBUG( |
| 267 | + dbgs() << "Successfully performed condition split transformation\n"); |
| 268 | + return true; |
| 269 | +} |
| 270 | + |
| 271 | +} // anonymous namespace |
| 272 | + |
| 273 | +namespace { |
| 274 | + |
| 275 | +/// Check if the function matches the expected matrix multiplication pattern. |
| 276 | +/// This validates that the function has the necessary components for |
| 277 | +/// transformation. |
| 278 | +static bool isMatrixMultiplicationCandidate(Function &F, LoopInfo &LI, |
| 279 | + AllocaInst **KAddr, |
| 280 | + ICmpInst **FinalShiftCmp) { |
| 281 | + // Find k variable alloca |
| 282 | + *KAddr = findKVariableAlloca(F); |
| 283 | + if (!*KAddr) { |
| 284 | + LLVM_DEBUG( |
| 285 | + dbgs() |
| 286 | + << "Cannot find k variable - not a matrix multiplication function\n"); |
| 287 | + return false; |
| 288 | + } |
| 289 | + |
| 290 | + // Find final_shift comparison |
| 291 | + *FinalShiftCmp = findFinalShiftComparison(F); |
| 292 | + if (!*FinalShiftCmp) { |
| 293 | + LLVM_DEBUG( |
| 294 | + dbgs() |
| 295 | + << "Cannot find final_shift comparison - not the target pattern\n"); |
| 296 | + return false; |
| 297 | + } |
| 298 | + |
| 299 | + // Check loop structure (optional validation) |
| 300 | + bool HasTripleNestedLoops = hasTripleNestedLoopStructure(F, LI); |
| 301 | + LLVM_DEBUG(dbgs() << "Triple nested loops check: " |
| 302 | + << (HasTripleNestedLoops ? "PASS" : "FAIL") << "\n"); |
| 303 | + |
| 304 | + if (!HasTripleNestedLoops) { |
| 305 | + LLVM_DEBUG(dbgs() << "Warning: Loop structure validation failed, " |
| 306 | + << "but proceeding based on pattern match\n"); |
| 307 | + } |
| 308 | + |
| 309 | + return true; |
| 310 | +} |
| 311 | + |
| 312 | +} // anonymous namespace |
| 313 | + |
| 314 | +PreservedAnalyses |
| 315 | +RISCVESP32P4ConditionSplitPass::run(Function &F, FunctionAnalysisManager &AM) { |
| 316 | + LLVM_DEBUG(dbgs() << "Running RISCVESP32P4ConditionSplitPass on function: " |
| 317 | + << F.getName() << "\n"); |
| 318 | + |
| 319 | + // Early exit if pass is disabled |
| 320 | + if (!EnableRISCVESP32P4ConditionSplit) { |
| 321 | + LLVM_DEBUG(dbgs() << "Pass is disabled via command line option\n"); |
| 322 | + return PreservedAnalyses::all(); |
| 323 | + } |
| 324 | + |
| 325 | + // Get required analysis |
| 326 | + LoopInfo &LI = AM.getResult<LoopAnalysis>(F); |
| 327 | + |
| 328 | + // Validate function matches target pattern |
| 329 | + AllocaInst *KAddr = nullptr; |
| 330 | + ICmpInst *FinalShiftCmp = nullptr; |
| 331 | + |
| 332 | + if (!isMatrixMultiplicationCandidate(F, LI, &KAddr, &FinalShiftCmp)) { |
| 333 | + LLVM_DEBUG( |
| 334 | + dbgs() << "Function does not match matrix multiplication pattern\n"); |
| 335 | + return PreservedAnalyses::all(); |
| 336 | + } |
| 337 | + |
| 338 | + // Perform the transformation |
| 339 | + bool Changed = performConditionSplit(F, FinalShiftCmp, KAddr); |
| 340 | + |
| 341 | + if (!Changed) { |
| 342 | + LLVM_DEBUG(dbgs() << "Failed to apply condition split transformation\n"); |
| 343 | + return PreservedAnalyses::all(); |
| 344 | + } |
| 345 | + |
| 346 | + LLVM_DEBUG(dbgs() << "Successfully applied condition split transformation to " |
| 347 | + << F.getName() << "\n"); |
| 348 | + |
| 349 | + // Return preserved analyses - we preserve LoopInfo as we don't modify loop |
| 350 | + // structure |
| 351 | + PreservedAnalyses PA; |
| 352 | + PA.preserve<LoopAnalysis>(); |
| 353 | + return PA; |
| 354 | +} |
0 commit comments