diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h index 1d7c15b7e2..70b0f9acf2 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h @@ -18,6 +18,11 @@ class ConversionPatternRewriter; namespace mlir::triton::gpu::intel { +// If the given type is a pointer of tensors, return the pointee type. +// Otherwise, attempt to cast the given type to a ranked tensor and return the +// dynamic cast result. +RankedTensorType getRankedTensorType(Type type); + // Check if given value is divisible by the divisor. bool isDivisible(Value value, unsigned divisor); diff --git a/third_party/intel/lib/Analysis/AxisInfo.cpp b/third_party/intel/lib/Analysis/AxisInfo.cpp index 567c01a24d..addf8c0b82 100644 --- a/third_party/intel/lib/Analysis/AxisInfo.cpp +++ b/third_party/intel/lib/Analysis/AxisInfo.cpp @@ -4,6 +4,7 @@ #include "llvm/Support/raw_ostream.h" #include "intel/include/Analysis/AxisInfo.h" +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #define DEBUG_TYPE "intel-axis-info" @@ -11,6 +12,8 @@ #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") namespace mlir::triton::intel { + +namespace ttgi = mlir::triton::gpu::intel; namespace { int64_t gcdImpl(int64_t a, int64_t b, int64_t *x, int64_t *y) { @@ -50,12 +53,6 @@ int64_t multiplyDivisor(int64_t lhs, int64_t rhs) { return lhs * rhs; } -RankedTensorType getRankedTensorType(Type ptrTy) { - return isTensorPointerType(ptrTy) - ? cast(cast(ptrTy).getPointeeType()) - : dyn_cast(ptrTy); -} - class AxisInfoVisitor { public: AxisInfoVisitor() = default; @@ -415,7 +412,7 @@ class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl { int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { - auto resTy = getRankedTensorType(op.getType()); + auto resTy = ttgi::getRankedTensorType(op.getType()); if (!resTy) return BinaryOpVisitorImpl::getConstancy(op, lhs, rhs, dim); auto shape = resTy.getShape(); @@ -470,7 +467,7 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl { private: int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { - auto resTy = getRankedTensorType(op.getType()); + auto resTy = ttgi::getRankedTensorType(op.getType()); if (!resTy) return BinaryOpVisitorImpl::getContiguity(op, lhs, rhs, dim); auto shape = resTy.getShape(); @@ -504,7 +501,7 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl { int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { - auto resTy = getRankedTensorType(op.getType()); + auto resTy = ttgi::getRankedTensorType(op.getType()); if (!resTy) return BinaryOpVisitorImpl::getConstancy(op, lhs, rhs, dim); auto shape = resTy.getShape(); @@ -653,7 +650,7 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl { AxisInfo getAxisInfo(OpTy op, ArrayRef *> operands) override { - auto resTy = getRankedTensorType(op.getType()); + auto resTy = ttgi::getRankedTensorType(op.getType()); if (!resTy) return AxisInfo(); auto shape = resTy.getShape(); @@ -1265,7 +1262,7 @@ void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp, } unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) { - auto tensorTy = getRankedTensorType(ptr.getType()); + auto tensorTy = ttgi::getRankedTensorType(ptr.getType()); if (!tensorTy) return 1; auto layout = tensorTy.getEncoding(); @@ -1287,7 +1284,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) { } unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) { - auto tensorTy = getRankedTensorType(ptr.getType()); + auto tensorTy = ttgi::getRankedTensorType(ptr.getType()); if (!tensorTy) return 1; auto *axisInfo = getAxisInfo(ptr); @@ -1315,7 +1312,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) { } unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) { - auto tensorTy = getRankedTensorType(mask.getType()); + auto tensorTy = ttgi::getRankedTensorType(mask.getType()); if (!tensorTy) return 1; auto *axisInfo = getAxisInfo(mask); diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp index 4e921ca7fc..759fc1782d 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp @@ -25,9 +25,16 @@ namespace ttgi = mlir::triton::gpu::intel; namespace mlir::triton::gpu::intel { +RankedTensorType getRankedTensorType(Type ptrTy) { + return tt::isTensorPointerType(ptrTy) + ? cast( + cast(ptrTy).getPointeeType()) + : dyn_cast(ptrTy); +} + static bool isSingleValue(Value value) { // Don't consider load as expensive if it is loading a scalar. - if (auto tensorTy = dyn_cast(value.getType())) + if (auto tensorTy = getRankedTensorType(value.getType())) return tensorTy.getNumElements() == 1; // TODO: Handle other cases. // For example, when ptr is a tensor of single value.