From 0d82630a5c258a75544d4cfeb70142af1bb996fd Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Wed, 30 Oct 2024 20:09:02 +0000 Subject: [PATCH 1/3] Make intel AxisInfo analysis derive from upstream implementation Signed-off-by: Tiotto, Ettore --- third_party/intel/include/Analysis/AxisInfo.h | 168 ++---------------- .../include/Dialect/TritonIntelGPU/IR/Utils.h | 2 +- third_party/intel/lib/Analysis/AxisInfo.cpp | 116 +----------- .../PatternTritonGPUOpToLLVM.h | 2 +- .../TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp | 4 +- 5 files changed, 22 insertions(+), 270 deletions(-) diff --git a/third_party/intel/include/Analysis/AxisInfo.h b/third_party/intel/include/Analysis/AxisInfo.h index 1fbaba2e0c..c159db9785 100644 --- a/third_party/intel/include/Analysis/AxisInfo.h +++ b/third_party/intel/include/Analysis/AxisInfo.h @@ -1,157 +1,10 @@ #ifndef TRITON_INTEL_ANALYSIS_AXISINFO_H #define TRITON_INTEL_ANALYSIS_AXISINFO_H -#include "mlir/Analysis/DataFlow/SparseAnalysis.h" -#include "llvm/Support/raw_ostream.h" - -#include "mlir/Support/LLVM.h" -#include "triton/Analysis/Utility.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/Triton/IR/Utility.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" - -#include +#include "triton/Analysis/AxisInfo.h" namespace mlir::triton::intel { -//===----------------------------------------------------------------------===// -// AxisInfo -//===----------------------------------------------------------------------===// - -/// This lattice value represents known information on the axes of a lattice. -class AxisInfo { -public: - typedef SmallVector DimVectorT; - -public: - AxisInfo() : AxisInfo({}, {}, {}) {} - - AxisInfo(const DimVectorT &contiguity, const DimVectorT &divisibility, - const DimVectorT &constancy) - : AxisInfo(contiguity, divisibility, constancy, std::nullopt) {} - - AxisInfo(const DimVectorT &contiguity, const DimVectorT &divisibility, - const DimVectorT &constancy, std::optional constantValue) - : contiguity(contiguity), divisibility(divisibility), - constancy(constancy), constantValue(constantValue) { - assert(divisibility.size() == contiguity.size()); - assert(constancy.size() == contiguity.size()); - } - - // contiguity[d] is the length of the shortest sequence of contiguous integers - // along dimension d. - // - // If we have an array of N elements with a contiguity value C, then the array - // can be divided into a list of N/C sequences of C contiguous elements. - // Since we have N = 2^k, C must be a power of two. - // - // For example, the 2D array - // - // [[10, 11, 12, 13, 18, 19, 20, 21], - // [20, 21, 22, 23, 28, 29, 30, 31]] - // - // has contiguity [1, 4], and - // - // [[12, 16, 20, 24], - // [13, 17, 21, 25], - // [14, 18, 22, 26], - // [15, 19, 23, 27], - // [18, 22, 26, 30], - // [19, 23, 27, 31]] - // - // has contiguity [2, 1]. - int64_t getContiguity(size_t dim) const { return contiguity[dim]; } - const DimVectorT &getContiguity() const { return contiguity; } - - // divisibility[d] is the largest power of two that divides the first element - // of all groups of length contiguity[d] along dimension d. - // - // For example, - // - // [[10, 11, 12, 13, 18, 19, 20, 21], - // [20, 21, 22, 23, 28, 29, 30, 31]] - // - // has divisibility [1, 2], and - // - // [[12, 16, 20, 24], - // [13, 17, 21, 25], - // [14, 18, 22, 26], - // [15, 19, 23, 27]] - // - // has divisibility [4, 1]. - // - // On the other hand, - // - // [0, 1, 2, 0, 4, 5, 6, 7] - // - // has divisibility 1 because its contiguity is 1. - int64_t getDivisibility(size_t dim) const { return divisibility[dim]; } - const DimVectorT &getDivisibility() const { return divisibility; } - - // constancy[d] is the length of the shortest sequence of repeating integers - // along dimension d. - // - // This is particularly useful to infer the contiguity of operations (e.g. - // add) involving a constant. - // - // If we have an array of N elements, with a constancy value C, then the array - // can be divided into a list of N/C sequences of C elements with the same - // value. Since we have N = 2^k, C must be a power of two. - // - // For example - // - // [[8, 8, 8, 8, 12, 12, 12, 12], - // [16, 16, 16, 16, 20, 20, 20, 20]] - // - // has constancy [1, 4]. - int64_t getConstancy(size_t dim) const { return constancy[dim]; } - const DimVectorT &getConstancy() const { return constancy; } - - int getRank() const { return contiguity.size(); } - - std::optional getConstantValue() const { return constantValue; } - - template - static void - initPessimisticStateFromFunc(int argNumber, T funcOp, DimVectorT *contiguity, - DimVectorT *divisibility, DimVectorT *constancy); - - bool operator==(const AxisInfo &other) const { - return contiguity == other.contiguity && - divisibility == other.divisibility && constancy == other.constancy && - constantValue == other.constantValue; - } - - static AxisInfo getPessimisticValueState(Value value); - - // The gcd of both arguments for each dimension - static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs); - - void print(raw_ostream &os) const { - auto print = [&](StringRef name, DimVectorT vec) { - os << name << " = ["; - llvm::interleaveComma(vec, os); - os << "]"; - }; - print("contiguity", contiguity); - print(", divisibility", divisibility); - print(", constancy", constancy); - os << ", constant_value = "; - if (constantValue) - os << *constantValue; - else - os << ""; - } - -private: - DimVectorT contiguity; - DimVectorT divisibility; - DimVectorT constancy; - - // The constant value of the lattice if we can infer it. - std::optional constantValue; -}; - // Module level axis info analysis based on the call graph, assuming that we do // not have recursive functions. // @@ -159,11 +12,13 @@ class AxisInfo { // axis info based on the axis info of all the callers. In the future, we can // perform optimization using function cloning so that each call site will have // unique axis info. -using AxisInfoMapT = DenseMap; -class ModuleAxisInfoAnalysis : public CallGraph { +// using AxisInfoMapT = DenseMap; +class ModuleAxisInfoAnalysis : public triton::ModuleAxisInfoAnalysis { public: explicit ModuleAxisInfoAnalysis(ModuleOp moduleOp) - : CallGraph(moduleOp) { + : triton::ModuleAxisInfoAnalysis(moduleOp) { + funcMap.clear(); + SmallVector funcs; for (auto root : getRoots()) { walk( @@ -187,10 +42,11 @@ class ModuleAxisInfoAnalysis : public CallGraph { } } - AxisInfo *getAxisInfo(Value value) { + AxisInfo *getAxisInfo(Value value) const { auto funcOp = value.getParentRegion()->getParentOfType(); - auto *axisInfoMap = getFuncData(funcOp); + auto *axisInfoMap = + const_cast(this)->getFuncData(funcOp); if (!axisInfoMap) { return nullptr; } @@ -201,9 +57,9 @@ class ModuleAxisInfoAnalysis : public CallGraph { return &(it->second); } - unsigned getPtrContiguity(Value ptr); - unsigned getPtrAlignment(Value ptr); - unsigned getMaskAlignment(Value mask); + unsigned getPtrContiguity(Value ptr) const; + unsigned getPtrAlignment(Value ptr) const; + unsigned getMaskAlignment(Value mask) const; private: void initialize(FunctionOpInterface funcOp); diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h b/third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h index 6357d4a8c2..7c813a64fa 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h @@ -28,7 +28,7 @@ inline unsigned getNumElementsPerThread( ? cast(cast(valTy).getPointeeType()) : cast(valTy); auto shapePerCTA = getShapePerCTA(ty); - mlir::triton::intel::AxisInfo &valInfo = *axisInfoAnalysis.getAxisInfo(val); + mlir::triton::AxisInfo &valInfo = *axisInfoAnalysis.getAxisInfo(val); unsigned elemNumBits = getElementBitWidth(ty); unsigned elemNumBytes = std::max(elemNumBits / 8, 1u); diff --git a/third_party/intel/lib/Analysis/AxisInfo.cpp b/third_party/intel/lib/Analysis/AxisInfo.cpp index 378ba01442..6d31af31b7 100644 --- a/third_party/intel/lib/Analysis/AxisInfo.cpp +++ b/third_party/intel/lib/Analysis/AxisInfo.cpp @@ -1159,113 +1159,7 @@ void AxisInfoAnalysis::visitForOpInductionVar( } // anonymous namespace -template -void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp, - DimVectorT *contiguity, - DimVectorT *divisibility, - DimVectorT *constancy) { - // liast of attributes that we care about - SmallVector> retVecs; - retVecs.push_back({contiguity, "tt.contiguity"}); - retVecs.push_back({divisibility, "tt.divisibility"}); - retVecs.push_back({constancy, "tt.constancy"}); - // initialize attributes one by one - for (auto [vec, attrName] : retVecs) { - Attribute attr = funcOp.getArgAttr(argNumber, attrName); - if (auto int_attr = dyn_cast_or_null(attr)) - *vec = DimVectorT(contiguity->size(), int_attr.getValue().getZExtValue()); - if (auto dense_attr = dyn_cast_or_null(attr)) { - auto vals = dense_attr.getValues(); - *vec = DimVectorT(vals.begin(), vals.end()); - } - } -} - -/*static*/ AxisInfo AxisInfo::getPessimisticValueState(Value value) { - auto rank = 1; - if (TensorType ty = dyn_cast(value.getType())) - rank = ty.getRank(); - if (triton::PointerType ty = dyn_cast(value.getType())) - if (TensorType elemTy = dyn_cast(ty.getPointeeType())) - rank = elemTy.getRank(); - - DimVectorT knownContiguity(rank, 1); - DimVectorT knownDivisibility(rank, 1); - DimVectorT knownConstancy(rank, 1); - - BlockArgument blockArg = dyn_cast(value); - - if (blockArg && blockArg.getOwner()->isEntryBlock()) { - Operation *op = blockArg.getOwner()->getParentOp(); - if (auto fun = dyn_cast(op)) - initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, - &knownContiguity, &knownDivisibility, - &knownConstancy); - // llvm codegen check alignment to generate vector load/store - // would be nice if this wasn't the case - else if (auto fun = dyn_cast(op)) - initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, - &knownContiguity, &knownDivisibility, - &knownConstancy); - else if (isa(op)) { - // scf::ForOp, scf::IfOp, scf::WhileOp - // Control flow operations are initialized with "unknown" state: - // the maximum possible divisibility, contiguity, and constancy. - knownDivisibility = DimVectorT(rank, highestPowOf2Divisor(0)); - knownConstancy = DimVectorT(rank, highestPowOf2Divisor(0)); - knownContiguity = DimVectorT(rank, highestPowOf2Divisor(0)); - } - } else if (Operation *op = value.getDefiningOp()) { - if (isa(op)) { - // scf::ForOp, scf::IfOp, scf::WhileOp - // Control flow operations are initialized with "unknown" state: - // the maximum possible divisibility, contiguity, and constancy. - knownDivisibility = DimVectorT(rank, highestPowOf2Divisor(0)); - knownConstancy = DimVectorT(rank, highestPowOf2Divisor(0)); - knownContiguity = DimVectorT(rank, highestPowOf2Divisor(0)); - } - // Other operations are conservatively initialized with the lowest possible - // divisibility, contiguity, and constancy unless they have specified. - if (Attribute attr = op->getDiscardableAttr("tt.divisibility")) { - auto vals = cast(attr).getValues(); - knownDivisibility = DimVectorT(vals.begin(), vals.end()); - } - if (Attribute attr = op->getDiscardableAttr("tt.contiguity")) { - auto vals = cast(attr).getValues(); - knownContiguity = DimVectorT(vals.begin(), vals.end()); - } - if (Attribute attr = op->getDiscardableAttr("tt.constancy")) { - auto vals = cast(attr).getValues(); - knownConstancy = DimVectorT(vals.begin(), vals.end()); - } - } - - return AxisInfo(knownContiguity, knownDivisibility, knownConstancy); -} - -/*static*/ AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) { - // If one argument is not initialized, return the other. - if (lhs.getRank() == 0) - return rhs; - if (rhs.getRank() == 0) - return lhs; - DimVectorT contiguity; - DimVectorT divisibility; - DimVectorT constancy; - for (auto d = 0; d < lhs.getRank(); ++d) { - contiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d))); - divisibility.push_back(gcd(lhs.getDivisibility(d), rhs.getDivisibility(d))); - constancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d))); - } - std::optional constantValue; - if (lhs.getConstantValue().has_value() && - rhs.getConstantValue().has_value() && - lhs.getConstantValue() == rhs.getConstantValue()) - constantValue = lhs.getConstantValue(); - return AxisInfo(contiguity, divisibility, constancy, constantValue); -} - -unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) { +unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) const { auto tensorTy = ttgi::getRankedTensorType(ptr.getType()); if (!tensorTy) return 1; @@ -1287,7 +1181,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) { return contiguity; } -unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) { +unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) const { auto tensorTy = ttgi::getRankedTensorType(ptr.getType()); if (!tensorTy) return 1; @@ -1298,7 +1192,9 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) { auto order = triton::gpu::getOrder(layout); auto maxMultipleBytes = axisInfo->getDivisibility(order[0]); auto maxContig = axisInfo->getContiguity(order[0]); - auto elemNumBits = triton::getPointeeBitWidth(tensorTy); + unsigned elemNumBits = isTensorPointerType(ptr.getType()) + ? tensorTy.getElementType().getIntOrFloatBitWidth() + : triton::getPointeeBitWidth(tensorTy); auto elemNumBytes = std::max(elemNumBits / 8, 1); auto maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1); unsigned alignment = std::min(maxMultiple, maxContig); @@ -1315,7 +1211,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) { return alignment; } -unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) { +unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) const { auto tensorTy = ttgi::getRankedTensorType(mask.getType()); if (!tensorTy) return 1; diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h index b605776b52..36223e3245 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -3,8 +3,8 @@ #include "TargetInfo.h" #include "TritonGPUToLLVMBase.h" +#include "intel/include/Analysis/AxisInfo.h" #include "intel/include/TritonIntelGPUToLLVM/TypeConverter.h" -#include "triton/Analysis/AxisInfo.h" namespace mlir::triton::intel { diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp index 998b7204cf..3d3bbb3015 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp @@ -7,6 +7,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "intel/include/Analysis/AxisInfo.h" #include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h" #include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" #include "intel/include/GPUToTritonGEN/GPUToTritonGENPass.h" @@ -14,7 +15,6 @@ #include "intel/include/TritonIntelGPUToLLVM/Passes.h" #include "triton/Analysis/Allocation.h" -#include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Membar.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -114,7 +114,7 @@ struct ConvertTritonGPUToLLVM return signalPassFailure(); } - ModuleAxisInfoAnalysis axisInfoAnalysis(mod); + intel::ModuleAxisInfoAnalysis axisInfoAnalysis(mod); OpBuilder::InsertPoint indexInsertPoint; RewritePatternSet patterns(context); From f1909a51c6d572d7da51b46c1b515c12c9bfcdf0 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Thu, 31 Oct 2024 15:01:50 +0000 Subject: [PATCH 2/3] Address code review comments Signed-off-by: Tiotto, Ettore --- third_party/intel/include/Analysis/AxisInfo.h | 11 +++++------ third_party/intel/lib/Analysis/AxisInfo.cpp | 6 +++--- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/third_party/intel/include/Analysis/AxisInfo.h b/third_party/intel/include/Analysis/AxisInfo.h index c159db9785..f6dc0c4260 100644 --- a/third_party/intel/include/Analysis/AxisInfo.h +++ b/third_party/intel/include/Analysis/AxisInfo.h @@ -42,11 +42,10 @@ class ModuleAxisInfoAnalysis : public triton::ModuleAxisInfoAnalysis { } } - AxisInfo *getAxisInfo(Value value) const { + AxisInfo *getAxisInfo(Value value) { auto funcOp = value.getParentRegion()->getParentOfType(); - auto *axisInfoMap = - const_cast(this)->getFuncData(funcOp); + auto *axisInfoMap = getFuncData(funcOp); if (!axisInfoMap) { return nullptr; } @@ -57,9 +56,9 @@ class ModuleAxisInfoAnalysis : public triton::ModuleAxisInfoAnalysis { return &(it->second); } - unsigned getPtrContiguity(Value ptr) const; - unsigned getPtrAlignment(Value ptr) const; - unsigned getMaskAlignment(Value mask) const; + unsigned getPtrContiguity(Value ptr); + unsigned getPtrAlignment(Value ptr); + unsigned getMaskAlignment(Value mask); private: void initialize(FunctionOpInterface funcOp); diff --git a/third_party/intel/lib/Analysis/AxisInfo.cpp b/third_party/intel/lib/Analysis/AxisInfo.cpp index 6d31af31b7..09da088e8c 100644 --- a/third_party/intel/lib/Analysis/AxisInfo.cpp +++ b/third_party/intel/lib/Analysis/AxisInfo.cpp @@ -1159,7 +1159,7 @@ void AxisInfoAnalysis::visitForOpInductionVar( } // anonymous namespace -unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) const { +unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) { auto tensorTy = ttgi::getRankedTensorType(ptr.getType()); if (!tensorTy) return 1; @@ -1181,7 +1181,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) const { return contiguity; } -unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) const { +unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) { auto tensorTy = ttgi::getRankedTensorType(ptr.getType()); if (!tensorTy) return 1; @@ -1211,7 +1211,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) const { return alignment; } -unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) const { +unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) { auto tensorTy = ttgi::getRankedTensorType(mask.getType()); if (!tensorTy) return 1; From 66c24032052f165c02e8c2c1e3e3e96d03c669f2 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Fri, 1 Nov 2024 14:25:29 +0000 Subject: [PATCH 3/3] Address code review comments Signed-off-by: Tiotto, Ettore --- third_party/intel/include/Analysis/AxisInfo.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/intel/include/Analysis/AxisInfo.h b/third_party/intel/include/Analysis/AxisInfo.h index f6dc0c4260..b0b90f7d10 100644 --- a/third_party/intel/include/Analysis/AxisInfo.h +++ b/third_party/intel/include/Analysis/AxisInfo.h @@ -12,7 +12,7 @@ namespace mlir::triton::intel { // axis info based on the axis info of all the callers. In the future, we can // perform optimization using function cloning so that each call site will have // unique axis info. -// using AxisInfoMapT = DenseMap; + class ModuleAxisInfoAnalysis : public triton::ModuleAxisInfoAnalysis { public: explicit ModuleAxisInfoAnalysis(ModuleOp moduleOp)