Skip to content
Draft
2 changes: 2 additions & 0 deletions third_party/iluvatar/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
include_directories(backend/flagtree_backend_specialization/include)
add_subdirectory(backend/flagtree_backend_specialization/lib)
add_subdirectory(include)
add_subdirectory(lib)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#include "triton/Analysis/iluvatar_AxisInfo.h"
#include "triton/Analysis/iluvatar_Membar.h"
#include "triton/Analysis/iluvatar_Utility.h"
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#ifndef ILUVATAR_TRITON_ANALYSIS_AXISINFO_H
#define ILUVATAR_TRITON_ANALYSIS_AXISINFO_H

#define FLAGTREE_SPEC_AxisInfo_CorexFlag
#define FLAGTREE_SPEC_AxisInfo_initPessimisticStateFromFunc_ARG AxisInfo::DimVectorT *

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#ifndef ILUVATAR_TRITON_ANALYSIS_MEMBAR_H
#define ILUVATAR_TRITON_ANALYSIS_MEMBAR_H

#define FLAGTREE_SPEC_BlockInfo_Function

#endif // ILUVATAR_TRITON_ANALYSIS_MEMBAR_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#ifndef ILUVATAR_TRITON_ANALYSIS_UTILITY_H
#define ILUVATAR_TRITON_ANALYSIS_UTILITY_H

#define FLAGTREE_SPEC_Utility_Function
#define FLAGTREE_SPEC_Utility_multiRootGetSlice_ARG bool

#endif // TRITON_ANALYSIS_UTILITY_H
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// ILUVATAR_TRITON_ANALYSIS_UTILITY_H

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include "triton/Analysis/AxisInfo.h"

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"

namespace mlir::triton {

template <class T>
void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp,
DimVectorT *contiguity,
DimVectorT *divisibility,
DimVectorT *constancy,
DimVectorT *corexFlag) {
// liast of attributes that we care about
SmallVector<std::pair<DimVectorT *, std::string>> retVecs;
retVecs.push_back({contiguity, "tt.contiguity"});
retVecs.push_back({divisibility, "tt.divisibility"});
retVecs.push_back({constancy, "tt.constancy"});
retVecs.push_back({corexFlag, "tt.corex_stride"});

// initialize attributes one by one
for (auto [vec, attrName] : retVecs) {
Attribute attr = funcOp.getArgAttr(argNumber, attrName);
if (auto int_attr = dyn_cast_or_null<IntegerAttr>(attr))
*vec = DimVectorT(contiguity->size(), int_attr.getValue().getZExtValue());
if (auto dense_attr = dyn_cast_or_null<DenseElementsAttr>(attr)) {
auto vals = dense_attr.getValues<int>();
*vec = DimVectorT(vals.begin(), vals.end());
}
}
}

template void AxisInfo::initPessimisticStateFromFunc<mlir::FunctionOpInterface>(
int argNumber, mlir::FunctionOpInterface funcOp, AxisInfo::DimVectorT *contiguity,
AxisInfo::DimVectorT *divisibility, AxisInfo::DimVectorT *constancy,
FLAGTREE_SPEC_AxisInfo_initPessimisticStateFromFunc_ARG spec_arg);

