Skip to content

Commit e797907

Browse files
Annotate tensor of pointer LoadOp with row_major attribute (#3864)
This PR annotates `LoadOp` with `row_major` attribute using `AxisInfo`. --------- Signed-off-by: Whitney Tsang <[email protected]> Co-authored-by: Lu,Chengjun <[email protected]>
1 parent 306051e commit e797907

File tree

5 files changed

+259
-29
lines changed

5 files changed

+259
-29
lines changed

include/triton/Analysis/AxisInfo.h

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,73 @@ class AxisInfo {
2525
typedef SmallVector<int64_t> DimVectorT;
2626

2727
public:
28-
AxisInfo() : AxisInfo({}, {}, {}) {}
28+
AxisInfo() : AxisInfo({}, {}, {}, {}, std::nullopt) {}
2929

3030
AxisInfo(ArrayRef<int64_t> contiguity, ArrayRef<int64_t> divisibility,
3131
ArrayRef<int64_t> constancy)
3232
: AxisInfo(contiguity, divisibility, constancy, std::nullopt) {}
3333

3434
AxisInfo(ArrayRef<int64_t> contiguity, ArrayRef<int64_t> divisibility,
3535
ArrayRef<int64_t> constancy, std::optional<int64_t> constantValue)
36-
: contiguity(contiguity), divisibility(divisibility),
36+
: AxisInfo(AxisInfo::DimVectorT(contiguity.size(), -1), contiguity,
37+
divisibility, constancy, constantValue) {
38+
for (size_t i = 0; i < contiguity.size(); ++i) {
39+
if (contiguity[i] > 1) {
40+
stride[i] = 1;
41+
}
42+
}
43+
}
44+
45+
AxisInfo(ArrayRef<int64_t> stride, ArrayRef<int64_t> contiguity,
46+
ArrayRef<int64_t> divisibility, ArrayRef<int64_t> constancy,
47+
std::optional<int64_t> constantValue)
48+
: stride(stride), contiguity(contiguity), divisibility(divisibility),
3749
constancy(constancy), constantValue(constantValue) {
50+
assert(stride.size() == contiguity.size());
3851
assert(divisibility.size() == contiguity.size());
3952
assert(constancy.size() == contiguity.size());
4053
}
4154

55+
// TODO: Support non compile time constant strides.
56+
// stride[d] is the stride of contiguityWithStride[d] elements along dimension
57+
// d. Value -1 is used to represent the unknown stride.
58+
// For example, the 2D array
59+
//
60+
// [[10, 11, 12, 13, 18, 19, 20, 21],
61+
// [20, 21, 22, 23, 28, 29, 30, 31]]
62+
//
63+
// has stride [10, 1], and
64+
//
65+
// [[12, 16, 20, 24],
66+
// [13, 17, 21, 25],
67+
// [14, 18, 22, 26],
68+
// [15, 19, 23, 27],
69+
// [18, 22, 26, 30],
70+
// [19, 23, 27, 31]]
71+
//
72+
// has stride [1, 4].
73+
int64_t getStride(size_t dim) const { return stride[dim]; }
74+
const DimVectorT &getStride() const { return stride; }
75+
76+
// TODO: Add contiguity with stride.
77+
// contiguityWithStride[d] is the length of the shortest sequence of
78+
// contiguous integers with the same stride along dimension d. For example,
79+
// the 2D array
80+
//
81+
// [[10, 11, 12, 13, 18, 19, 20, 21],
82+
// [20, 21, 22, 23, 28, 29, 30, 31]]
83+
//
84+
// has contiguityWithStride [2, 4], and
85+
//
86+
// [[12, 16, 20, 24],
87+
// [13, 17, 21, 25],
88+
// [14, 18, 22, 26],
89+
// [15, 19, 23, 27],
90+
// [18, 22, 26, 30],
91+
// [19, 23, 27, 31]]
92+
//
93+
// has contiguityWithStride [2, 4].
94+
4295
// contiguity[d] is the length of the shortest sequence of contiguous integers
4396
// along dimension d.
4497
//
@@ -134,7 +187,8 @@ class AxisInfo {
134187
llvm::interleaveComma(vec, os);
135188
os << "]";
136189
};
137-
print("contiguity", contiguity);
190+
print("stride", stride);
191+
print(", contiguity", contiguity);
138192
print(", divisibility", divisibility);
139193
print(", constancy", constancy);
140194
os << ", constant_value = ";
@@ -145,6 +199,7 @@ class AxisInfo {
145199
}
146200

147201
private:
202+
DimVectorT stride;
148203
DimVectorT contiguity;
149204
DimVectorT divisibility;
150205
DimVectorT constancy;

test/Analysis/intel/test-axis-info.mlir

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -702,13 +702,13 @@ tt.func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
702702
%7 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
703703
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
704704
%9 = tt.splat %n_elements : i32 -> tensor<64xi32>
705-
// CHECK: arith.cmpi slt, %{{.*}} => contiguity = [1], divisibility = [1], constancy = [16], constant_value = <none>
705+
// CHECK: arith.cmpi slt, %{{.*}} => stride = [-1], contiguity = [1], divisibility = [1], constancy = [16], constant_value = <none>
706706
%mask = arith.cmpi slt, %4, %9 : tensor<64xi32>
707707
%11 = tt.load %6, %mask : tensor<64x!tt.ptr<f32>>
708708
%12 = tt.load %8, %mask : tensor<64x!tt.ptr<f32>>
709709
%13 = arith.addf %11, %12 : tensor<64xf32>
710710
%14 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
711-
// CHECK: tt.addptr %{{.*}} => contiguity = [64], divisibility = [16], constancy = [1], constant_value = <none>
711+
// CHECK: tt.addptr %{{.*}} => stride = [1], contiguity = [64], divisibility = [16], constancy = [1], constant_value = <none>
712712
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
713713
tt.store %15, %13, %mask : tensor<64x!tt.ptr<f32>>
714714
tt.return
@@ -731,7 +731,7 @@ tt.func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %
731731
%7 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
732732
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
733733
%9 = tt.splat %n_elements : i32 -> tensor<64xi32>
734-
// CHECK: arith.cmpi slt, %{{.*}} => contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
734+
// CHECK: arith.cmpi slt, %{{.*}} => stride = [-1], contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
735735
%10 = arith.cmpi slt, %4, %9 : tensor<64xi32>
736736
%11 = tt.load %6, %10 : tensor<64x!tt.ptr<f32>>
737737
%12 = tt.load %8, %10 : tensor<64x!tt.ptr<f32>>
@@ -885,11 +885,32 @@ tt.func public @make_tensor_ptr(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f8E5M2> {tt.
885885
%c1_i64 = arith.constant 1 : i64
886886
%c32_i64 = arith.constant 32 : i64
887887
%c128_i64 = arith.constant 128 : i64
888-
// CHECK: tt.make_tensor_ptr %arg0, {{.*}} => contiguity = [128, 32], divisibility = [1, 1], constancy = [1, 1], constant_value = <none>
888+
// CHECK: tt.make_tensor_ptr %arg0, {{.*}} => stride = [1, 1], contiguity = [128, 32], divisibility = [1, 1], constancy = [1, 1], constant_value = <none>
889889
%0 = tt.make_tensor_ptr %arg0, [%c128_i64, %c32_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : !tt.ptr<tensor<128x32xf16>>
890-
// CHECK: tt.make_tensor_ptr %arg1, {{.*}} => contiguity = [64, 1], divisibility = [16, 1], constancy = [1, 1], constant_value = <none>
890+
// CHECK: tt.make_tensor_ptr %arg1, {{.*}} => stride = [1, -1], contiguity = [64, 1], divisibility = [16, 1], constancy = [1, 1], constant_value = <none>
891891
%1 = tt.make_tensor_ptr %arg1, [%c32_i64, %c32_i64], [%c1_i64, %arg2], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<64x16xf8E5M2>>
892-
// CHECK: tt.make_tensor_ptr %arg1, {{.*}} => contiguity = [32, 64], divisibility = [1, 1], constancy = [1, 1], constant_value = <none>
892+
// CHECK: tt.make_tensor_ptr %arg1, {{.*}} => stride = [1, 1], contiguity = [32, 64], divisibility = [1, 1], constancy = [1, 1], constant_value = <none>
893893
%2 = tt.make_tensor_ptr %arg1, [%arg2, %c128_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<32x64xf8E5M2>>
894894
tt.return
895895
}
896+
897+
// -----
898+
899+
// CHECK-LABEL: @ptr_offset
900+
tt.func public @ptr_offset(%arg0: i32) {
901+
// CHECK: stride = [0, 0], contiguity = [1, 1], divisibility = [512, 512], constancy = [128, 1], constant_value = 512
902+
%cst = arith.constant dense<512> : tensor<128x1xi32>
903+
// CHECK: stride = [0], contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none>
904+
%0 = tt.splat %arg0 : i32 -> tensor<128xi32>
905+
// CHECK: stride = [1], contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
906+
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
907+
// CHECK: stride = [1], contiguity = [128], divisibility = [1], constancy = [1], constant_value = <none>
908+
%2 = arith.addi %0, %1 : tensor<128xi32>
909+
// CHECK: stride = [1, 0], contiguity = [128, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = <none>
910+
%3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32>
911+
// CHECK: stride = [512, 0], contiguity = [1, 1], divisibility = [512, 512], constancy = [1, 1], constant_value = <none>
912+
%4 = arith.muli %3, %cst : tensor<128x1xi32>
913+
// CHECK: stride = [512, 0], contiguity = [1, 1], divisibility = [512, 512], constancy = [1, 64], constant_value = <none>
914+
%5 = tt.broadcast %4 : tensor<128x1xi32> -> tensor<128x64xi32>
915+
tt.return
916+
}

third_party/intel/lib/Analysis/AxisInfo.cpp

Lines changed: 72 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -107,28 +107,37 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl<OpTy> {
107107
const auto &rhsInfo = operands[1]->getValue();
108108
auto rank = lhsInfo.getRank();
109109
assert(operands.size() == 2 && "Expected two operands");
110+
AxisInfo::DimVectorT stride;
110111
AxisInfo::DimVectorT contiguity;
111112
AxisInfo::DimVectorT divisibility;
112113
AxisInfo::DimVectorT constancy;
113114
auto constantValue = getConstantValue(op, lhsInfo, rhsInfo);
114115
for (auto d = 0; d < rank; ++d) {
115116
if (constantValue.has_value()) {
117+
stride.push_back(0);
116118
contiguity.push_back(1);
117119
constancy.push_back(
118120
std::max(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d)));
119121
divisibility.push_back(
120122
highestPowOf2Divisor<int64_t>(constantValue.value()));
121123
} else {
124+
stride.push_back(getStride(op, lhsInfo, rhsInfo, d));
122125
contiguity.push_back(getContiguity(op, lhsInfo, rhsInfo, d));
123126
constancy.push_back(getConstancy(op, lhsInfo, rhsInfo, d));
124127
divisibility.push_back(getDivisibility(op, lhsInfo, rhsInfo, d));
125128
}
126129
}
127-
return AxisInfo(std::move(contiguity), std::move(divisibility),
128-
std::move(constancy), constantValue);
130+
return AxisInfo(std::move(stride), std::move(contiguity),
131+
std::move(divisibility), std::move(constancy),
132+
constantValue);
129133
}
130134

131135
protected:
136+
virtual int64_t getStride(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
137+
int dim) {
138+
return -1;
139+
}
140+
132141
virtual int64_t getContiguity(OpTy op, const AxisInfo &lhs,
133142
const AxisInfo &rhs, int dim) {
134143
return 1;
@@ -252,7 +261,7 @@ class ConstantOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
252261
value = intAttr.getValue().getZExtValue();
253262
else
254263
value = boolAttr.getValue() ? 1 : 0;
255-
return AxisInfo(/*contiguity=*/{1},
264+
return AxisInfo(/*stride=*/{0}, /*contiguity=*/{1},
256265
/*divisibility=*/{highestPowOf2Divisor(value)},
257266
/*constancy=*/{1},
258267
/*knownConstantValue=*/{value});
@@ -263,6 +272,7 @@ class ConstantOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
263272
int64_t value = splatAttr.template getSplatValue<APInt>().getZExtValue();
264273
TensorType ty = cast<TensorType>(splatAttr.getType());
265274
return AxisInfo(
275+
/*stride=*/AxisInfo::DimVectorT(ty.getRank(), 0),
266276
/*contiguity=*/AxisInfo::DimVectorT(ty.getRank(), 1),
267277
/*divisibility=*/
268278
AxisInfo::DimVectorT(ty.getRank(), highestPowOf2Divisor(value)),
@@ -302,6 +312,15 @@ class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
302312
using BinaryOpVisitorImpl<OpTy>::BinaryOpVisitorImpl;
303313

304314
private:
315+
int64_t getStride(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
316+
int dim) override {
317+
if (lhs.getStride(dim) < 0 || rhs.getStride(dim) < 0)
318+
return -1;
319+
if (isa<arith::SubIOp>(op))
320+
return std::max(lhs.getStride(dim) - rhs.getStride(dim), int64_t(-1));
321+
return lhs.getStride(dim) + rhs.getStride(dim);
322+
}
323+
305324
int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
306325
int dim) override {
307326
// Contiguity assumes an increasing sequence. So for SubIOp contiguous
@@ -373,6 +392,17 @@ class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::MulIOp> {
373392
using BinaryOpVisitorImpl<arith::MulIOp>::BinaryOpVisitorImpl;
374393

375394
private:
395+
int64_t getStride(arith::MulIOp op, const AxisInfo &lhs, const AxisInfo &rhs,
396+
int dim) override {
397+
if (lhs.getStride(dim) > 0 && rhs.getConstantValue().has_value())
398+
return lhs.getStride(dim) * rhs.getConstantValue().value();
399+
if (rhs.getStride(dim) > 0 && lhs.getConstantValue().has_value())
400+
return lhs.getConstantValue().value() * rhs.getStride(dim);
401+
if (lhs.getStride(dim) == 0 || rhs.getStride(dim) == 0)
402+
return 0;
403+
return -1;
404+
}
405+
376406
int64_t getContiguity(arith::MulIOp op, const AxisInfo &lhs,
377407
const AxisInfo &rhs, int dim) override {
378408
// lhs * 1 = lhs
@@ -425,6 +455,22 @@ class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
425455
using BinaryOpVisitorImpl<OpTy>::BinaryOpVisitorImpl;
426456

427457
private:
458+
int64_t getStride(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
459+
int dim) override {
460+
if (getContiguity(op, lhs, rhs, dim) > 1)
461+
return 1;
462+
if (lhs.getStride(dim) > 0 && rhs.getConstantValue().has_value() &&
463+
rhs.getConstantValue().has_value() != 0 &&
464+
lhs.getStride(dim) % rhs.getConstantValue().value() == 0)
465+
return lhs.getStride(dim) / rhs.getConstantValue().value();
466+
if (rhs.getStride(dim) > 0 && lhs.getConstantValue().has_value() &&
467+
lhs.getConstantValue().value() % rhs.getStride(dim) == 0)
468+
return lhs.getConstantValue().value() / rhs.getStride(dim);
469+
if (lhs.getStride(dim) == 0)
470+
return 0;
471+
return -1;
472+
}
473+
428474
int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
429475
int dim) override {
430476
// lhs / 1 = lhs
@@ -559,16 +605,18 @@ class SplatOpAxisInfoVisitor final
559605
Type _retTy = *op->result_type_begin();
560606
TensorType retTy = cast<TensorType>(_retTy);
561607
AxisInfo opInfo = operands[0]->getValue();
608+
AxisInfo::DimVectorT stride;
562609
AxisInfo::DimVectorT contiguity;
563610
AxisInfo::DimVectorT divisibility;
564611
AxisInfo::DimVectorT constancy;
565612
for (int d = 0; d < retTy.getRank(); ++d) {
613+
stride.push_back(0);
566614
contiguity.push_back(1);
567615
divisibility.push_back(opInfo.getDivisibility(0));
568616
constancy.push_back(retTy.getShape()[d]);
569617
}
570-
return AxisInfo(std::move(contiguity), std::move(divisibility),
571-
std::move(constancy),
618+
return AxisInfo(std::move(stride), std::move(contiguity),
619+
std::move(divisibility), std::move(constancy),
572620
operands[0]->getValue().getConstantValue());
573621
}
574622
};
@@ -613,6 +661,7 @@ class ExpandDimsOpAxisInfoVisitor final
613661
getAxisInfo(triton::ExpandDimsOp op,
614662
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
615663
AxisInfo opInfo = operands[0]->getValue();
664+
AxisInfo::DimVectorT stride = opInfo.getStride();
616665
AxisInfo::DimVectorT contiguity = opInfo.getContiguity();
617666
AxisInfo::DimVectorT divisibility = opInfo.getDivisibility();
618667
AxisInfo::DimVectorT constancy = opInfo.getConstancy();
@@ -631,11 +680,12 @@ class ExpandDimsOpAxisInfoVisitor final
631680
opInfo.getContiguity(d) > 1 ? 1 : opInfo.getDivisibility(d));
632681
}
633682
}
683+
stride.insert(stride.begin() + op.getAxis(), 0);
634684
contiguity.insert(contiguity.begin() + op.getAxis(), 1);
635685
divisibility.insert(divisibility.begin() + op.getAxis(), newDivisibility);
636686
constancy.insert(constancy.begin() + op.getAxis(), 1);
637-
return AxisInfo(std::move(contiguity), std::move(divisibility),
638-
std::move(constancy),
687+
return AxisInfo(std::move(stride), std::move(contiguity),
688+
std::move(divisibility), std::move(constancy),
639689
operands[0]->getValue().getConstantValue());
640690
}
641691
};
@@ -655,17 +705,19 @@ class BroadcastOpAxisInfoVisitor final
655705
ArrayRef<int64_t> retShape = retTy.getShape();
656706
ArrayRef<int64_t> opShape = opTy.getShape();
657707
AxisInfo opInfo = operands[0]->getValue();
708+
AxisInfo::DimVectorT stride;
658709
AxisInfo::DimVectorT contiguity;
659710
AxisInfo::DimVectorT divisibility;
660711
AxisInfo::DimVectorT constancy;
661712
for (int d = 0; d < retTy.getRank(); ++d) {
713+
stride.push_back(opInfo.getStride(d));
662714
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
663715
divisibility.push_back(opInfo.getDivisibility(d));
664716
constancy.push_back(opShape[d] == 1 ? retShape[d]
665717
: opInfo.getConstancy(d));
666718
}
667-
return AxisInfo(std::move(contiguity), std::move(divisibility),
668-
std::move(constancy),
719+
return AxisInfo(std::move(stride), std::move(contiguity),
720+
std::move(divisibility), std::move(constancy),
669721
operands[0]->getValue().getConstantValue());
670722
}
671723
};
@@ -1048,15 +1100,18 @@ class MakeTensorPtrOpAxisInfoVisitor final
10481100
if (rank > 2)
10491101
return AxisInfo();
10501102

1051-
SmallVector<AxisInfo> strideInfo;
1103+
SmallVector<AxisInfo, 2> strideInfo;
10521104
for (int i = rank + 1; i <= rank * 2; ++i)
10531105
strideInfo.emplace_back(operands[i]->getValue());
10541106

10551107
AxisInfo ptrInfo = operands[0]->getValue();
10561108
int64_t ptrDivisibility = ptrInfo.getDivisibility(0);
10571109

1058-
AxisInfo::DimVectorT contiguity, constancy, divisibility;
1110+
AxisInfo::DimVectorT stride, contiguity, constancy, divisibility;
10591111
for (int dim = 0; dim < rank; ++dim) {
1112+
stride.push_back(strideInfo[dim].getConstantValue().has_value()
1113+
? strideInfo[dim].getConstantValue().value()
1114+
: -1);
10601115
contiguity.push_back(
10611116
strideInfo[dim].getConstantValue() == 1 ? blkShape[dim] : 1);
10621117
divisibility.push_back(
@@ -1069,8 +1124,9 @@ class MakeTensorPtrOpAxisInfoVisitor final
10691124
constancy.push_back(1);
10701125
}
10711126

1072-
auto axisInfo = AxisInfo(std::move(contiguity), std::move(divisibility),
1073-
std::move(constancy));
1127+
auto axisInfo =
1128+
AxisInfo(std::move(stride), std::move(contiguity),
1129+
std::move(divisibility), std::move(constancy), std::nullopt);
10741130

10751131
LLVM_DEBUG({
10761132
std::string axisStr;
@@ -1176,8 +1232,9 @@ LogicalResult AxisInfoAnalysis::visitOperation(
11761232
auto vals = cast<DenseElementsAttr>(attr).getValues<int>();
11771233
newConstancy = AxisInfo::DimVectorT(vals.begin(), vals.end());
11781234
}
1179-
curr = AxisInfo(std::move(newContiguity), std::move(newDivisibility),
1180-
std::move(newConstancy), curr.getConstantValue());
1235+
curr = AxisInfo(curr.getStride(), std::move(newContiguity),
1236+
std::move(newDivisibility), std::move(newConstancy),
1237+
curr.getConstantValue());
11811238
// join all lattice elements
11821239
for (auto *result : results)
11831240
propagateIfChanged(result, result->join(curr));

0 commit comments

Comments
 (0)