Skip to content

Commit 1b91543

Browse files
committed
add register logging
1 parent 2f41fa3 commit 1b91543

File tree

6 files changed

+83
-3
lines changed

6 files changed

+83
-3
lines changed

llvm/lib/Target/AMDGPU/VOP3PInstructions.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ class VOPProfileMAI<VOPProfile P, RegisterOperand _SrcRC, RegisterOperand _DstRC
626626
// and with the earlyclobber flag on the dst. This is stricter than the
627627
// actual HW restriction. In particular earlyclobber also affects src0 and
628628
// src1 allocation which is not required.
629-
bit NoDstOverlap = !gt(DstVT.Size, 128);
629+
bit NoDstOverlap = 1; //!gt(DstVT.Size, 128);
630630
}
631631

632632
class VOPProfileSMFMAC<VOPProfile P, RegisterOperand _DstRC,

llvm/lib/Target/NVPTX/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ set(NVPTXCodeGen_sources
2727
NVPTXLowerArgs.cpp
2828
NVPTXLowerAlloca.cpp
2929
NVPTXLowerUnreachable.cpp
30+
NVPTXRegCount.cpp
3031
NVPTXPeephole.cpp
3132
NVPTXMCExpr.cpp
3233
NVPTXPrologEpilogPass.cpp

llvm/lib/Target/NVPTX/NVPTX.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ enum CondCodes {
3636
GE
3737
};
3838
}
39-
39+
FunctionPass *createNVPTXRegCountPass();
4040
FunctionPass *createNVPTXISelDag(NVPTXTargetMachine &TM,
4141
llvm::CodeGenOptLevel OptLevel);
4242
ModulePass *createNVPTXAssignValidGlobalNamesPass();

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1852,7 +1852,7 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
18521852
case 1: {
18531853
MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
18541854
SDValue Imm = Ops[0];
1855-
if (MemTy != MVT::f16 && MemTy != MVT::v2f16 &&
1855+
if (MemTy != MVT::f16 && MemTy != MVT::bf16 &&
18561856
(isa<ConstantSDNode>(Imm) || isa<ConstantFPSDNode>(Imm))) {
18571857
// Convert immediate to target constant
18581858
if (MemTy == MVT::f32 || MemTy == MVT::f64) {
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
//===-- NVPTXISelDAGToDAG.cpp - A dag to dag inst selector for NVPTX ------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines an instruction selector for the NVPTX target.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "NVPTXISelDAGToDAG.h"
14+
#include "NVPTX.h"
15+
#include "NVPTXUtilities.h"
16+
#include "llvm/ADT/APInt.h"
17+
#include "llvm/Analysis/ValueTracking.h"
18+
#include "llvm/CodeGen/ISDOpcodes.h"
19+
#include "llvm/CodeGen/SelectionDAG.h"
20+
#include "llvm/CodeGen/SelectionDAGNodes.h"
21+
#include "llvm/IR/GlobalValue.h"
22+
#include "llvm/IR/Instructions.h"
23+
#include "llvm/IR/IntrinsicsNVPTX.h"
24+
#include "llvm/IR/NVVMIntrinsicUtils.h"
25+
#include "llvm/Support/AtomicOrdering.h"
26+
#include "llvm/Support/CommandLine.h"
27+
#include "llvm/Support/ErrorHandling.h"
28+
#include "llvm/Support/FormatVariadic.h"
29+
#include <optional>
30+
using namespace llvm;
31+
32+
namespace {
33+
class NVPTXRegCountPass : public MachineFunctionPass {
34+
public:
35+
static char ID;
36+
NVPTXRegCountPass() : MachineFunctionPass(ID) {}
37+
38+
bool runOnMachineFunction(MachineFunction &MF) override {
39+
unsigned maxRegs = 0;
40+
for (const MachineBasicBlock &MBB : MF) {
41+
unsigned liveRegs = 0;
42+
for (const MachineInstr &MI : MBB) {
43+
// Count unique virtual and physical registers
44+
for (const MachineOperand &MO : MI.operands()) {
45+
if (MO.isReg() && MO.getReg())
46+
liveRegs++;
47+
}
48+
}
49+
maxRegs = std::max(maxRegs, liveRegs);
50+
}
51+
errs() << "Function " << MF.getName() << " uses maximum of "
52+
<< maxRegs << " registers\n";
53+
return false;
54+
}
55+
};
56+
} // namespace
57+
58+
char NVPTXRegCountPass::ID = 0;
59+
// INITIALIZE_PASS(NVPTXRegCountPass, "nvptx-count-reg",
60+
// "NVPTX count reg", false, false)
61+
62+
FunctionPass *llvm::createNVPTXRegCountPass() {
63+
return new NVPTXRegCountPass();
64+
}

llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,19 @@ void NVPTXTargetMachine::registerDefaultAliasAnalyses(AAManager &AAM) {
234234
AAM.registerFunctionAnalysis<NVPTXAA>();
235235
}
236236

237+
struct NVPTXModulePrinter : public PassInfoMixin<NVPTXModulePrinter> {
238+
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM) {
239+
std::error_code EC;
240+
raw_fd_ostream OutFile("/home/ubuntu/modular/delete-me-test_batch_kv_cache_flash_attention_causal_mask_ragged_paged.ll", EC);
241+
if (!EC) {
242+
M.print(OutFile, nullptr);
243+
}
244+
return PreservedAnalyses::all();
245+
}
246+
247+
static bool isRequired() { return true; }
248+
};
249+
237250
void NVPTXTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
238251
#define GET_PASS_REGISTRY "NVPTXPassRegistry.def"
239252
#include "llvm/Passes/TargetPassRegistry.inc"
@@ -250,6 +263,7 @@ void NVPTXTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
250263
FPM.addPass(NVVMIntrRangePass());
251264
if (EarlyByValArgsCopy)
252265
FPM.addPass(NVPTXCopyByValArgsPass());
266+
//PM.addPass(NVPTXModulePrinter());
253267
PM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
254268
});
255269

@@ -418,6 +432,7 @@ void NVPTXPassConfig::addPreRegAlloc() {
418432

419433
void NVPTXPassConfig::addPostRegAlloc() {
420434
addPass(createNVPTXPrologEpilogPass());
435+
addPass(createNVPTXRegCountPass());
421436
if (getOptLevel() != CodeGenOptLevel::None) {
422437
// NVPTXPrologEpilogPass calculates frame object offset and replace frame
423438
// index with VRFrame register. NVPTXPeephole need to be run after that and

0 commit comments

Comments
 (0)