template void AxisInfo::initPessimisticStateFromFunc<mlir::LLVM::LLVMFuncOp>(
int argNumber, mlir::LLVM::LLVMFuncOp funcOp, AxisInfo::DimVectorT *contiguity,
AxisInfo::DimVectorT *divisibility, AxisInfo::DimVectorT *constancy,
FLAGTREE_SPEC_AxisInfo_initPessimisticStateFromFunc_ARG spec_arg);

}

Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
add_triton_library(FlagTree_iluvatar_TritonAnalysis
AxisInfo.cpp
Membar.cpp
Utility.cpp

DEPENDS
TritonTableGen
TritonGPUAttrDefsIncGen
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#include "triton/Analysis/Membar.h"

namespace mlir {

// type: 0 all | 1 del W from other R |2 del R from other W
void BlockInfo::erase(BlockInfo &other, int type) {
if (type == 0) {
for (auto &sri : other.syncReadIntervals)
syncReadIntervals.erase(sri);
for (auto &swi : other.syncWriteIntervals)
syncWriteIntervals.erase(swi);
} else if (type == 1) {
for (auto &sri : other.syncReadIntervals)
syncWriteIntervals.erase(sri);
} else if (type == 2) {
for (auto &swi : other.syncWriteIntervals)
syncReadIntervals.erase(swi);
}
}

// for debug
void BlockInfo::printIntervals() {
if (syncReadIntervals.size() > 0 || syncWriteIntervals.size() > 0) {
std::cout << " syncReadIntervals";
for (auto &lhs : syncReadIntervals)
std::cout << " [" << lhs.start() << ", " << lhs.end() << "] ";
std::cout << "" << std::endl;
std::cout << " syncWriteIntervals";
for (auto &lhs : syncWriteIntervals)
std::cout << " [" << lhs.start() << ", " << lhs.end() << "] ";
std::cout << "" << std::endl;
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#include "triton/Analysis/Utility.h"

namespace mlir {

bool isMmaToDotSlowShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {

auto srcLayout = srcTy.getEncoding();
auto dstLayout = dstTy.getEncoding();
if (!srcLayout.isa<triton::gpu::IluvatarMmaEncodingAttr>())
return false;
auto mmaLayout = srcLayout.cast<triton::gpu::IluvatarMmaEncodingAttr>();
if (!dstLayout.isa<triton::gpu::DotOperandEncodingAttr>())
return false;
auto dotOperandLayout = dstLayout.cast<triton::gpu::DotOperandEncodingAttr>();
auto dstParLayout = dotOperandLayout.getParent();
if (!dstParLayout.isa<triton::gpu::IluvatarMmaEncodingAttr>())
return false;
auto dstMmaLayout =
dstParLayout.dyn_cast<triton::gpu::IluvatarMmaEncodingAttr>();
return !isMmaToDotShortcut(srcTy, dstTy) &&
mmaLayout.getVersionMajor() == 1 &&
dstMmaLayout.getVersionMajor() == 1 &&
mmaLayout.getWarpsPerCTA()[0] == dstMmaLayout.getWarpsPerCTA()[0] &&
dotOperandLayout.getOpIdx() == 0 && !srcTy.getElementType().isF32();
}

void getBackwardSliceImplCorex(Operation *op,
SetVector<Operation *> *backwardSlice,
TransitiveFilter filter,
bool omitBlockArguments) {
if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
return;

// Evaluate whether we should keep this def.
// This is useful in particular to implement scoping; i.e. return the
// transitive backwardSlice in the current scope.
if (filter && !filter(op))
return;

for (const auto &en : llvm::enumerate(op->getOperands())) {
auto operand = en.value();
if (auto *definingOp = operand.getDefiningOp()) {
if (backwardSlice->count(definingOp) == 0)
getBackwardSliceImplCorex(definingOp, backwardSlice, filter,
omitBlockArguments);
} else if (auto blockArg = operand.dyn_cast<BlockArgument>()) {
if (omitBlockArguments)
continue;

Block *block = blockArg.getOwner();
Operation *parentOp = block->getParentOp();
// TODO: determine whether we want to recurse backward into the other
// blocks of parentOp, which are not technically backward unless they flow
// into us. For now, just bail.
if (parentOp && backwardSlice->count(parentOp) == 0) {
// assert(parentOp->getNumRegions() == 1 &&
// parentOp->getRegion(0).getBlocks().size() == 1);
getBackwardSliceImplCorex(parentOp, backwardSlice, filter,
omitBlockArguments);
}
} else {
llvm_unreachable("No definingOp and not a block argument.");
}
}

backwardSlice->insert(op);
}

void getBackwardSliceCorex(Operation *op, SetVector<Operation *> *backwardSlice,
TransitiveFilter filter, bool omitBlockArguments) {
getBackwardSliceImplCorex(op, backwardSlice, filter, omitBlockArguments);

// Don't insert the top level operation, we just queried on it and don't
// want it in the results.
backwardSlice->remove(op);
}

SetVector<Operation *> multiRootGetSlice(Operation *op,
TransitiveFilter backwardFilter,
TransitiveFilter forwardFilter,
bool omitBlockArguments) {
SetVector<Operation *> slice;
slice.insert(op);

unsigned currentIndex = 0;
SetVector<Operation *> backwardSlice;
SetVector<Operation *> forwardSlice;
while (currentIndex != slice.size()) {
auto *currentOp = (slice)[currentIndex];
// Compute and insert the backwardSlice starting from currentOp.
backwardSlice.clear();
BackwardSliceOptions opt;
opt.omitBlockArguments = true;
opt.filter = backwardFilter;
getBackwardSliceCorex(currentOp, &backwardSlice, opt.filter,
opt.omitBlockArguments);
slice.insert(backwardSlice.begin(), backwardSlice.end());

// Compute and insert the forwardSlice starting from currentOp.
forwardSlice.clear();
getForwardSlice(currentOp, &forwardSlice, forwardFilter);
slice.insert(forwardSlice.begin(), forwardSlice.end());
++currentIndex;
}
return multiRootTopologicalSort(slice);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(Analysis)
3 changes: 3 additions & 0 deletions third_party/iluvatar/include/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
add_subdirectory(triton)

set(ILUVATAR_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../backend/flagtree_backend_specialization/include/triton")
include_directories("${ILUVATAR_INCLUDE_DIR}")
38 changes: 36 additions & 2 deletions third_party/iluvatar/include/triton/Analysis/AxisInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include <optional>
#include <type_traits>

#include "flagtree_spec.h"

namespace mlir::triton {

//===----------------------------------------------------------------------===//
Expand All @@ -25,6 +27,20 @@ class AxisInfo {
typedef SmallVector<int64_t> DimVectorT;

public:
#ifndef FLAGTREE_SPEC_AxisInfo_CorexFlag
AxisInfo() : AxisInfo({}, {}, {}) {}

AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy)
: AxisInfo(contiguity, divisibility, constancy, std::nullopt) {}

AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy,
std::optional<int64_t> constantValue)
: contiguity(contiguity), divisibility(divisibility),
constancy(constancy), constantValue(constantValue) {
assert(divisibility.size() == contiguity.size());
assert(constancy.size() == contiguity.size());
}
#else
AxisInfo() : AxisInfo({}, {}, {}, {}) {}

AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy,
Expand All @@ -40,6 +56,7 @@ class AxisInfo {
assert(divisibility.size() == contiguity.size());
assert(constancy.size() == contiguity.size());
}
#endif

// contiguity[d] is the length of the shortest sequence of contiguous integers
// along dimension d.
Expand Down Expand Up @@ -110,25 +127,37 @@ class AxisInfo {
int64_t getConstancy(size_t dim) const { return constancy[dim]; }
const DimVectorT &getConstancy() const { return constancy; }

#ifdef FLAGTREE_SPEC_AxisInfo_CorexFlag
// corexFlag is used to determine whether special instructions can be used to
// accelerate data loading.
int64_t getCorexFlag(size_t dim) const { return corexFlag[dim]; }
const DimVectorT &getCorexFlag() const { return corexFlag; }
#endif

int getRank() const { return contiguity.size(); }

std::optional<int64_t> getConstantValue() const { return constantValue; }

#ifdef FLAGTREE_SPEC_AxisInfo_initPessimisticStateFromFunc_ARG
template <class T>
static void
initPessimisticStateFromFunc(int argNumber, T funcOp, DimVectorT *contiguity,
DimVectorT *divisibility, DimVectorT *constancy,
DimVectorT *corex_stride);
FLAGTREE_SPEC_AxisInfo_initPessimisticStateFromFunc_ARG spec_arg);
#else
template <class T>
static void
initPessimisticStateFromFunc(int argNumber, T funcOp, DimVectorT *contiguity,
DimVectorT *divisibility, DimVectorT *constancy);
#endif

bool operator==(const AxisInfo &other) const {
return contiguity == other.contiguity &&
divisibility == other.divisibility && constancy == other.constancy &&
corexFlag == other.corexFlag && constantValue == other.constantValue;
#ifdef FLAGTREE_SPEC_AxisInfo_CorexFlag
corexFlag == other.corexFlag &&
#endif
constantValue == other.constantValue;
}

static AxisInfo getPessimisticValueState(Value value);
Expand All @@ -145,7 +174,9 @@ class AxisInfo {
print("contiguity", contiguity);
print(", divisibility", divisibility);
print(", constancy", constancy);
#ifdef FLAGTREE_SPEC_AxisInfo_CorexFlag
print(", corexflag", corexFlag);
#endif
os << ", constant_value = ";
if (constantValue)
os << *constantValue;
Expand All @@ -157,9 +188,12 @@ class AxisInfo {
DimVectorT contiguity;
DimVectorT divisibility;
DimVectorT constancy;

// The constant value of the lattice if we can infer it.
std::optional<int64_t> constantValue;
#ifdef FLAGTREE_SPEC_AxisInfo_CorexFlag
DimVectorT corexFlag;
#endif
};

// Module level axis info analysis based on the call graph, assuming that we do
Expand Down
32 changes: 5 additions & 27 deletions third_party/iluvatar/include/triton/Analysis/Membar.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include <set>

#include "triton/../../backend/flagtree_backend_specialization/include/flagtree_spec.h"

namespace mlir {

class OpBuilder;
Expand Down Expand Up @@ -43,36 +45,12 @@ struct BlockInfo {
syncWriteIntervals.clear();
}

#ifdef __ILUVATAR__
#ifdef FLAGTREE_SPEC_BlockInfo_Function
// type: 0 all | 1 del W from other R |2 del R from other W
void erase(BlockInfo &other, int type = 0) {
if (type == 0) {
for (auto &sri : other.syncReadIntervals)
syncReadIntervals.erase(sri);
for (auto &swi : other.syncWriteIntervals)
syncWriteIntervals.erase(swi);
} else if (type == 1) {
for (auto &sri : other.syncReadIntervals)
syncWriteIntervals.erase(sri);
} else if (type == 2) {
for (auto &swi : other.syncWriteIntervals)
syncReadIntervals.erase(swi);
}
}
void erase(BlockInfo &other, int type = 0);

// for debug
void printIntervals() {
if (syncReadIntervals.size() > 0 || syncWriteIntervals.size() > 0) {
std::cout << " syncReadIntervals";
for (auto &lhs : syncReadIntervals)
std::cout << " [" << lhs.start() << ", " << lhs.end() << "] ";
std::cout << "" << std::endl;
std::cout << " syncWriteIntervals";
for (auto &lhs : syncWriteIntervals)
std::cout << " [" << lhs.start() << ", " << lhs.end() << "] ";
std::cout << "" << std::endl;
}
}
void printIntervals();
#endif

/// Compares two BlockInfo objects.
Expand Down
Loading