Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions llvm/lib/Target/X86/X86.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,16 @@ FunctionPass *createX86InsertX87waitPass();
/// This pass optimizes arithmetic based on knowledge that is only used by
/// a reduction sequence and is therefore safe to reassociate in interesting
/// ways.
FunctionPass *createX86PartialReductionPass();
class X86PartialReductionPass : public PassInfoMixin<X86PartialReductionPass> {
private:
const TargetMachine *TM;

public:
X86PartialReductionPass(const TargetMachine *TM) : TM(TM) {}
PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM);
};

FunctionPass *createX86PartialReductionLegacyPass();

/// // Analyzes and emits pseudos to support Win x64 Unwind V2.
FunctionPass *createX86WinEHUnwindV2Pass();
Expand All @@ -179,7 +188,18 @@ FunctionPass *createX86LowerAMXTypeLegacyPass();

/// The pass transforms amx intrinsics to scalar operation if the function has
/// optnone attribute or it is O0.
FunctionPass *createX86LowerAMXIntrinsicsPass();
class X86LowerAMXIntrinsicsPass
: public PassInfoMixin<X86LowerAMXIntrinsicsPass> {
private:
const TargetMachine *TM;

public:
X86LowerAMXIntrinsicsPass(const TargetMachine *TM) : TM(TM) {}
PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM);
static bool isRequired() { return true; }
};

FunctionPass *createX86LowerAMXIntrinsicsLegacyPass();

InstructionSelector *createX86InstructionSelector(const X86TargetMachine &TM,
const X86Subtarget &,
Expand Down Expand Up @@ -220,7 +240,7 @@ void initializeX86LowerAMXIntrinsicsLegacyPassPass(PassRegistry &);
void initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &);
void initializeX86LowerTileCopyPass(PassRegistry &);
void initializeX86OptimizeLEAPassPass(PassRegistry &);
void initializeX86PartialReductionPass(PassRegistry &);
void initializeX86PartialReductionLegacyPass(PassRegistry &);
void initializeX86PreTileConfigPass(PassRegistry &);
void initializeX86ReturnThunksPass(PassRegistry &);
void initializeX86SpeculativeExecutionSideEffectSuppressionPass(PassRegistry &);
Expand Down
48 changes: 38 additions & 10 deletions llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@
#include "llvm/CodeGen/Passes.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/CodeGen/ValueTypes.h"
#include "llvm/IR/Analysis.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/IntrinsicsX86.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
Expand All @@ -40,7 +43,7 @@
using namespace llvm;
using namespace PatternMatch;

#define DEBUG_TYPE "lower-amx-intrinsics"
#define DEBUG_TYPE "x86-lower-amx-intrinsics"

#ifndef NDEBUG
static bool isV256I32Ty(Type *Ty) {
Expand Down Expand Up @@ -626,6 +629,37 @@ bool X86LowerAMXIntrinsics::visit() {
return C;
}

namespace {
bool shouldRunLowerAMXIntrinsics(const Function &F, const TargetMachine *TM) {
return X86ScalarizeAMX && (F.hasFnAttribute(Attribute::OptimizeNone) ||
TM->getOptLevel() == CodeGenOptLevel::None);
}

bool runLowerAMXIntrinsics(Function &F, DominatorTree *DT, LoopInfo *LI) {
DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);

X86LowerAMXIntrinsics LAT(F, DTU, LI);
return LAT.visit();
}
} // namespace

PreservedAnalyses X86LowerAMXIntrinsicsPass::run(Function &F,
FunctionAnalysisManager &FAM) {
if (!shouldRunLowerAMXIntrinsics(F, TM))
return PreservedAnalyses::all();

DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
LoopInfo &LI = FAM.getResult<LoopAnalysis>(F);
bool Changed = runLowerAMXIntrinsics(F, &DT, &LI);
if (!Changed)
return PreservedAnalyses::all();

PreservedAnalyses PA = PreservedAnalyses::none();
PA.preserve<DominatorTreeAnalysis>();
PA.preserve<LoopAnalysis>();
return PA;
}

