diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td index d8088b8c638fd..46d2165e91d16 100644 --- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td +++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td @@ -626,7 +626,7 @@ class VOPProfileMAIgetMemoryVT().getSimpleVT().SimpleTy; SDValue Imm = Ops[0]; - if (MemTy != MVT::f16 && MemTy != MVT::v2f16 && + if (MemTy != MVT::f16 && MemTy != MVT::bf16 && (isa(Imm) || isa(Imm))) { // Convert immediate to target constant if (MemTy == MVT::f32 || MemTy == MVT::f64) { diff --git a/llvm/lib/Target/NVPTX/NVPTXRegCount.cpp b/llvm/lib/Target/NVPTX/NVPTXRegCount.cpp new file mode 100644 index 0000000000000..5caf934ed4f2c --- /dev/null +++ b/llvm/lib/Target/NVPTX/NVPTXRegCount.cpp @@ -0,0 +1,64 @@ +//===-- NVPTXISelDAGToDAG.cpp - A dag to dag inst selector for NVPTX ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines an instruction selector for the NVPTX target. +// +//===----------------------------------------------------------------------===// + +#include "NVPTXISelDAGToDAG.h" +#include "NVPTX.h" +#include "NVPTXUtilities.h" +#include "llvm/ADT/APInt.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/CodeGen/ISDOpcodes.h" +#include "llvm/CodeGen/SelectionDAG.h" +#include "llvm/CodeGen/SelectionDAGNodes.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicsNVPTX.h" +#include "llvm/IR/NVVMIntrinsicUtils.h" +#include "llvm/Support/AtomicOrdering.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/FormatVariadic.h" +#include +using namespace llvm; + +namespace { +class NVPTXRegCountPass : public MachineFunctionPass { + public: + static char ID; + NVPTXRegCountPass() : MachineFunctionPass(ID) {} + + bool runOnMachineFunction(MachineFunction &MF) override { + unsigned maxRegs = 0; + for (const MachineBasicBlock &MBB : MF) { + unsigned liveRegs = 0; + for (const MachineInstr &MI : MBB) { + // Count unique virtual and physical registers + for (const MachineOperand &MO : MI.operands()) { + if (MO.isReg() && MO.getReg()) + liveRegs++; + } + } + maxRegs = std::max(maxRegs, liveRegs); + } + errs() << "Function " << MF.getName() << " uses maximum of " + << maxRegs << " registers\n"; + return false; + } + }; +} // namespace + +char NVPTXRegCountPass::ID = 0; +// INITIALIZE_PASS(NVPTXRegCountPass, "nvptx-count-reg", +// "NVPTX count reg", false, false) + + FunctionPass *llvm::createNVPTXRegCountPass() { + return new NVPTXRegCountPass(); + } \ No newline at end of file diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp index 8a25256ea1e4a..b9dd54ca3be60 100644 --- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp @@ -234,6 +234,19 @@ void NVPTXTargetMachine::registerDefaultAliasAnalyses(AAManager &AAM) { AAM.registerFunctionAnalysis(); } +struct NVPTXModulePrinter : public PassInfoMixin { + PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM) { + std::error_code EC; + raw_fd_ostream OutFile("/home/ubuntu/modular/delete-me-test_batch_kv_cache_flash_attention_causal_mask_ragged_paged.ll", EC); + if (!EC) { + M.print(OutFile, nullptr); + } + return PreservedAnalyses::all(); + } + + static bool isRequired() { return true; } +}; + void NVPTXTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) { #define GET_PASS_REGISTRY "NVPTXPassRegistry.def" #include "llvm/Passes/TargetPassRegistry.inc" @@ -250,6 +263,7 @@ void NVPTXTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) { FPM.addPass(NVVMIntrRangePass()); if (EarlyByValArgsCopy) FPM.addPass(NVPTXCopyByValArgsPass()); + //PM.addPass(NVPTXModulePrinter()); PM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM))); }); @@ -418,6 +432,7 @@ void NVPTXPassConfig::addPreRegAlloc() { void NVPTXPassConfig::addPostRegAlloc() { addPass(createNVPTXPrologEpilogPass()); + addPass(createNVPTXRegCountPass()); if (getOptLevel() != CodeGenOptLevel::None) { // NVPTXPrologEpilogPass calculates frame object offset and replace frame // index with VRFrame register. NVPTXPeephole need to be run after that and