Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 12 additions & 156 deletions third_party/intel/include/Analysis/AxisInfo.h
Original file line number Diff line number Diff line change
@@ -1,169 +1,24 @@
#ifndef TRITON_INTEL_ANALYSIS_AXISINFO_H
#define TRITON_INTEL_ANALYSIS_AXISINFO_H

#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
#include "llvm/Support/raw_ostream.h"

#include "mlir/Support/LLVM.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"

#include <optional>
#include "triton/Analysis/AxisInfo.h"

namespace mlir::triton::intel {

//===----------------------------------------------------------------------===//
// AxisInfo
//===----------------------------------------------------------------------===//

/// This lattice value represents known information on the axes of a lattice.
class AxisInfo {
public:
typedef SmallVector<int64_t> DimVectorT;

public:
AxisInfo() : AxisInfo({}, {}, {}) {}

AxisInfo(const DimVectorT &contiguity, const DimVectorT &divisibility,
const DimVectorT &constancy)
: AxisInfo(contiguity, divisibility, constancy, std::nullopt) {}

AxisInfo(const DimVectorT &contiguity, const DimVectorT &divisibility,
const DimVectorT &constancy, std::optional<int64_t> constantValue)
: contiguity(contiguity), divisibility(divisibility),
constancy(constancy), constantValue(constantValue) {
assert(divisibility.size() == contiguity.size());
assert(constancy.size() == contiguity.size());
}

// contiguity[d] is the length of the shortest sequence of contiguous integers
// along dimension d.
//
// If we have an array of N elements with a contiguity value C, then the array
// can be divided into a list of N/C sequences of C contiguous elements.
// Since we have N = 2^k, C must be a power of two.
//
// For example, the 2D array
//
// [[10, 11, 12, 13, 18, 19, 20, 21],
// [20, 21, 22, 23, 28, 29, 30, 31]]
//
// has contiguity [1, 4], and
//
// [[12, 16, 20, 24],
// [13, 17, 21, 25],
// [14, 18, 22, 26],
// [15, 19, 23, 27],
// [18, 22, 26, 30],
// [19, 23, 27, 31]]
//
// has contiguity [2, 1].
int64_t getContiguity(size_t dim) const { return contiguity[dim]; }
const DimVectorT &getContiguity() const { return contiguity; }

// divisibility[d] is the largest power of two that divides the first element
// of all groups of length contiguity[d] along dimension d.
//
// For example,
//
// [[10, 11, 12, 13, 18, 19, 20, 21],
// [20, 21, 22, 23, 28, 29, 30, 31]]
//
// has divisibility [1, 2], and
//
// [[12, 16, 20, 24],
// [13, 17, 21, 25],
// [14, 18, 22, 26],
// [15, 19, 23, 27]]
//
// has divisibility [4, 1].
//
// On the other hand,
//
// [0, 1, 2, 0, 4, 5, 6, 7]
//
// has divisibility 1 because its contiguity is 1.
int64_t getDivisibility(size_t dim) const { return divisibility[dim]; }
const DimVectorT &getDivisibility() const { return divisibility; }

// constancy[d] is the length of the shortest sequence of repeating integers
// along dimension d.
//
// This is particularly useful to infer the contiguity of operations (e.g.
// add) involving a constant.
//
// If we have an array of N elements, with a constancy value C, then the array
// can be divided into a list of N/C sequences of C elements with the same
// value. Since we have N = 2^k, C must be a power of two.
//
// For example
//
// [[8, 8, 8, 8, 12, 12, 12, 12],
// [16, 16, 16, 16, 20, 20, 20, 20]]
//
// has constancy [1, 4].
int64_t getConstancy(size_t dim) const { return constancy[dim]; }
const DimVectorT &getConstancy() const { return constancy; }

int getRank() const { return contiguity.size(); }

std::optional<int64_t> getConstantValue() const { return constantValue; }

template <class T>
static void
initPessimisticStateFromFunc(int argNumber, T funcOp, DimVectorT *contiguity,
DimVectorT *divisibility, DimVectorT *constancy);

bool operator==(const AxisInfo &other) const {
return contiguity == other.contiguity &&
divisibility == other.divisibility && constancy == other.constancy &&
constantValue == other.constantValue;
}

static AxisInfo getPessimisticValueState(Value value);

// The gcd of both arguments for each dimension
static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs);

void print(raw_ostream &os) const {
auto print = [&](StringRef name, DimVectorT vec) {
os << name << " = [";
llvm::interleaveComma(vec, os);
os << "]";
};
print("contiguity", contiguity);
print(", divisibility", divisibility);
print(", constancy", constancy);
os << ", constant_value = ";
if (constantValue)
os << *constantValue;
else
os << "<none>";
}

private:
DimVectorT contiguity;
DimVectorT divisibility;
DimVectorT constancy;

// The constant value of the lattice if we can infer it.
std::optional<int64_t> constantValue;
};

