Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llvm/lib/Target/AMDGPU/VOP3PInstructions.td
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ class VOPProfileMAI<VOPProfile P, RegisterOperand _SrcRC, RegisterOperand _DstRC
// and with the earlyclobber flag on the dst. This is stricter than the
// actual HW restriction. In particular earlyclobber also affects src0 and
// src1 allocation which is not required.
bit NoDstOverlap = !gt(DstVT.Size, 128);
bit NoDstOverlap = 1; //!gt(DstVT.Size, 128);
}

class VOPProfileSMFMAC<VOPProfile P, RegisterOperand _DstRC,
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ set(NVPTXCodeGen_sources
NVPTXLowerArgs.cpp
NVPTXLowerAlloca.cpp
NVPTXLowerUnreachable.cpp
NVPTXRegCount.cpp
NVPTXPeephole.cpp
NVPTXMCExpr.cpp
NVPTXPrologEpilogPass.cpp
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/NVPTX/NVPTX.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ enum CondCodes {
GE
};
}

FunctionPass *createNVPTXRegCountPass();
FunctionPass *createNVPTXISelDag(NVPTXTargetMachine &TM,
llvm::CodeGenOptLevel OptLevel);
ModulePass *createNVPTXAssignValidGlobalNamesPass();
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1852,7 +1852,7 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
case 1: {
MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
SDValue Imm = Ops[0];
if (MemTy != MVT::f16 && MemTy != MVT::v2f16 &&
if (MemTy != MVT::f16 && MemTy != MVT::bf16 &&
(isa<ConstantSDNode>(Imm) || isa<ConstantFPSDNode>(Imm))) {
// Convert immediate to target constant
if (MemTy == MVT::f32 || MemTy == MVT::f64) {
Expand Down
64 changes: 64 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXRegCount.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//===-- NVPTXISelDAGToDAG.cpp - A dag to dag inst selector for NVPTX ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines an instruction selector for the NVPTX target.
//
//===----------------------------------------------------------------------===//

#include "NVPTXISelDAGToDAG.h"
#include "NVPTX.h"
#include "NVPTXUtilities.h"
#include "llvm/ADT/APInt.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/NVVMIntrinsicUtils.h"
#include "llvm/Support/AtomicOrdering.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include <optional>
using namespace llvm;

namespace {
class NVPTXRegCountPass : public MachineFunctionPass {
public:
static char ID;
NVPTXRegCountPass() : MachineFunctionPass(ID) {}

bool runOnMachineFunction(MachineFunction &MF) override {
unsigned maxRegs = 0;
for (const MachineBasicBlock &MBB : MF) {
unsigned liveRegs = 0;
for (const MachineInstr &MI : MBB) {
// Count unique virtual and physical registers
for (const MachineOperand &MO : MI.operands()) {
if (MO.isReg() && MO.getReg())
liveRegs++;
}
}
maxRegs = std::max(maxRegs, liveRegs);
}
errs() << "Function " << MF.getName() << " uses maximum of "
<< maxRegs << " registers\n";
return false;
}
};
} // namespace

char NVPTXRegCountPass::ID = 0;
// INITIALIZE_PASS(NVPTXRegCountPass, "nvptx-count-reg",
// "NVPTX count reg", false, false)

FunctionPass *llvm::createNVPTXRegCountPass() {
return new NVPTXRegCountPass();
}
15 changes: 15 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,19 @@ void NVPTXTargetMachine::registerDefaultAliasAnalyses(AAManager &AAM) {
AAM.registerFunctionAnalysis<NVPTXAA>();
}

struct NVPTXModulePrinter : public PassInfoMixin<NVPTXModulePrinter> {
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM) {
std::error_code EC;
raw_fd_ostream OutFile("/home/ubuntu/modular/delete-me-test_batch_kv_cache_flash_attention_causal_mask_ragged_paged.ll", EC);
if (!EC) {
M.print(OutFile, nullptr);
}
return PreservedAnalyses::all();
}

static bool isRequired() { return true; }
};

void NVPTXTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
#define GET_PASS_REGISTRY "NVPTXPassRegistry.def"
#include "llvm/Passes/TargetPassRegistry.inc"
Expand All @@ -250,6 +263,7 @@ void NVPTXTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
FPM.addPass(NVVMIntrRangePass());
if (EarlyByValArgsCopy)
FPM.addPass(NVPTXCopyByValArgsPass());
//PM.addPass(NVPTXModulePrinter());
PM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
});

Expand Down Expand Up @@ -418,6 +432,7 @@ void NVPTXPassConfig::addPreRegAlloc() {

void NVPTXPassConfig::addPostRegAlloc() {
addPass(createNVPTXPrologEpilogPass());
addPass(createNVPTXRegCountPass());
if (getOptLevel() != CodeGenOptLevel::None) {
// NVPTXPrologEpilogPass calculates frame object offset and replace frame
// index with VRFrame register. NVPTXPeephole need to be run after that and
Expand Down
Loading