Skip to content

Commit c24d86c

Browse files
authored
[Gluon] Fix auto_encoding for ops which may infer multiple layouts (#7718)
In the failing example we have: ```mlir %0 = tt.make_range {end = 8192 : i32, start = 0 : i32} : tensor<8192xi32, #gluon.auto_encoding> %1 = tt.reshape %0 : tensor<8192xi32, #gluon.auto_encoding> -> tensor<64x128xi32, #gluon.auto_encoding> %2 = gluon.set_auto_layout %1 : tensor<64x128xi32, #gluon.auto_encoding> -> tensor<64x128xi32, #blocked> ``` which currently fails with the error: ```python /root/code/triton/test.py:43:52: error: 'tt.reshape' op Found conflicting encodings for value gl.arange(0, BLOCK_M * BLOCK_N).reshape((BLOCK_M, BLOCK_N)), ``` The issue is that we propagate the blocked layout backwards to get a linear layout for the `make_range` result, then the algorithm propagates that layout forward to the `reshape` result. However, it infers a linear layout and errors because it conflicts with the original blocked layout. I fix this by setting a `mayVary` flag when an encoding comes from an inference result that isn't the only possibility. I then have special rules that resolve conflicts where one or more of the encodings is allowed to vary.
1 parent e40c213 commit c24d86c

File tree

4 files changed

+117
-29
lines changed

4 files changed

+117
-29
lines changed

lib/Dialect/Gluon/Transforms/ResolveAutoEncodings.cpp

Lines changed: 97 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "mlir/IR/Attributes.h"
12
#include "mlir/IR/BuiltinAttributes.h"
23
#include "mlir/IR/Visitors.h"
34
#include "mlir/Support/LLVM.h"
@@ -9,6 +10,8 @@
910
#include "llvm/ADT/PriorityWorklist.h"
1011
#include "llvm/Support/Debug.h"
1112
#include "llvm/Support/LogicalResult.h"
13+
#include "llvm/Support/raw_ostream.h"
14+
#include "llvm/Support/xxhash.h"
1215

1316
namespace ttg = mlir::triton::gpu;
1417

@@ -28,6 +31,65 @@ bool isAutoEncodingTensorType(Type ty) {
2831
return tensorTy && isa<gluon::AutoEncodingAttr>(tensorTy.getEncoding());
2932
}
3033

34+
struct LayoutInfo {
35+
Attribute encoding;
36+
// Some operations can infer one of many encodings,
37+
// we model this by setting the mayVary flag on encodings
38+
// derived from these ops.
39+
// If "may vary" is set then we allow conflicts, and when
40+
// resolving conflicts we prefer encodings that are not allowed to vary.
41+
bool mayVary = false;
42+
43+
operator bool() { return bool(encoding); }
44+
};
45+
46+
uint64_t hashWithMemo(Attribute attr,
47+
llvm::MapVector<Attribute, uint64_t> &hashMemo) {
48+
auto it = hashMemo.find(attr);
49+
if (it != hashMemo.end()) {
50+
return it->second;
51+
}
52+
53+
// llvm::hash_value is not stable, so instead we hash the string repr of the
54+
// attribute
55+
std::string str;
56+
llvm::raw_string_ostream os(str);
57+
attr.print(os);
58+
auto hash = llvm::xxh3_64bits(str);
59+
hashMemo.try_emplace(attr, hash);
60+
return hash;
61+
}
62+
63+
bool compare(Attribute a, Attribute b,
64+
llvm::MapVector<Attribute, uint64_t> &hashMemo) {
65+
if (a == b)
66+
return false;
67+
68+
return hashWithMemo(a, hashMemo) > hashWithMemo(b, hashMemo);
69+
}
70+
71+
LayoutInfo combineInfo(LayoutInfo lhs, LayoutInfo rhs, Operation *op,
72+
llvm::MapVector<Attribute, uint64_t> &hashMemo) {
73+
// Sort inputs so this operation is commutative
74+
if (compare(lhs.encoding, rhs.encoding, hashMemo)) {
75+
std::swap(lhs, rhs);
76+
}
77+
if (lhs.mayVary)
78+
return rhs;
79+
if (rhs.mayVary)
80+
return lhs;
81+
if (lhs.encoding == rhs.encoding)
82+
return lhs;
83+
op->emitOpError("found conflicting encodings for value:\n ")
84+
<< lhs.encoding << "\nand\n " << rhs.encoding;
85+
return {};
86+
}
87+
88+
bool encodingsMayVary(Operation *op) {
89+
return isa<triton::JoinOp, triton::SplitOp, triton::ReshapeOp, triton::CatOp,
90+
triton::TransOp>(op);
91+
}
92+
3193
LogicalResult inferAutoLayouts(FuncOp func) {
3294
// Disallow auto encoding accross function call boundaries
3395
for (auto argTy : func.getArgumentTypes()) {
@@ -42,33 +104,37 @@ LogicalResult inferAutoLayouts(FuncOp func) {
42104
"Functions returning auto encoding must be fully inlined");
43105
}
44106

45-
llvm::MapVector<Value, Attribute> valueToEncoding;
107+
llvm::MapVector<Value, LayoutInfo> valueToEncoding;
46108
llvm::PriorityWorklist<Value> worklist;
109+
llvm::MapVector<Attribute, uint64_t> hashMemo;
47110

48111
auto updateEncoding = [&](ArrayRef<Value> values,
49-
Attribute enc) -> LogicalResult {
112+
LayoutInfo info) -> LogicalResult {
50113
for (auto value : values) {
51-
auto [it, inserted] = valueToEncoding.insert({value, enc});
114+
auto [it, inserted] = valueToEncoding.insert({value, info});
52115
if (!inserted) {
53-
if (it->second != enc) {
54-
auto defOp = value.getDefiningOp();
55-
auto op = defOp ? defOp : func;
56-
return op->emitOpError("Found conflicting encodings for value");
57-
}
58-
} else {
59-
LLVM_DEBUG({
60-
DBGS() << "Setting value:\n\t" << value << "\nto encoding:\n\t" << enc
61-
<< "\n";
62-
});
63-
worklist.insert(value);
116+
auto defOp = value.getDefiningOp();
117+
auto op = defOp ? defOp : func;
118+
auto combine = combineInfo(it->second, info, op, hashMemo);
119+
if (!combine)
120+
return failure();
121+
if (combine == it->second)
122+
continue;
123+
it->second = combine;
64124
}
125+
LLVM_DEBUG({
126+
DBGS() << "Setting value:\n\t" << value << "\nto encoding:\n\t"
127+
<< it->second << "\n";
128+
});
129+
worklist.insert(value);
65130
}
66131
return success();
67132
};
68133

69134
// 1. Set seed values from set_auto_layout ops
70135
auto res = func.walk([&](gluon::SetAutoLayoutOp op) -> WalkResult {
71-
return updateEncoding({op.getSrc()}, op.getType().getEncoding());
136+
return updateEncoding({op.getSrc()},
137+
LayoutInfo{op.getType().getEncoding()});
72138
});
73139

74140
if (res.wasInterrupted())
@@ -77,26 +143,28 @@ LogicalResult inferAutoLayouts(FuncOp func) {
77143
// 2. Propagate encodings through the graph until fixed point, or conflict
78144
while (!worklist.empty()) {
79145
auto val = worklist.pop_back_val();
80-
auto enc = valueToEncoding[val];
81-
assert(enc);
146+
auto info = valueToEncoding[val];
147+
assert(info);
82148

83149
// Propagate to users
84150
for (OpOperand &use : val.getUses()) {
85151
auto op = use.getOwner();
86152
if (isa<scf::ForOp, scf::WhileOp>(op)) {
87153
auto offset = 3 * isa<scf::ForOp>(op);
88154
auto tiedArgs = getTiedArgs(op, use.getOperandNumber() - offset);
89-
if (failed(updateEncoding(tiedArgs, enc)))
155+
if (failed(updateEncoding(tiedArgs, info)))
90156
return failure();
91157
} else if (isa<scf::YieldOp>(op)) {
92158
auto tiedArgs = getTiedArgs(op, use.getOperandNumber());
93-
if (failed(updateEncoding(tiedArgs, enc)))
159+
if (failed(updateEncoding(tiedArgs, info)))
94160
return failure();
95161
} else {
96-
auto dstEnc = inferDstEncoding(op, enc);
162+
auto dstEnc = inferDstEncoding(op, info.encoding);
97163
if (dstEnc) {
164+
bool mayVary = info.mayVary || encodingsMayVary(op);
165+
LayoutInfo dstInfo{dstEnc, mayVary};
98166
if (failed(updateEncoding(llvm::to_vector_of<Value>(op->getResults()),
99-
dstEnc)))
167+
dstInfo)))
100168
return failure();
101169
}
102170
}
@@ -107,17 +175,19 @@ LogicalResult inferAutoLayouts(FuncOp func) {
107175
auto definingOp = opResult.getOwner();
108176
if (isa<scf::ForOp, scf::WhileOp, scf::IfOp>(definingOp)) {
109177
auto tiedArgs = getTiedArgs(definingOp, opResult.getResultNumber());
110-
if (failed(updateEncoding(tiedArgs, enc)))
178+
if (failed(updateEncoding(tiedArgs, info)))
111179
return failure();
112180
} else {
113-
auto srcEncoding = inferSrcEncoding(definingOp, enc);
181+
auto srcEncoding = inferSrcEncoding(definingOp, info.encoding);
114182
if (srcEncoding) {
183+
bool mayVary = info.mayVary || encodingsMayVary(definingOp);
184+
LayoutInfo srcInfo{srcEncoding, mayVary};
115185
llvm::SmallVector<Value> tensorOperands;
116186
for (auto operand : definingOp->getOperands())
117187
if (isa<RankedTensorType>(operand.getType()))
118188
tensorOperands.push_back(operand);
119189

120-
if (failed(updateEncoding(tensorOperands, srcEncoding)))
190+
if (failed(updateEncoding(tensorOperands, srcInfo)))
121191
return failure();
122192
}
123193
}
@@ -126,18 +196,18 @@ LogicalResult inferAutoLayouts(FuncOp func) {
126196
if (isa<scf::ForOp, scf::WhileOp>(parentOp)) {
127197
auto offset = isa<scf::ForOp>(parentOp);
128198
auto tiedArgs = getTiedArgs(parentOp, blockArg.getArgNumber() - offset);
129-
if (failed(updateEncoding(tiedArgs, enc)))
199+
if (failed(updateEncoding(tiedArgs, info)))
130200
return failure();
131201
}
132202
}
133203
}
134204

135205
// 3. Transfer propagated encodings into the graph
136206
auto ctx = func.getContext();
137-
for (auto &[val, enc] : valueToEncoding) {
207+
for (auto &[val, info] : valueToEncoding) {
138208
auto existingTy = cast<RankedTensorType>(val.getType());
139209
assert(isa<gluon::AutoEncodingAttr>(existingTy.getEncoding()));
140-
auto ty = existingTy.cloneWithEncoding(enc);
210+
auto ty = existingTy.cloneWithEncoding(info.encoding);
141211
val.setType(ty);
142212

143213
if (auto opResult = dyn_cast<OpResult>(val)) {

python/triton/language/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
join,
8585
load,
8686
make_block_ptr,
87-
map_elementwise, # noqa
87+
map_elementwise,
8888
max_constancy,
8989
max_contiguous,
9090
maximum,
@@ -209,6 +209,7 @@
209209
"log",
210210
"log2",
211211
"make_block_ptr",
212+
"map_elementwise",
212213
"math",
213214
"max",
214215
"max_constancy",

test/Gluon/auto_encoding.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,20 @@ module attributes {ttg.maxnreg = 128 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-w
131131
tt.return %3 : tensor<128x128xi32, #blocked>
132132
}
133133
}
134+
135+
// -----
136+
137+
#blocked = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
138+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
139+
tt.func public @_tmem_col_slice_load(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) -> tensor<64x128xi32, #blocked> {
140+
// CHECK-DAG: [[BLOCKED:#.*]] = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
141+
// CHECK-DAG: [[LINEAR:#.*]] = #ttg.linear
142+
// CHECK: [[RANGE:%.*]] = tt.make_range {end = 8192 : i32, start = 0 : i32} : tensor<8192xi32, [[LINEAR]]>
143+
// CHECK: [[RESHAPE:%.*]] = tt.reshape [[RANGE]] : tensor<8192xi32, [[LINEAR]]> -> tensor<64x128xi32, [[BLOCKED]]>
144+
// CHECK: tt.return [[RESHAPE]] : tensor<64x128xi32, [[BLOCKED]]>
145+
%0 = tt.make_range {end = 8192 : i32, start = 0 : i32} : tensor<8192xi32, #gluon.auto_encoding>
146+
%1 = tt.reshape %0 : tensor<8192xi32, #gluon.auto_encoding> -> tensor<64x128xi32, #gluon.auto_encoding>
147+
%2 = gluon.set_auto_layout %1 : tensor<64x128xi32, #gluon.auto_encoding> -> tensor<64x128xi32, #blocked>
148+
tt.return %2 : tensor<64x128xi32, #blocked>
149+
}
150+
}

test/Gluon/invalid_auto_encoding.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
77
tt.func public @infer_conflict() -> (tensor<16xi32, #blocked>, tensor<16xi32, #blocked1>) {
8-
// expected-error @+1 {{Found conflicting encodings for value}}
8+
// expected-error-re @+1 {{found conflicting encodings for value:{{.*}} #ttg.blocked<{sizePerThread = [1]{{.*}}and{{.*}} #ttg.blocked<{sizePerThread = [2]}}
99
%0 = arith.constant dense<7> : tensor<16xi32, #gluon.auto_encoding>
1010
%cvt1 = gluon.set_auto_layout %0 : tensor<16xi32, #gluon.auto_encoding> -> tensor<16xi32, #blocked>
1111
%cvt2 = gluon.set_auto_layout %0 : tensor<16xi32, #gluon.auto_encoding> -> tensor<16xi32, #blocked1>

0 commit comments

Comments
 (0)