1
+ #include " mlir/IR/Attributes.h"
1
2
#include " mlir/IR/BuiltinAttributes.h"
2
3
#include " mlir/IR/Visitors.h"
3
4
#include " mlir/Support/LLVM.h"
9
10
#include " llvm/ADT/PriorityWorklist.h"
10
11
#include " llvm/Support/Debug.h"
11
12
#include " llvm/Support/LogicalResult.h"
13
+ #include " llvm/Support/raw_ostream.h"
14
+ #include " llvm/Support/xxhash.h"
12
15
13
16
namespace ttg = mlir::triton::gpu;
14
17
@@ -28,6 +31,65 @@ bool isAutoEncodingTensorType(Type ty) {
28
31
return tensorTy && isa<gluon::AutoEncodingAttr>(tensorTy.getEncoding ());
29
32
}
30
33
34
+ struct LayoutInfo {
35
+ Attribute encoding;
36
+ // Some operations can infer one of many encodings,
37
+ // we model this by setting the mayVary flag on encodings
38
+ // derived from these ops.
39
+ // If "may vary" is set then we allow conflicts, and when
40
+ // resolving conflicts we prefer encodings that are not allowed to vary.
41
+ bool mayVary = false ;
42
+
43
+ operator bool () { return bool (encoding); }
44
+ };
45
+
46
+ uint64_t hashWithMemo (Attribute attr,
47
+ llvm::MapVector<Attribute, uint64_t > &hashMemo) {
48
+ auto it = hashMemo.find (attr);
49
+ if (it != hashMemo.end ()) {
50
+ return it->second ;
51
+ }
52
+
53
+ // llvm::hash_value is not stable, so instead we hash the string repr of the
54
+ // attribute
55
+ std::string str;
56
+ llvm::raw_string_ostream os (str);
57
+ attr.print (os);
58
+ auto hash = llvm::xxh3_64bits (str);
59
+ hashMemo.try_emplace (attr, hash);
60
+ return hash;
61
+ }
62
+
63
+ bool compare (Attribute a, Attribute b,
64
+ llvm::MapVector<Attribute, uint64_t > &hashMemo) {
65
+ if (a == b)
66
+ return false ;
67
+
68
+ return hashWithMemo (a, hashMemo) > hashWithMemo (b, hashMemo);
69
+ }
70
+
71
+ LayoutInfo combineInfo (LayoutInfo lhs, LayoutInfo rhs, Operation *op,
72
+ llvm::MapVector<Attribute, uint64_t > &hashMemo) {
73
+ // Sort inputs so this operation is commutative
74
+ if (compare (lhs.encoding , rhs.encoding , hashMemo)) {
75
+ std::swap (lhs, rhs);
76
+ }
77
+ if (lhs.mayVary )
78
+ return rhs;
79
+ if (rhs.mayVary )
80
+ return lhs;
81
+ if (lhs.encoding == rhs.encoding )
82
+ return lhs;
83
+ op->emitOpError (" found conflicting encodings for value:\n " )
84
+ << lhs.encoding << " \n and\n " << rhs.encoding ;
85
+ return {};
86
+ }
87
+
88
+ bool encodingsMayVary (Operation *op) {
89
+ return isa<triton::JoinOp, triton::SplitOp, triton::ReshapeOp, triton::CatOp,
90
+ triton::TransOp>(op);
91
+ }
92
+
31
93
LogicalResult inferAutoLayouts (FuncOp func) {
32
94
// Disallow auto encoding accross function call boundaries
33
95
for (auto argTy : func.getArgumentTypes ()) {
@@ -42,33 +104,37 @@ LogicalResult inferAutoLayouts(FuncOp func) {
42
104
" Functions returning auto encoding must be fully inlined" );
43
105
}
44
106
45
- llvm::MapVector<Value, Attribute > valueToEncoding;
107
+ llvm::MapVector<Value, LayoutInfo > valueToEncoding;
46
108
llvm::PriorityWorklist<Value> worklist;
109
+ llvm::MapVector<Attribute, uint64_t > hashMemo;
47
110
48
111
auto updateEncoding = [&](ArrayRef<Value> values,
49
- Attribute enc ) -> LogicalResult {
112
+ LayoutInfo info ) -> LogicalResult {
50
113
for (auto value : values) {
51
- auto [it, inserted] = valueToEncoding.insert ({value, enc });
114
+ auto [it, inserted] = valueToEncoding.insert ({value, info });
52
115
if (!inserted) {
53
- if (it->second != enc) {
54
- auto defOp = value.getDefiningOp ();
55
- auto op = defOp ? defOp : func;
56
- return op->emitOpError (" Found conflicting encodings for value" );
57
- }
58
- } else {
59
- LLVM_DEBUG ({
60
- DBGS () << " Setting value:\n\t " << value << " \n to encoding:\n\t " << enc
61
- << " \n " ;
62
- });
63
- worklist.insert (value);
116
+ auto defOp = value.getDefiningOp ();
117
+ auto op = defOp ? defOp : func;
118
+ auto combine = combineInfo (it->second , info, op, hashMemo);
119
+ if (!combine)
120
+ return failure ();
121
+ if (combine == it->second )
122
+ continue ;
123
+ it->second = combine;
64
124
}
125
+ LLVM_DEBUG ({
126
+ DBGS () << " Setting value:\n\t " << value << " \n to encoding:\n\t "
127
+ << it->second << " \n " ;
128
+ });
129
+ worklist.insert (value);
65
130
}
66
131
return success ();
67
132
};
68
133
69
134
// 1. Set seed values from set_auto_layout ops
70
135
auto res = func.walk ([&](gluon::SetAutoLayoutOp op) -> WalkResult {
71
- return updateEncoding ({op.getSrc ()}, op.getType ().getEncoding ());
136
+ return updateEncoding ({op.getSrc ()},
137
+ LayoutInfo{op.getType ().getEncoding ()});
72
138
});
73
139
74
140
if (res.wasInterrupted ())
@@ -77,26 +143,28 @@ LogicalResult inferAutoLayouts(FuncOp func) {
77
143
// 2. Propagate encodings through the graph until fixed point, or conflict
78
144
while (!worklist.empty ()) {
79
145
auto val = worklist.pop_back_val ();
80
- auto enc = valueToEncoding[val];
81
- assert (enc );
146
+ auto info = valueToEncoding[val];
147
+ assert (info );
82
148
83
149
// Propagate to users
84
150
for (OpOperand &use : val.getUses ()) {
85
151
auto op = use.getOwner ();
86
152
if (isa<scf::ForOp, scf::WhileOp>(op)) {
87
153
auto offset = 3 * isa<scf::ForOp>(op);
88
154
auto tiedArgs = getTiedArgs (op, use.getOperandNumber () - offset);
89
- if (failed (updateEncoding (tiedArgs, enc )))
155
+ if (failed (updateEncoding (tiedArgs, info )))
90
156
return failure ();
91
157
} else if (isa<scf::YieldOp>(op)) {
92
158
auto tiedArgs = getTiedArgs (op, use.getOperandNumber ());
93
- if (failed (updateEncoding (tiedArgs, enc )))
159
+ if (failed (updateEncoding (tiedArgs, info )))
94
160
return failure ();
95
161
} else {
96
- auto dstEnc = inferDstEncoding (op, enc );
162
+ auto dstEnc = inferDstEncoding (op, info. encoding );
97
163
if (dstEnc) {
164
+ bool mayVary = info.mayVary || encodingsMayVary (op);
165
+ LayoutInfo dstInfo{dstEnc, mayVary};
98
166
if (failed (updateEncoding (llvm::to_vector_of<Value>(op->getResults ()),
99
- dstEnc )))
167
+ dstInfo )))
100
168
return failure ();
101
169
}
102
170
}
@@ -107,17 +175,19 @@ LogicalResult inferAutoLayouts(FuncOp func) {
107
175
auto definingOp = opResult.getOwner ();
108
176
if (isa<scf::ForOp, scf::WhileOp, scf::IfOp>(definingOp)) {
109
177
auto tiedArgs = getTiedArgs (definingOp, opResult.getResultNumber ());
110
- if (failed (updateEncoding (tiedArgs, enc )))
178
+ if (failed (updateEncoding (tiedArgs, info )))
111
179
return failure ();
112
180
} else {
113
- auto srcEncoding = inferSrcEncoding (definingOp, enc );
181
+ auto srcEncoding = inferSrcEncoding (definingOp, info. encoding );
114
182
if (srcEncoding) {
183
+ bool mayVary = info.mayVary || encodingsMayVary (definingOp);
184
+ LayoutInfo srcInfo{srcEncoding, mayVary};
115
185
llvm::SmallVector<Value> tensorOperands;
116
186
for (auto operand : definingOp->getOperands ())
117
187
if (isa<RankedTensorType>(operand.getType ()))
118
188
tensorOperands.push_back (operand);
119
189
120
- if (failed (updateEncoding (tensorOperands, srcEncoding )))
190
+ if (failed (updateEncoding (tensorOperands, srcInfo )))
121
191
return failure ();
122
192
}
123
193
}
@@ -126,18 +196,18 @@ LogicalResult inferAutoLayouts(FuncOp func) {
126
196
if (isa<scf::ForOp, scf::WhileOp>(parentOp)) {
127
197
auto offset = isa<scf::ForOp>(parentOp);
128
198
auto tiedArgs = getTiedArgs (parentOp, blockArg.getArgNumber () - offset);
129
- if (failed (updateEncoding (tiedArgs, enc )))
199
+ if (failed (updateEncoding (tiedArgs, info )))
130
200
return failure ();
131
201
}
132
202
}
133
203
}
134
204
135
205
// 3. Transfer propagated encodings into the graph
136
206
auto ctx = func.getContext ();
137
- for (auto &[val, enc ] : valueToEncoding) {
207
+ for (auto &[val, info ] : valueToEncoding) {
138
208
auto existingTy = cast<RankedTensorType>(val.getType ());
139
209
assert (isa<gluon::AutoEncodingAttr>(existingTy.getEncoding ()));
140
- auto ty = existingTy.cloneWithEncoding (enc );
210
+ auto ty = existingTy.cloneWithEncoding (info. encoding );
141
211
val.setType (ty);
142
212
143
213
if (auto opResult = dyn_cast<OpResult>(val)) {
0 commit comments