// Module level axis info analysis based on the call graph, assuming that we do
// not have recursive functions.
//
// Since each function will be called multiple times, we need to calculate the
// axis info based on the axis info of all the callers. In the future, we can
// perform optimization using function cloning so that each call site will have
// unique axis info.
using AxisInfoMapT = DenseMap<Value, AxisInfo>;
class ModuleAxisInfoAnalysis : public CallGraph<AxisInfoMapT> {
// using AxisInfoMapT = DenseMap<Value, AxisInfo>;
class ModuleAxisInfoAnalysis : public triton::ModuleAxisInfoAnalysis {
public:
explicit ModuleAxisInfoAnalysis(ModuleOp moduleOp)
: CallGraph<AxisInfoMapT>(moduleOp) {
: triton::ModuleAxisInfoAnalysis(moduleOp) {
funcMap.clear();

SmallVector<FunctionOpInterface> funcs;
for (auto root : getRoots()) {
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
Expand All @@ -187,10 +42,11 @@ class ModuleAxisInfoAnalysis : public CallGraph<AxisInfoMapT> {
}
}

AxisInfo *getAxisInfo(Value value) {
AxisInfo *getAxisInfo(Value value) const {
auto funcOp =
value.getParentRegion()->getParentOfType<FunctionOpInterface>();
auto *axisInfoMap = getFuncData(funcOp);
auto *axisInfoMap =
const_cast<ModuleAxisInfoAnalysis *>(this)->getFuncData(funcOp);
if (!axisInfoMap) {
return nullptr;
}
Expand All @@ -201,9 +57,9 @@ class ModuleAxisInfoAnalysis : public CallGraph<AxisInfoMapT> {
return &(it->second);
}

unsigned getPtrContiguity(Value ptr);
unsigned getPtrAlignment(Value ptr);
unsigned getMaskAlignment(Value mask);
unsigned getPtrContiguity(Value ptr) const;
unsigned getPtrAlignment(Value ptr) const;
unsigned getMaskAlignment(Value mask) const;

private:
void initialize(FunctionOpInterface funcOp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ inline unsigned getNumElementsPerThread(
? cast<RankedTensorType>(cast<PointerType>(valTy).getPointeeType())
: cast<RankedTensorType>(valTy);
auto shapePerCTA = getShapePerCTA(ty);
mlir::triton::intel::AxisInfo &valInfo = *axisInfoAnalysis.getAxisInfo(val);
mlir::triton::AxisInfo &valInfo = *axisInfoAnalysis.getAxisInfo(val);

unsigned elemNumBits = getElementBitWidth(ty);
unsigned elemNumBytes = std::max(elemNumBits / 8, 1u);
Expand Down
116 changes: 6 additions & 110 deletions third_party/intel/lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1159,113 +1159,7 @@ void AxisInfoAnalysis::visitForOpInductionVar(

} // anonymous namespace

template <class T>
void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp,
DimVectorT *contiguity,
DimVectorT *divisibility,
DimVectorT *constancy) {
// liast of attributes that we care about
SmallVector<std::pair<DimVectorT *, std::string>> retVecs;
retVecs.push_back({contiguity, "tt.contiguity"});
retVecs.push_back({divisibility, "tt.divisibility"});
retVecs.push_back({constancy, "tt.constancy"});
// initialize attributes one by one
for (auto [vec, attrName] : retVecs) {
Attribute attr = funcOp.getArgAttr(argNumber, attrName);
if (auto int_attr = dyn_cast_or_null<IntegerAttr>(attr))
*vec = DimVectorT(contiguity->size(), int_attr.getValue().getZExtValue());
if (auto dense_attr = dyn_cast_or_null<DenseElementsAttr>(attr)) {
auto vals = dense_attr.getValues<int>();
*vec = DimVectorT(vals.begin(), vals.end());
}
}
}

/*static*/ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
auto rank = 1;
if (TensorType ty = dyn_cast<TensorType>(value.getType()))
rank = ty.getRank();
if (triton::PointerType ty = dyn_cast<triton::PointerType>(value.getType()))
if (TensorType elemTy = dyn_cast<TensorType>(ty.getPointeeType()))
rank = elemTy.getRank();

DimVectorT knownContiguity(rank, 1);
DimVectorT knownDivisibility(rank, 1);
DimVectorT knownConstancy(rank, 1);

BlockArgument blockArg = dyn_cast<BlockArgument>(value);

if (blockArg && blockArg.getOwner()->isEntryBlock()) {
Operation *op = blockArg.getOwner()->getParentOp();
if (auto fun = dyn_cast<FunctionOpInterface>(op))
initPessimisticStateFromFunc(blockArg.getArgNumber(), fun,
&knownContiguity, &knownDivisibility,
&knownConstancy);
// llvm codegen check alignment to generate vector load/store
// would be nice if this wasn't the case
else if (auto fun = dyn_cast<LLVM::LLVMFuncOp>(op))
initPessimisticStateFromFunc(blockArg.getArgNumber(), fun,
&knownContiguity, &knownDivisibility,
&knownConstancy);
else if (isa<RegionBranchOpInterface>(op)) {
// scf::ForOp, scf::IfOp, scf::WhileOp
// Control flow operations are initialized with "unknown" state:
// the maximum possible divisibility, contiguity, and constancy.
knownDivisibility = DimVectorT(rank, highestPowOf2Divisor<int64_t>(0));
knownConstancy = DimVectorT(rank, highestPowOf2Divisor<int64_t>(0));
knownContiguity = DimVectorT(rank, highestPowOf2Divisor<int64_t>(0));
}
} else if (Operation *op = value.getDefiningOp()) {
if (isa<RegionBranchOpInterface>(op)) {
// scf::ForOp, scf::IfOp, scf::WhileOp
// Control flow operations are initialized with "unknown" state:
// the maximum possible divisibility, contiguity, and constancy.
knownDivisibility = DimVectorT(rank, highestPowOf2Divisor<int64_t>(0));
knownConstancy = DimVectorT(rank, highestPowOf2Divisor<int64_t>(0));
knownContiguity = DimVectorT(rank, highestPowOf2Divisor<int64_t>(0));
}
// Other operations are conservatively initialized with the lowest possible
// divisibility, contiguity, and constancy unless they have specified.
if (Attribute attr = op->getDiscardableAttr("tt.divisibility")) {
auto vals = cast<DenseElementsAttr>(attr).getValues<int>();
knownDivisibility = DimVectorT(vals.begin(), vals.end());
}
if (Attribute attr = op->getDiscardableAttr("tt.contiguity")) {
auto vals = cast<DenseElementsAttr>(attr).getValues<int>();
knownContiguity = DimVectorT(vals.begin(), vals.end());
}
if (Attribute attr = op->getDiscardableAttr("tt.constancy")) {
auto vals = cast<DenseElementsAttr>(attr).getValues<int>();
knownConstancy = DimVectorT(vals.begin(), vals.end());
}
}

return AxisInfo(knownContiguity, knownDivisibility, knownConstancy);
}

/*static*/ AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) {
// If one argument is not initialized, return the other.
if (lhs.getRank() == 0)
return rhs;
if (rhs.getRank() == 0)
return lhs;
DimVectorT contiguity;
DimVectorT divisibility;
DimVectorT constancy;
for (auto d = 0; d < lhs.getRank(); ++d) {
contiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
divisibility.push_back(gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
constancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d)));
}
std::optional<int64_t> constantValue;
if (lhs.getConstantValue().has_value() &&
rhs.getConstantValue().has_value() &&
lhs.getConstantValue() == rhs.getConstantValue())
constantValue = lhs.getConstantValue();
return AxisInfo(contiguity, divisibility, constancy, constantValue);
}

unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) const {
auto tensorTy = ttgi::getRankedTensorType(ptr.getType());
if (!tensorTy)
return 1;
Expand All @@ -1287,7 +1181,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
return contiguity;
}

unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) const {
auto tensorTy = ttgi::getRankedTensorType(ptr.getType());
if (!tensorTy)
return 1;
Expand All @@ -1298,7 +1192,9 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
auto order = triton::gpu::getOrder(layout);
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
auto maxContig = axisInfo->getContiguity(order[0]);
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
unsigned elemNumBits = isTensorPointerType(ptr.getType())
? tensorTy.getElementType().getIntOrFloatBitWidth()
: triton::getPointeeBitWidth(tensorTy);
auto elemNumBytes = std::max<unsigned>(elemNumBits / 8, 1);
auto maxMultiple = std::max<int64_t>(maxMultipleBytes / elemNumBytes, 1);
unsigned alignment = std::min(maxMultiple, maxContig);
Expand All @@ -1315,7 +1211,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
return alignment;
}

unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) const {
auto tensorTy = ttgi::getRankedTensorType(mask.getType());
if (!tensorTy)
return 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

#include "TargetInfo.h"
#include "TritonGPUToLLVMBase.h"
#include "intel/include/Analysis/AxisInfo.h"
#include "intel/include/TritonIntelGPUToLLVM/TypeConverter.h"
#include "triton/Analysis/AxisInfo.h"

namespace mlir::triton::intel {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"

#include "intel/include/Analysis/AxisInfo.h"
#include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h"
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
#include "intel/include/GPUToTritonGEN/GPUToTritonGENPass.h"
#include "intel/include/TritonGENToLLVM/TritonGENToLLVMPass.h"
#include "intel/include/TritonIntelGPUToLLVM/Passes.h"

#include "triton/Analysis/Allocation.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Membar.h"
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
Expand Down Expand Up @@ -114,7 +114,7 @@ struct ConvertTritonGPUToLLVM
return signalPassFailure();
}

ModuleAxisInfoAnalysis axisInfoAnalysis(mod);
intel::ModuleAxisInfoAnalysis axisInfoAnalysis(mod);
OpBuilder::InsertPoint indexInsertPoint;

RewritePatternSet patterns(context);
Expand Down