Skip to content

Commit 7fc0d2b

Browse files
authored
Improve axis analysis to handle tt.make_tensor_ptr (#2448)
The upstream axis analysis doesn't handle blocked pointers. This PR creates an intel version of the analysis and adds support for the `tt.make_tensor_ptr` and `tt.advance` operations along with an additional unit test. --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent ca7e8ba commit 7fc0d2b

File tree

3 files changed

+80
-8
lines changed

3 files changed

+80
-8
lines changed

test/Analysis/intel/test-alignment.mlir renamed to test/Analysis/intel/test-axis-info.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -876,3 +876,18 @@ module {
876876
tt.return %int_min : i64
877877
}
878878
}
879+
880+
// -----
881+
882+
// CHECK-LABEL: @make_tensor_ptr
883+
tt.func public @make_tensor_ptr(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f8E5M2> {tt.divisibility = 32 : i32}, %arg2: i64 {tt.divisibility = 16 : i32}) {
884+
%c0_i32 = arith.constant 0 : i32
885+
%c1_i64 = arith.constant 1 : i64
886+
%c32_i64 = arith.constant 32 : i64
887+
%c128_i64 = arith.constant 128 : i64
888+
// CHECK: %0 = tt.make_tensor_ptr %arg0, {{.*}} => contiguity = [128, 32], divisibility = [1, 1], constancy = [1, 1], constant_value = <none>
889+
%0 = tt.make_tensor_ptr %arg0, [%c128_i64, %c32_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : !tt.ptr<tensor<128x32xf16>>
890+
// CHECK: %1 = tt.make_tensor_ptr %arg1, {{.*}} => contiguity = [32, 1], divisibility = [16, 1], constancy = [1, 1], constant_value = <none>
891+
%1 = tt.make_tensor_ptr %arg1, [%c32_i64, %c32_i64], [%c1_i64, %arg2], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<64x16xf8E5M2>>
892+
tt.return
893+
}

test/lib/Analysis/intel/TestAxisInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ struct TestAxisInfoPass
1313

1414
StringRef getArgument() const final { return "test-print-axis-info"; }
1515
StringRef getDescription() const final {
16-
return "print the result of the alignment analysis pass";
16+
return "print the result of the axis analysis pass";
1717
}
1818

1919
void runOnOperation() override {

third_party/intel/lib/Analysis/AxisInfo.cpp

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ int64_t multiplyDivisor(int64_t lhs, int64_t rhs) {
5050
return lhs * rhs;
5151
}
5252

53+
RankedTensorType getRankedTensorType(Type ptrTy) {
54+
return isTensorPointerType(ptrTy)
55+
? cast<RankedTensorType>(cast<PointerType>(ptrTy).getPointeeType())
56+
: dyn_cast<RankedTensorType>(ptrTy);
57+
}
58+
5359
class AxisInfoVisitor {
5460
public:
5561
AxisInfoVisitor() = default;
@@ -409,7 +415,7 @@ class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
409415

410416
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
411417
int dim) override {
412-
auto resTy = dyn_cast<RankedTensorType>(op.getType());
418+
auto resTy = getRankedTensorType(op.getType());
413419
if (!resTy)
414420
return BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
415421
auto shape = resTy.getShape();
@@ -464,7 +470,7 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
464470
private:
465471
int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
466472
int dim) override {
467-
auto resTy = dyn_cast<RankedTensorType>(op.getType());
473+
auto resTy = getRankedTensorType(op.getType());
468474
if (!resTy)
469475
return BinaryOpVisitorImpl<OpTy>::getContiguity(op, lhs, rhs, dim);
470476
auto shape = resTy.getShape();
@@ -498,7 +504,7 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
498504

499505
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
500506
int dim) override {
501-
auto resTy = dyn_cast<RankedTensorType>(op.getType());
507+
auto resTy = getRankedTensorType(op.getType());
502508
if (!resTy)
503509
return BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
504510
auto shape = resTy.getShape();
@@ -647,7 +653,7 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
647653
AxisInfo
648654
getAxisInfo(OpTy op,
649655
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
650-
auto resTy = dyn_cast<RankedTensorType>(op.getType());
656+
auto resTy = getRankedTensorType(op.getType());
651657
if (!resTy)
652658
return AxisInfo();
653659
auto shape = resTy.getShape();
@@ -995,6 +1001,55 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
9951001
}
9961002
};
9971003

1004+
class MakeTensorPtrOpAxisInfoVisitor final
1005+
: public AxisInfoVisitorImpl<triton::MakeTensorPtrOp> {
1006+
public:
1007+
using AxisInfoVisitorImpl<triton::MakeTensorPtrOp>::AxisInfoVisitorImpl;
1008+
1009+
AxisInfo
1010+
getAxisInfo(triton::MakeTensorPtrOp op,
1011+
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
1012+
LDBG("MakeTensorPtrOpAxisInfoVisitor: " << *op);
1013+
assert(op.getShape().size() == 2 && operands.size() == 7 &&
1014+
"MakeTensorPtrOp should have 2D shape");
1015+
1016+
AxisInfo ptrInfo = operands[0]->getValue();
1017+
AxisInfo shapeInfo0 = operands[1]->getValue();
1018+
AxisInfo shapeInfo1 = operands[2]->getValue();
1019+
AxisInfo strideInfo0 = operands[3]->getValue();
1020+
AxisInfo strideInfo1 = operands[4]->getValue();
1021+
1022+
std::optional<int64_t> shape0 = shapeInfo0.getConstantValue();
1023+
std::optional<int64_t> shape1 = shapeInfo1.getConstantValue();
1024+
std::optional<int64_t> stride0 = strideInfo0.getConstantValue();
1025+
std::optional<int64_t> stride1 = strideInfo1.getConstantValue();
1026+
1027+
AxisInfo::DimVectorT contiguity{
1028+
shape0.has_value() && (stride0 == 1) ? shape0.value() : 1,
1029+
shape1.has_value() && (stride1 == 1) ? shape1.value() : 1};
1030+
1031+
int64_t ptrDivisibility = ptrInfo.getDivisibility()[0];
1032+
int64_t strideDivisibility0 = strideInfo0.getDivisibility()[0];
1033+
int64_t strideDivisibility1 = strideInfo1.getDivisibility()[0];
1034+
1035+
LDBG("ptrDivisibility: " << ptrDivisibility);
1036+
LDBG("strideDivisibility0: " << strideDivisibility0);
1037+
LDBG("strideDivisibility1: " << strideDivisibility1);
1038+
1039+
AxisInfo::DimVectorT divisibility{1, 1};
1040+
if (ptrDivisibility > 1) {
1041+
if (contiguity[0] > 1)
1042+
divisibility[0] = std::min(ptrDivisibility, strideDivisibility1);
1043+
if (contiguity[1] > 1)
1044+
divisibility[1] = std::min(ptrDivisibility, strideDivisibility0);
1045+
}
1046+
1047+
AxisInfo::DimVectorT constancy{1, 1};
1048+
1049+
return AxisInfo(contiguity, divisibility, constancy);
1050+
}
1051+
};
1052+
9981053
//===----------------------------------------------------------------------===//
9991054
// AxisInfoAnalysis
10001055
//===----------------------------------------------------------------------===//
@@ -1042,11 +1097,13 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
10421097
MaxMinOpAxisInfoVisitor<arith::MinSIOp>,
10431098
MaxMinOpAxisInfoVisitor<arith::MinUIOp>>();
10441099
visitors.append<LoadOpAxisInfoVisitor>();
1100+
visitors.append<MakeTensorPtrOpAxisInfoVisitor>();
10451101
}
10461102

10471103
LogicalResult AxisInfoAnalysis::visitOperation(
10481104
Operation *op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
10491105
ArrayRef<dataflow::Lattice<AxisInfo> *> results) {
1106+
LDBG("visitOperation: << " << *op);
10501107
// TODO: For sure not the right way to do this
10511108
// but why is scf.if not initialized otherwise?
10521109
for (auto op : operands)
@@ -1204,7 +1261,7 @@ void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp,
12041261
}
12051262

12061263
unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
1207-
auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
1264+
auto tensorTy = getRankedTensorType(ptr.getType());
12081265
if (!tensorTy)
12091266
return 1;
12101267
auto layout = tensorTy.getEncoding();
@@ -1226,7 +1283,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
12261283
}
12271284

12281285
unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
1229-
auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
1286+
auto tensorTy = getRankedTensorType(ptr.getType());
12301287
if (!tensorTy)
12311288
return 1;
12321289
auto *axisInfo = getAxisInfo(ptr);
@@ -1254,7 +1311,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
12541311
}
12551312

12561313
unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
1257-
auto tensorTy = dyn_cast<RankedTensorType>(mask.getType());
1314+
auto tensorTy = getRankedTensorType(mask.getType());
12581315
if (!tensorTy)
12591316
return 1;
12601317
auto *axisInfo = getAxisInfo(mask);

0 commit comments

Comments
 (0)