diff --git a/third_party/iluvatar/CMakeLists.txt b/third_party/iluvatar/CMakeLists.txt index 189faac8e..75528071b 100644 --- a/third_party/iluvatar/CMakeLists.txt +++ b/third_party/iluvatar/CMakeLists.txt @@ -1,3 +1,5 @@ +include_directories(backend/flagtree_backend_specialization/include) +add_subdirectory(backend/flagtree_backend_specialization/lib) add_subdirectory(include) add_subdirectory(lib) diff --git a/third_party/iluvatar/backend/flagtree_backend_specialization/include/flagtree_spec.h b/third_party/iluvatar/backend/flagtree_backend_specialization/include/flagtree_spec.h new file mode 100644 index 000000000..065fd04f6 --- /dev/null +++ b/third_party/iluvatar/backend/flagtree_backend_specialization/include/flagtree_spec.h @@ -0,0 +1,5 @@ +#include "triton/Analysis/iluvatar_AxisInfo.h" +#include "triton/Analysis/iluvatar_Membar.h" +#include "triton/Analysis/iluvatar_Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/iluvatar_ElementwiseOpToLLVMBase.h" +#include "triton/Conversion/TritonGPUToLLVM/iluvatar_TargetInfoBase.h" diff --git a/third_party/iluvatar/backend/flagtree_backend_specialization/include/triton/Analysis/iluvatar_AxisInfo.h b/third_party/iluvatar/backend/flagtree_backend_specialization/include/triton/Analysis/iluvatar_AxisInfo.h new file mode 100644 index 000000000..92ba755b4 --- /dev/null +++ b/third_party/iluvatar/backend/flagtree_backend_specialization/include/triton/Analysis/iluvatar_AxisInfo.h @@ -0,0 +1,7 @@ +#ifndef ILUVATAR_TRITON_ANALYSIS_AXISINFO_H +#define ILUVATAR_TRITON_ANALYSIS_AXISINFO_H + +#define FLAGTREE_SPEC_AxisInfo_CorexFlag +#define FLAGTREE_SPEC_AxisInfo_initPessimisticStateFromFunc_ARG AxisInfo::DimVectorT * + +#endif // ILUVATAR_TRITON_ANALYSIS_AXISINFO_H \ No newline at end of file diff --git a/third_party/iluvatar/backend/flagtree_backend_specialization/include/triton/Analysis/iluvatar_Membar.h b/third_party/iluvatar/backend/flagtree_backend_specialization/include/triton/Analysis/iluvatar_Membar.h new file mode 100644 index 000000000..2b83b34ed --- /dev/null +++ b/third_party/iluvatar/backend/flagtree_backend_specialization/include/triton/Analysis/iluvatar_Membar.h @@ -0,0 +1,6 @@ +#ifndef ILUVATAR_TRITON_ANALYSIS_MEMBAR_H +#define ILUVATAR_TRITON_ANALYSIS_MEMBAR_H + +#define FLAGTREE_SPEC_BlockInfo_Function + +#endif // ILUVATAR_TRITON_ANALYSIS_MEMBAR_H diff --git a/third_party/iluvatar/backend/flagtree_backend_specialization/include/triton/Analysis/iluvatar_Utility.h b/third_party/iluvatar/backend/flagtree_backend_specialization/include/triton/Analysis/iluvatar_Utility.h new file mode 100644 index 000000000..721671cae --- /dev/null +++ b/third_party/iluvatar/backend/flagtree_backend_specialization/include/triton/Analysis/iluvatar_Utility.h @@ -0,0 +1,7 @@ +#ifndef ILUVATAR_TRITON_ANALYSIS_UTILITY_H +#define ILUVATAR_TRITON_ANALYSIS_UTILITY_H + +#define FLAGTREE_SPEC_Utility_Function +#define FLAGTREE_SPEC_Utility_multiRootGetSlice_ARG bool + +#endif // ILUVATAR_TRITON_ANALYSIS_UTILITY_H \ No newline at end of file diff --git a/third_party/iluvatar/backend/flagtree_backend_specialization/include/triton/Conversion/TritonGPUToLLVM/iluvatar_ElementwiseOpToLLVMBase.h b/third_party/iluvatar/backend/flagtree_backend_specialization/include/triton/Conversion/TritonGPUToLLVM/iluvatar_ElementwiseOpToLLVMBase.h new file mode 100644 index 000000000..89715a919 --- /dev/null +++ b/third_party/iluvatar/backend/flagtree_backend_specialization/include/triton/Conversion/TritonGPUToLLVM/iluvatar_ElementwiseOpToLLVMBase.h @@ -0,0 +1,7 @@ +#ifndef ILUVATAR_TRITON_CONVERSION_TRITONGPU_TO_ELEMENTWISE_OP_H +#define ILUVATAR_TRITON_CONVERSION_TRITONGPU_TO_ELEMENTWISE_OP_H + +#define FLAGTREE_SPEC_ElementwiseOpConversionBase_maybeDeduplicate +#define FLAGTREE_SPEC_ElementwiseOpConversionBase_matchAndRewrite + +#endif // ILUVATAR_TRITON_CONVERSION_TRITONGPU_TO_ELEMENTWISE_OP_H \ No newline at end of file diff --git a/third_party/iluvatar/backend/flagtree_backend_specialization/include/triton/Conversion/TritonGPUToLLVM/iluvatar_TargetInfoBase.h b/third_party/iluvatar/backend/flagtree_backend_specialization/include/triton/Conversion/TritonGPUToLLVM/iluvatar_TargetInfoBase.h new file mode 100644 index 000000000..bfffdfbf6 --- /dev/null +++ b/third_party/iluvatar/backend/flagtree_backend_specialization/include/triton/Conversion/TritonGPUToLLVM/iluvatar_TargetInfoBase.h @@ -0,0 +1,6 @@ +#ifndef ILUVATAR_TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H +#define ILUVATAR_TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H + +#define FLAGTREE_SPEC_TargetInfoBase_function + +#endif // ILUVATAR_TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H \ No newline at end of file diff --git a/third_party/iluvatar/backend/flagtree_backend_specialization/lib/Analysis/AxisInfo.cpp b/third_party/iluvatar/backend/flagtree_backend_specialization/lib/Analysis/AxisInfo.cpp new file mode 100644 index 000000000..277124e2c --- /dev/null +++ b/third_party/iluvatar/backend/flagtree_backend_specialization/lib/Analysis/AxisInfo.cpp @@ -0,0 +1,43 @@ +#include "triton/Analysis/AxisInfo.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" + +namespace mlir::triton { + +template +void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp, + DimVectorT *contiguity, + DimVectorT *divisibility, + DimVectorT *constancy, + DimVectorT *corexFlag) { + // liast of attributes that we care about + SmallVector> retVecs; + retVecs.push_back({contiguity, "tt.contiguity"}); + retVecs.push_back({divisibility, "tt.divisibility"}); + retVecs.push_back({constancy, "tt.constancy"}); + retVecs.push_back({corexFlag, "tt.corex_stride"}); + + // initialize attributes one by one + for (auto [vec, attrName] : retVecs) { + Attribute attr = funcOp.getArgAttr(argNumber, attrName); + if (auto int_attr = dyn_cast_or_null(attr)) + *vec = DimVectorT(contiguity->size(), int_attr.getValue().getZExtValue()); + if (auto dense_attr = dyn_cast_or_null(attr)) { + auto vals = dense_attr.getValues(); + *vec = DimVectorT(vals.begin(), vals.end()); + } + } +} + +template void AxisInfo::initPessimisticStateFromFunc( + int argNumber, mlir::FunctionOpInterface funcOp, AxisInfo::DimVectorT *contiguity, + AxisInfo::DimVectorT *divisibility, AxisInfo::DimVectorT *constancy, + FLAGTREE_SPEC_AxisInfo_initPessimisticStateFromFunc_ARG spec_arg); + +template void AxisInfo::initPessimisticStateFromFunc( + int argNumber, mlir::LLVM::LLVMFuncOp funcOp, AxisInfo::DimVectorT *contiguity, + AxisInfo::DimVectorT *divisibility, AxisInfo::DimVectorT *constancy, + FLAGTREE_SPEC_AxisInfo_initPessimisticStateFromFunc_ARG spec_arg); + +} + diff --git a/third_party/iluvatar/backend/flagtree_backend_specialization/lib/Analysis/CMakeLists.txt b/third_party/iluvatar/backend/flagtree_backend_specialization/lib/Analysis/CMakeLists.txt new file mode 100644 index 000000000..3e4c88d5a --- /dev/null +++ b/third_party/iluvatar/backend/flagtree_backend_specialization/lib/Analysis/CMakeLists.txt @@ -0,0 +1,9 @@ +add_triton_library(FlagTree_iluvatar_TritonAnalysis + AxisInfo.cpp + Membar.cpp + Utility.cpp + + DEPENDS + TritonTableGen + TritonGPUAttrDefsIncGen +) \ No newline at end of file diff --git a/third_party/iluvatar/backend/flagtree_backend_specialization/lib/Analysis/Membar.cpp b/third_party/iluvatar/backend/flagtree_backend_specialization/lib/Analysis/Membar.cpp new file mode 100644 index 000000000..1ff993208 --- /dev/null +++ b/third_party/iluvatar/backend/flagtree_backend_specialization/lib/Analysis/Membar.cpp @@ -0,0 +1,35 @@ +#include "triton/Analysis/Membar.h" + +namespace mlir { + +// type: 0 all | 1 del W from other R |2 del R from other W +void BlockInfo::erase(BlockInfo &other, int type) { + if (type == 0) { + for (auto &sri : other.syncReadIntervals) + syncReadIntervals.erase(sri); + for (auto &swi : other.syncWriteIntervals) + syncWriteIntervals.erase(swi); + } else if (type == 1) { + for (auto &sri : other.syncReadIntervals) + syncWriteIntervals.erase(sri); + } else if (type == 2) { + for (auto &swi : other.syncWriteIntervals) + syncReadIntervals.erase(swi); + } +} + +// for debug +void BlockInfo::printIntervals() { + if (syncReadIntervals.size() > 0 || syncWriteIntervals.size() > 0) { + std::cout << " syncReadIntervals"; + for (auto &lhs : syncReadIntervals) + std::cout << " [" << lhs.start() << ", " << lhs.end() << "] "; + std::cout << "" << std::endl; + std::cout << " syncWriteIntervals"; + for (auto &lhs : syncWriteIntervals) + std::cout << " [" << lhs.start() << ", " << lhs.end() << "] "; + std::cout << "" << std::endl; + } +} + +} \ No newline at end of file diff --git a/third_party/iluvatar/backend/flagtree_backend_specialization/lib/Analysis/Utility.cpp b/third_party/iluvatar/backend/flagtree_backend_specialization/lib/Analysis/Utility.cpp new file mode 100644 index 000000000..27f84e1b7 --- /dev/null +++ b/third_party/iluvatar/backend/flagtree_backend_specialization/lib/Analysis/Utility.cpp @@ -0,0 +1,108 @@ +#include "triton/Analysis/Utility.h" + +namespace mlir { + +bool isMmaToDotSlowShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { + + auto srcLayout = srcTy.getEncoding(); + auto dstLayout = dstTy.getEncoding(); + if (!srcLayout.isa()) + return false; + auto mmaLayout = srcLayout.cast(); + if (!dstLayout.isa()) + return false; + auto dotOperandLayout = dstLayout.cast(); + auto dstParLayout = dotOperandLayout.getParent(); + if (!dstParLayout.isa()) + return false; + auto dstMmaLayout = + dstParLayout.dyn_cast(); + return !isMmaToDotShortcut(srcTy, dstTy) && + mmaLayout.getVersionMajor() == 1 && + dstMmaLayout.getVersionMajor() == 1 && + mmaLayout.getWarpsPerCTA()[0] == dstMmaLayout.getWarpsPerCTA()[0] && + dotOperandLayout.getOpIdx() == 0 && !srcTy.getElementType().isF32(); +} + +void getBackwardSliceImplCorex(Operation *op, + SetVector *backwardSlice, + TransitiveFilter filter, + bool omitBlockArguments) { + if (!op || op->hasTrait()) + return; + + // Evaluate whether we should keep this def. + // This is useful in particular to implement scoping; i.e. return the + // transitive backwardSlice in the current scope. + if (filter && !filter(op)) + return; + + for (const auto &en : llvm::enumerate(op->getOperands())) { + auto operand = en.value(); + if (auto *definingOp = operand.getDefiningOp()) { + if (backwardSlice->count(definingOp) == 0) + getBackwardSliceImplCorex(definingOp, backwardSlice, filter, + omitBlockArguments); + } else if (auto blockArg = operand.dyn_cast()) { + if (omitBlockArguments) + continue; + + Block *block = blockArg.getOwner(); + Operation *parentOp = block->getParentOp(); + // TODO: determine whether we want to recurse backward into the other + // blocks of parentOp, which are not technically backward unless they flow + // into us. For now, just bail. + if (parentOp && backwardSlice->count(parentOp) == 0) { + // assert(parentOp->getNumRegions() == 1 && + // parentOp->getRegion(0).getBlocks().size() == 1); + getBackwardSliceImplCorex(parentOp, backwardSlice, filter, + omitBlockArguments); + } + } else { + llvm_unreachable("No definingOp and not a block argument."); + } + } + + backwardSlice->insert(op); +} + +void getBackwardSliceCorex(Operation *op, SetVector *backwardSlice, + TransitiveFilter filter, bool omitBlockArguments) { + getBackwardSliceImplCorex(op, backwardSlice, filter, omitBlockArguments); + + // Don't insert the top level operation, we just queried on it and don't + // want it in the results. + backwardSlice->remove(op); +} + +SetVector multiRootGetSlice(Operation *op, + TransitiveFilter backwardFilter, + TransitiveFilter forwardFilter, + bool omitBlockArguments) { + SetVector slice; + slice.insert(op); + + unsigned currentIndex = 0; + SetVector backwardSlice; + SetVector forwardSlice; + while (currentIndex != slice.size()) { + auto *currentOp = (slice)[currentIndex]; + // Compute and insert the backwardSlice starting from currentOp. + backwardSlice.clear(); + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + opt.filter = backwardFilter; + getBackwardSliceCorex(currentOp, &backwardSlice, opt.filter, + opt.omitBlockArguments); + slice.insert(backwardSlice.begin(), backwardSlice.end()); + + // Compute and insert the forwardSlice starting from currentOp. + forwardSlice.clear(); + getForwardSlice(currentOp, &forwardSlice, forwardFilter); + slice.insert(forwardSlice.begin(), forwardSlice.end()); + ++currentIndex; + } + return multiRootTopologicalSort(slice); +} + +} \ No newline at end of file diff --git a/third_party/iluvatar/backend/flagtree_backend_specialization/lib/CMakeLists.txt b/third_party/iluvatar/backend/flagtree_backend_specialization/lib/CMakeLists.txt new file mode 100644 index 000000000..5c6d3ffe1 --- /dev/null +++ b/third_party/iluvatar/backend/flagtree_backend_specialization/lib/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Analysis) \ No newline at end of file diff --git a/third_party/iluvatar/include/CMakeLists.txt b/third_party/iluvatar/include/CMakeLists.txt index 109c292fe..233e998d3 100644 --- a/third_party/iluvatar/include/CMakeLists.txt +++ b/third_party/iluvatar/include/CMakeLists.txt @@ -1 +1,4 @@ add_subdirectory(triton) + +set(ILUVATAR_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../backend/flagtree_backend_specialization/include/triton") +include_directories("${ILUVATAR_INCLUDE_DIR}") diff --git a/third_party/iluvatar/include/triton/Analysis/AxisInfo.h b/third_party/iluvatar/include/triton/Analysis/AxisInfo.h index 5167f37d9..d09f84bab 100644 --- a/third_party/iluvatar/include/triton/Analysis/AxisInfo.h +++ b/third_party/iluvatar/include/triton/Analysis/AxisInfo.h @@ -13,6 +13,8 @@ #include #include +#include "flagtree_spec.h" + namespace mlir::triton { //===----------------------------------------------------------------------===// @@ -25,6 +27,20 @@ class AxisInfo { typedef SmallVector DimVectorT; public: +#ifndef FLAGTREE_SPEC_AxisInfo_CorexFlag + AxisInfo() : AxisInfo({}, {}, {}) {} + + AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy) + : AxisInfo(contiguity, divisibility, constancy, std::nullopt) {} + + AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy, + std::optional constantValue) + : contiguity(contiguity), divisibility(divisibility), + constancy(constancy), constantValue(constantValue) { + assert(divisibility.size() == contiguity.size()); + assert(constancy.size() == contiguity.size()); + } +#else AxisInfo() : AxisInfo({}, {}, {}, {}) {} AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy, @@ -40,6 +56,7 @@ class AxisInfo { assert(divisibility.size() == contiguity.size()); assert(constancy.size() == contiguity.size()); } +#endif // contiguity[d] is the length of the shortest sequence of contiguous integers // along dimension d. @@ -110,25 +127,37 @@ class AxisInfo { int64_t getConstancy(size_t dim) const { return constancy[dim]; } const DimVectorT &getConstancy() const { return constancy; } +#ifdef FLAGTREE_SPEC_AxisInfo_CorexFlag // corexFlag is used to determine whether special instructions can be used to // accelerate data loading. int64_t getCorexFlag(size_t dim) const { return corexFlag[dim]; } const DimVectorT &getCorexFlag() const { return corexFlag; } +#endif int getRank() const { return contiguity.size(); } std::optional getConstantValue() const { return constantValue; } +#ifdef FLAGTREE_SPEC_AxisInfo_initPessimisticStateFromFunc_ARG template static void initPessimisticStateFromFunc(int argNumber, T funcOp, DimVectorT *contiguity, DimVectorT *divisibility, DimVectorT *constancy, - DimVectorT *corex_stride); + FLAGTREE_SPEC_AxisInfo_initPessimisticStateFromFunc_ARG spec_arg); +#else + template + static void + initPessimisticStateFromFunc(int argNumber, T funcOp, DimVectorT *contiguity, + DimVectorT *divisibility, DimVectorT *constancy); +#endif bool operator==(const AxisInfo &other) const { return contiguity == other.contiguity && divisibility == other.divisibility && constancy == other.constancy && - corexFlag == other.corexFlag && constantValue == other.constantValue; +#ifdef FLAGTREE_SPEC_AxisInfo_CorexFlag + corexFlag == other.corexFlag && +#endif + constantValue == other.constantValue; } static AxisInfo getPessimisticValueState(Value value); @@ -145,7 +174,9 @@ class AxisInfo { print("contiguity", contiguity); print(", divisibility", divisibility); print(", constancy", constancy); +#ifdef FLAGTREE_SPEC_AxisInfo_CorexFlag print(", corexflag", corexFlag); +#endif os << ", constant_value = "; if (constantValue) os << *constantValue; @@ -157,9 +188,12 @@ class AxisInfo { DimVectorT contiguity; DimVectorT divisibility; DimVectorT constancy; + // The constant value of the lattice if we can infer it. std::optional constantValue; +#ifdef FLAGTREE_SPEC_AxisInfo_CorexFlag DimVectorT corexFlag; +#endif }; // Module level axis info analysis based on the call graph, assuming that we do diff --git a/third_party/iluvatar/include/triton/Analysis/Membar.h b/third_party/iluvatar/include/triton/Analysis/Membar.h index 9a9cf25f3..f02bba135 100644 --- a/third_party/iluvatar/include/triton/Analysis/Membar.h +++ b/third_party/iluvatar/include/triton/Analysis/Membar.h @@ -6,6 +6,8 @@ #include +#include "triton/../../backend/flagtree_backend_specialization/include/flagtree_spec.h" + namespace mlir { class OpBuilder; @@ -43,36 +45,12 @@ struct BlockInfo { syncWriteIntervals.clear(); } -#ifdef __ILUVATAR__ +#ifdef FLAGTREE_SPEC_BlockInfo_Function // type: 0 all | 1 del W from other R |2 del R from other W - void erase(BlockInfo &other, int type = 0) { - if (type == 0) { - for (auto &sri : other.syncReadIntervals) - syncReadIntervals.erase(sri); - for (auto &swi : other.syncWriteIntervals) - syncWriteIntervals.erase(swi); - } else if (type == 1) { - for (auto &sri : other.syncReadIntervals) - syncWriteIntervals.erase(sri); - } else if (type == 2) { - for (auto &swi : other.syncWriteIntervals) - syncReadIntervals.erase(swi); - } - } + void erase(BlockInfo &other, int type = 0); // for debug - void printIntervals() { - if (syncReadIntervals.size() > 0 || syncWriteIntervals.size() > 0) { - std::cout << " syncReadIntervals"; - for (auto &lhs : syncReadIntervals) - std::cout << " [" << lhs.start() << ", " << lhs.end() << "] "; - std::cout << "" << std::endl; - std::cout << " syncWriteIntervals"; - for (auto &lhs : syncWriteIntervals) - std::cout << " [" << lhs.start() << ", " << lhs.end() << "] "; - std::cout << "" << std::endl; - } - } + void printIntervals(); #endif /// Compares two BlockInfo objects. diff --git a/third_party/iluvatar/include/triton/Analysis/Utility.h b/third_party/iluvatar/include/triton/Analysis/Utility.h index b89b90215..86ab7fd5c 100644 --- a/third_party/iluvatar/include/triton/Analysis/Utility.h +++ b/third_party/iluvatar/include/triton/Analysis/Utility.h @@ -7,6 +7,8 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/../../backend/flagtree_backend_specialization/include/flagtree_spec.h" + namespace mlir { inline bool isZeroConst(Value v) { @@ -192,8 +194,6 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy); bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); -bool isMmaToDotSlowShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy); - bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy); // Return true if the src and dst layout match. @@ -212,7 +212,9 @@ bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout); SetVector multiRootTopologicalSort(const SetVector &toSort); -#ifdef __ILUVATAR__ +#ifdef FLAGTREE_SPEC_Utility_Function +bool isMmaToDotSlowShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy); + /// This function dones't use assertion check. void getBackwardSliceCorex(Operation *op, SetVector *backwardSlice, TransitiveFilter filter = nullptr, @@ -226,10 +228,16 @@ void getBackwardSliceImplCorex(Operation *op, #endif /// This uses the toplogicalSort above +#ifdef FLAGTREE_SPEC_Utility_multiRootGetSlice_ARG SetVector multiRootGetSlice(Operation *op, TransitiveFilter backwardFilter = nullptr, TransitiveFilter forwardFilter = nullptr, - bool omitBlockArguments = true); + FLAGTREE_SPEC_Utility_multiRootGetSlice_ARG spec_arg = true); +#else +SetVector +multiRootGetSlice(Operation *op, TransitiveFilter backwardFilter = nullptr, + TransitiveFilter forwardFilter = nullptr); +#endif /// Create a basic DataFlowSolver with constant and dead code analysis included. std::unique_ptr createDataFlowSolver(); diff --git a/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h b/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h index 1286b4e56..ddea5283d 100644 --- a/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h +++ b/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h @@ -8,6 +8,8 @@ #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/../../backend/flagtree_backend_specialization/include/flagtree_spec.h" + using namespace mlir; using namespace mlir::triton; @@ -102,12 +104,14 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern { // test_core::test_fp8_dot_acc return resultVals; } +#ifdef FLAGTREE_SPEC_ElementwiseOpConversionBase_maybeDeduplicate if (isa(baseEncoding)) { // TODO: this logic seems incorrect for mma layout. Skip for now. // The following test crashes and some other miscompile: // test_core::test_fp8_dot_acc return resultVals; } +#endif SmallVector elemsPerThread = getElemsPerThread(rtType); int rank = elemsPerThread.size(); @@ -188,7 +192,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern { // element type auto resultElementTy = getElementTypeOrSelf(resultTy); Type elemTy = this->getTypeConverter()->convertType(resultElementTy); -#ifdef __ILUVATAR__ +#ifdef FLAGTREE_SPEC_ElementwiseOpConversionBase_matchAndRewrite auto srcType = this->getTypeConverter()->convertType(resultTy); if (auto structTy = dyn_cast(srcType)) elemTy = structTy.getBody()[0]; diff --git a/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h index 380c8cc1d..1da64cb8e 100644 --- a/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h +++ b/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -3,6 +3,8 @@ #include "triton/Conversion/MLIRTypes.h" +#include "triton/../../backend/flagtree_backend_specialization/include/flagtree_spec.h" + namespace mlir::triton { class TargetInfoBase { public: @@ -13,10 +15,18 @@ class TargetInfoBase { virtual Value ballot(ConversionPatternRewriter &rewriter, Location loc, Type type, Value cmp) const = 0; +#ifndef FLAGTREE_SPEC_TargetInfoBase_function + virtual void storeShared(ConversionPatternRewriter &rewriter, Location loc, + Value ptr, Value val, Value pred) const = 0; + virtual Value loadShared(ConversionPatternRewriter &rewriter, Location loc, + const TypeConverter *converter, Value ptr, + Type elemTy, Value pred) const = 0; +#else virtual Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, Value val, Value pred) const = 0; virtual Value loadShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, Type elemTy, Value pred) const = 0; +#endif virtual Value shuffleXor(ConversionPatternRewriter &rewriter, Location loc, Value val, int i) const = 0; diff --git a/third_party/iluvatar/lib/Analysis/AxisInfo.cpp b/third_party/iluvatar/lib/Analysis/AxisInfo.cpp index d0a39ba83..c008fbc0f 100644 --- a/third_party/iluvatar/lib/Analysis/AxisInfo.cpp +++ b/third_party/iluvatar/lib/Analysis/AxisInfo.cpp @@ -1299,19 +1299,17 @@ void AxisInfoAnalysis::visitForOpInductionVar( } // anonymous namespace +#ifndef FLAGTREE_SPEC_AxisInfo_initPessimisticStateFromFunc_ARG template void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp, DimVectorT *contiguity, DimVectorT *divisibility, - DimVectorT *constancy, - DimVectorT *corexFlag) { + DimVectorT *constancy) { // liast of attributes that we care about SmallVector> retVecs; retVecs.push_back({contiguity, "tt.contiguity"}); retVecs.push_back({divisibility, "tt.divisibility"}); retVecs.push_back({constancy, "tt.constancy"}); - retVecs.push_back({corexFlag, "tt.corex_stride"}); - // initialize attributes one by one for (auto [vec, attrName] : retVecs) { Attribute attr = funcOp.getArgAttr(argNumber, attrName); @@ -1323,6 +1321,7 @@ void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp, } } } +#endif /*static*/ AxisInfo AxisInfo::getPessimisticValueState(Value value) { auto rank = 1; diff --git a/third_party/iluvatar/lib/Analysis/CMakeLists.txt b/third_party/iluvatar/lib/Analysis/CMakeLists.txt index 12deb6143..d41b2aed5 100644 --- a/third_party/iluvatar/lib/Analysis/CMakeLists.txt +++ b/third_party/iluvatar/lib/Analysis/CMakeLists.txt @@ -15,4 +15,5 @@ add_triton_library(TritonAnalysis TritonIR TritonGPUIR #TritonNvidiaGPUIR + FlagTree_${FLAGTREE_BACKEND}_TritonAnalysis ) diff --git a/third_party/iluvatar/lib/Analysis/Utility.cpp b/third_party/iluvatar/lib/Analysis/Utility.cpp index 232edc6e2..251274231 100644 --- a/third_party/iluvatar/lib/Analysis/Utility.cpp +++ b/third_party/iluvatar/lib/Analysis/Utility.cpp @@ -676,28 +676,6 @@ bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { #endif } -bool isMmaToDotSlowShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { - - auto srcLayout = srcTy.getEncoding(); - auto dstLayout = dstTy.getEncoding(); - if (!srcLayout.isa()) - return false; - auto mmaLayout = srcLayout.cast(); - if (!dstLayout.isa()) - return false; - auto dotOperandLayout = dstLayout.cast(); - auto dstParLayout = dotOperandLayout.getParent(); - if (!dstParLayout.isa()) - return false; - auto dstMmaLayout = - dstParLayout.dyn_cast(); - return !isMmaToDotShortcut(srcTy, dstTy) && - mmaLayout.getVersionMajor() == 1 && - dstMmaLayout.getVersionMajor() == 1 && - mmaLayout.getWarpsPerCTA()[0] == dstMmaLayout.getWarpsPerCTA()[0] && - dotOperandLayout.getOpIdx() == 0 && !srcTy.getElementType().isF32(); -} - namespace { /// A data structure similar to SetVector but maintains @@ -830,63 +808,10 @@ multiRootTopologicalSort(const SetVector &toSort) { return res; } -#ifdef __ILUVATAR__ -void getBackwardSliceImplCorex(Operation *op, - SetVector *backwardSlice, - TransitiveFilter filter, - bool omitBlockArguments) { - if (!op || op->hasTrait()) - return; - - // Evaluate whether we should keep this def. - // This is useful in particular to implement scoping; i.e. return the - // transitive backwardSlice in the current scope. - if (filter && !filter(op)) - return; - - for (const auto &en : llvm::enumerate(op->getOperands())) { - auto operand = en.value(); - if (auto *definingOp = operand.getDefiningOp()) { - if (backwardSlice->count(definingOp) == 0) - getBackwardSliceImplCorex(definingOp, backwardSlice, filter, - omitBlockArguments); - } else if (auto blockArg = operand.dyn_cast()) { - if (omitBlockArguments) - continue; - - Block *block = blockArg.getOwner(); - Operation *parentOp = block->getParentOp(); - // TODO: determine whether we want to recurse backward into the other - // blocks of parentOp, which are not technically backward unless they flow - // into us. For now, just bail. - if (parentOp && backwardSlice->count(parentOp) == 0) { - // assert(parentOp->getNumRegions() == 1 && - // parentOp->getRegion(0).getBlocks().size() == 1); - getBackwardSliceImplCorex(parentOp, backwardSlice, filter, - omitBlockArguments); - } - } else { - llvm_unreachable("No definingOp and not a block argument."); - } - } - - backwardSlice->insert(op); -} - -void getBackwardSliceCorex(Operation *op, SetVector *backwardSlice, - TransitiveFilter filter, bool omitBlockArguments) { - getBackwardSliceImplCorex(op, backwardSlice, filter, omitBlockArguments); - - // Don't insert the top level operation, we just queried on it and don't - // want it in the results. - backwardSlice->remove(op); -} -#endif - +#ifndef FLAGTREE_SPEC_Utility_multiRootGetSlice_ARG SetVector multiRootGetSlice(Operation *op, TransitiveFilter backwardFilter, - TransitiveFilter forwardFilter, - bool omitBlockArguments) { + TransitiveFilter forwardFilter) { SetVector slice; slice.insert(op); @@ -900,12 +825,7 @@ SetVector multiRootGetSlice(Operation *op, BackwardSliceOptions opt; opt.omitBlockArguments = true; opt.filter = backwardFilter; -#ifdef __ILUVATAR__ - getBackwardSliceCorex(currentOp, &backwardSlice, opt.filter, - opt.omitBlockArguments); -#elif getBackwardSlice(currentOp, &backwardSlice, opt); -#endif slice.insert(backwardSlice.begin(), backwardSlice.end()); // Compute and insert the forwardSlice starting from currentOp. @@ -916,6 +836,7 @@ SetVector multiRootGetSlice(Operation *op, } return multiRootTopologicalSort(slice); } +#endif namespace { // Copied from TestDeadCodeAnalysis.cpp, because some dead code analysis