Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ class ConversionPatternRewriter;

namespace mlir::triton::gpu::intel {

// If the given type is a pointer of tensors, return the pointee type.
// Otherwise, attempt to cast the given type to a ranked tensor and return the
// dynamic cast result.
RankedTensorType getRankedTensorType(Type type);

// Check if given value is divisible by the divisor.
bool isDivisible(Value value, unsigned divisor);

Expand Down
23 changes: 10 additions & 13 deletions third_party/intel/lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
#include "llvm/Support/raw_ostream.h"

#include "intel/include/Analysis/AxisInfo.h"
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

#define DEBUG_TYPE "intel-axis-info"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

namespace mlir::triton::intel {

namespace ttgi = mlir::triton::gpu::intel;
namespace {

int64_t gcdImpl(int64_t a, int64_t b, int64_t *x, int64_t *y) {
Expand Down Expand Up @@ -50,12 +53,6 @@ int64_t multiplyDivisor(int64_t lhs, int64_t rhs) {
return lhs * rhs;
}

RankedTensorType getRankedTensorType(Type ptrTy) {
return isTensorPointerType(ptrTy)
? cast<RankedTensorType>(cast<PointerType>(ptrTy).getPointeeType())
: dyn_cast<RankedTensorType>(ptrTy);
}

class AxisInfoVisitor {
public:
AxisInfoVisitor() = default;
Expand Down Expand Up @@ -415,7 +412,7 @@ class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {

int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
int dim) override {
auto resTy = getRankedTensorType(op.getType());
auto resTy = ttgi::getRankedTensorType(op.getType());
if (!resTy)
return BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
auto shape = resTy.getShape();
Expand Down Expand Up @@ -470,7 +467,7 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
private:
int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
int dim) override {
auto resTy = getRankedTensorType(op.getType());
auto resTy = ttgi::getRankedTensorType(op.getType());
if (!resTy)
return BinaryOpVisitorImpl<OpTy>::getContiguity(op, lhs, rhs, dim);
auto shape = resTy.getShape();
Expand Down Expand Up @@ -504,7 +501,7 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {

int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
int dim) override {
auto resTy = getRankedTensorType(op.getType());
auto resTy = ttgi::getRankedTensorType(op.getType());
if (!resTy)
return BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
auto shape = resTy.getShape();
Expand Down Expand Up @@ -653,7 +650,7 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
AxisInfo
getAxisInfo(OpTy op,
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
auto resTy = getRankedTensorType(op.getType());
auto resTy = ttgi::getRankedTensorType(op.getType());
if (!resTy)
return AxisInfo();
auto shape = resTy.getShape();
Expand Down Expand Up @@ -1265,7 +1262,7 @@ void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp,
}

unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
auto tensorTy = getRankedTensorType(ptr.getType());
auto tensorTy = ttgi::getRankedTensorType(ptr.getType());
if (!tensorTy)
return 1;
auto layout = tensorTy.getEncoding();
Expand All @@ -1287,7 +1284,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
}

unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
auto tensorTy = getRankedTensorType(ptr.getType());
auto tensorTy = ttgi::getRankedTensorType(ptr.getType());
if (!tensorTy)
return 1;
auto *axisInfo = getAxisInfo(ptr);
Expand Down Expand Up @@ -1315,7 +1312,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
}

unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
auto tensorTy = getRankedTensorType(mask.getType());
auto tensorTy = ttgi::getRankedTensorType(mask.getType());
if (!tensorTy)
return 1;
auto *axisInfo = getAxisInfo(mask);
Expand Down
9 changes: 8 additions & 1 deletion third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,16 @@ namespace ttgi = mlir::triton::gpu::intel;

namespace mlir::triton::gpu::intel {

RankedTensorType getRankedTensorType(Type ptrTy) {
return tt::isTensorPointerType(ptrTy)
? cast<RankedTensorType>(
cast<tt::PointerType>(ptrTy).getPointeeType())
: dyn_cast<RankedTensorType>(ptrTy);
}

static bool isSingleValue(Value value) {
// Don't consider load as expensive if it is loading a scalar.
if (auto tensorTy = dyn_cast<RankedTensorType>(value.getType()))
if (auto tensorTy = getRankedTensorType(value.getType()))
return tensorTy.getNumElements() == 1;
// TODO: Handle other cases.
// For example, when ptr is a tensor of single value.
Expand Down