namespace {
class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass {
public:
Expand All @@ -634,21 +668,15 @@ class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass {
X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) {}

bool runOnFunction(Function &F) override {
if (!X86ScalarizeAMX)
return false;
TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
if (!F.hasFnAttribute(Attribute::OptimizeNone) &&
TM->getOptLevel() != CodeGenOptLevel::None)
if (!shouldRunLowerAMXIntrinsics(F, TM))
return false;

auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr;
DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);

X86LowerAMXIntrinsics LAT(F, DTU, LI);
return LAT.visit();
return runLowerAMXIntrinsics(F, DT, LI);
}
StringRef getPassName() const override { return "Lower AMX intrinsics"; }

Expand All @@ -668,6 +696,6 @@ INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
false, false)

FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() {
FunctionPass *llvm::createX86LowerAMXIntrinsicsLegacyPass() {
return new X86LowerAMXIntrinsicsLegacyPass();
}
73 changes: 47 additions & 26 deletions llvm/lib/Target/X86/X86PartialReduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
#include "X86TargetMachine.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/IR/Analysis.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicsX86.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Pass.h"
#include "llvm/Support/KnownBits.h"
Expand All @@ -30,39 +32,44 @@ using namespace llvm;

namespace {

class X86PartialReduction : public FunctionPass {
class X86PartialReduction {
const X86TargetMachine *TM;
const DataLayout *DL = nullptr;
const X86Subtarget *ST = nullptr;

public:
X86PartialReduction(const X86TargetMachine *TM) : TM(TM) {}
bool run(Function &F);

private:
bool tryMAddReplacement(Instruction *Op, bool ReduceInOneBB);
bool trySADReplacement(Instruction *Op);
};

class X86PartialReductionLegacy : public FunctionPass {
public:
static char ID; // Pass identification, replacement for typeid.

X86PartialReduction() : FunctionPass(ID) { }
X86PartialReductionLegacy() : FunctionPass(ID) {}

bool runOnFunction(Function &Fn) override;
bool runOnFunction(Function &F) override;

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
}

StringRef getPassName() const override {
return "X86 Partial Reduction";
}

private:
bool tryMAddReplacement(Instruction *Op, bool ReduceInOneBB);
bool trySADReplacement(Instruction *Op);
StringRef getPassName() const override { return "X86 Partial Reduction"; }
};
}

FunctionPass *llvm::createX86PartialReductionPass() {
return new X86PartialReduction();
FunctionPass *llvm::createX86PartialReductionLegacyPass() {
return new X86PartialReductionLegacy();
}

char X86PartialReduction::ID = 0;
char X86PartialReductionLegacy::ID = 0;

INITIALIZE_PASS(X86PartialReduction, DEBUG_TYPE,
"X86 Partial Reduction", false, false)
INITIALIZE_PASS(X86PartialReductionLegacy, DEBUG_TYPE, "X86 Partial Reduction",
false, false)

// This function should be aligned with detectExtMul() in X86ISelLowering.cpp.
static bool matchVPDPBUSDPattern(const X86Subtarget *ST, BinaryOperator *Mul,
Expand Down Expand Up @@ -494,17 +501,8 @@ static void collectLeaves(Value *Root, SmallVectorImpl<Instruction *> &Leaves) {
}
}

bool X86PartialReduction::runOnFunction(Function &F) {
if (skipFunction(F))
return false;

auto *TPC = getAnalysisIfAvailable<TargetPassConfig>();
if (!TPC)
return false;

auto &TM = TPC->getTM<X86TargetMachine>();
ST = TM.getSubtargetImpl(F);

bool X86PartialReduction::run(Function &F) {
ST = TM->getSubtargetImpl(F);
DL = &F.getDataLayout();

bool MadeChange = false;
Expand Down Expand Up @@ -540,3 +538,26 @@ bool X86PartialReduction::runOnFunction(Function &F) {

return MadeChange;
}

bool X86PartialReductionLegacy::runOnFunction(Function &F) {
if (skipFunction(F))
return false;

auto *TPC = getAnalysisIfAvailable<TargetPassConfig>();
if (!TPC)
return false;

return X86PartialReduction(&TPC->getTM<X86TargetMachine>()).run(F);
}

PreservedAnalyses X86PartialReductionPass::run(Function &F,
FunctionAnalysisManager &FAM) {
bool Changed =
X86PartialReduction(static_cast<const X86TargetMachine *>(TM)).run(F);
if (!Changed)
return PreservedAnalyses::all();

PreservedAnalyses PA = PreservedAnalyses::none();
PA.preserveSet<CFGAnalyses>();
return PA;
}
4 changes: 2 additions & 2 deletions llvm/lib/Target/X86/X86PassRegistry.def
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
#ifndef FUNCTION_PASS
#define FUNCTION_PASS(NAME, CREATE_PASS)
#endif
FUNCTION_PASS("x86-lower-amx-intrinsics", X86LowerAMXIntrinsicsPass(this))
FUNCTION_PASS("x86-lower-amx-type", X86LowerAMXTypePass(this))
FUNCTION_PASS("x86-partial-reduction", X86PartialReductionPass(this))
#undef FUNCTION_PASS

#ifndef DUMMY_FUNCTION_PASS
#define DUMMY_FUNCTION_PASS(NAME, CREATE_PASS)
#endif
DUMMY_FUNCTION_PASS("lower-amx-intrinsics", X86LowerAMXIntrinsics(*this))
DUMMY_FUNCTION_PASS("x86-partial-reduction", X86PartialReduction())
DUMMY_FUNCTION_PASS("x86-winehstate", WinEHStatePass())
#undef DUMMY_FUNCTION_PASS

Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Target/X86/X86TargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ extern "C" LLVM_C_ABI void LLVMInitializeX86Target() {
initializeX86LoadValueInjectionLoadHardeningPassPass(PR);
initializeX86LoadValueInjectionRetHardeningPassPass(PR);
initializeX86OptimizeLEAPassPass(PR);
initializeX86PartialReductionPass(PR);
initializeX86PartialReductionLegacyPass(PR);
initializePseudoProbeInserterPass(PR);
initializeX86ReturnThunksPass(PR);
initializeX86DAGToDAGISelLegacyPass(PR);
Expand Down Expand Up @@ -422,14 +422,14 @@ void X86PassConfig::addIRPasses() {

// We add both pass anyway and when these two passes run, we skip the pass
// based on the option level and option attribute.
addPass(createX86LowerAMXIntrinsicsPass());
addPass(createX86LowerAMXIntrinsicsLegacyPass());
addPass(createX86LowerAMXTypeLegacyPass());

TargetPassConfig::addIRPasses();

if (TM->getOptLevel() != CodeGenOptLevel::None) {
addPass(createInterleavedAccessPass());
addPass(createX86PartialReductionPass());
addPass(createX86PartialReductionLegacyPass());
}

// Add passes that handle indirect branch removal and insertion of a retpoline
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt -mtriple=x86_64 -lower-amx-intrinsics -enable-x86-scalar-amx=true %s -S | FileCheck %s
; RUN: opt -mtriple=x86_64 -x86-lower-amx-intrinsics -enable-x86-scalar-amx=true %s -S | FileCheck %s
; RUN: opt -mtriple=x86_64 -passes=x86-lower-amx-intrinsics -enable-x86-scalar-amx=true %s -S | FileCheck %s

define dso_local void @test_no_bitcast(ptr %A_mem, ptr %B_mem, ptr %C_mem) local_unnamed_addr #0 {
; CHECK-LABEL: @test_no_bitcast(
Expand Down
3 changes: 2 additions & 1 deletion llvm/test/CodeGen/X86/AMX/amx-low-intrinsics.ll
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt -mtriple=x86_64 -lower-amx-intrinsics -enable-x86-scalar-amx=true %s -S | FileCheck %s
; RUN: opt -mtriple=x86_64 -x86-lower-amx-intrinsics -enable-x86-scalar-amx=true %s -S | FileCheck %s
; RUN: opt -mtriple=x86_64 -passes=x86-lower-amx-intrinsics -enable-x86-scalar-amx=true %s -S | FileCheck %s

define dso_local void @test_amx_load_non_O0(i16 signext %row, i16 signext %col, ptr%ptr, i64 %stride, ptr %vptr) {
; CHECK-LABEL: @test_amx_load_non_O0(
Expand Down