Skip to content

Commit 9d11c09

Browse files
apgoucherlezcano
andauthored
[BACKEND] simpler codegen for linear layouts (#7201)
Co-authored-by: lezcano <[email protected]>
1 parent 5d440d3 commit 9d11c09

File tree

3 files changed

+71
-0
lines changed

3 files changed

+71
-0
lines changed

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,53 @@ LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op,
9595
StringAttr::get(op->getContext(), libpath));
9696
return ret;
9797
}
98+
99+
Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
100+
assert(A.getNumInDims() == 1);
101+
assert(A.getNumOutDims() == 1);
102+
auto flatten = [](const std::vector<std::vector<int32_t>> &matrix) {
103+
SmallVector<int32_t> ret;
104+
for (const auto &row : matrix) {
105+
ret.push_back(row[0]);
106+
}
107+
return ret;
108+
};
109+
auto nCol = A.getTotalInDimSizeLog2();
110+
auto nRow = A.getTotalOutDimSizeLog2();
111+
SmallVector<int32_t> matrix = flatten(A.getBases().begin()->second);
112+
assert(matrix.size() == nCol);
113+
// We iterate the matrix following the diagonals
114+
// The idea here is that we want to generate code of the form:
115+
// \xor_i (x & mask_i) << s_i
116+
// where s_i may by positive or negative (left or right shift)
117+
// The hope here (and we see it in codegen) is that LLVM can turn
118+
// the xor into a sum and then the sum + LHS/RHS can be fused into a mad.lo
119+
// Get the i-th diagonal
120+
auto getMask = [&](int i) {
121+
uint32_t mask = 0;
122+
int row = i < 0 ? -i : 0;
123+
int col = i < 0 ? 0 : i;
124+
while (row < nRow && col < nCol) {
125+
uint32_t bitValue = (matrix[col] >> row) & 1u;
126+
mask |= bitValue << col;
127+
++row;
128+
++col;
129+
}
130+
return mask;
131+
};
132+
133+
Value ret = b.i32_val(0);
134+
for (int i = -nRow + 1; i < nCol; i++) {
135+
auto mask = getMask(i);
136+
if (mask == 0)
137+
continue;
138+
auto masked = b.and_(x, b.i32_val(mask));
139+
ret = b.xor_(ret, i >= 0 ? Value(b.lshr(masked, b.i32_val(i)))
140+
: Value(b.shl(masked, b.i32_val(-i))));
141+
}
142+
return ret;
143+
}
144+
98145
} // namespace triton::gpu
99146

100147
SmallVector<std::pair<StringAttr, Value>>
@@ -117,12 +164,14 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
117164

118165
// Manually constant-fold the layout where possible.
119166
SmallVector<std::pair<StringAttr, int32_t>> constantIns;
167+
SmallVector<std::pair<StringAttr, Value>> nonConstantIns;
120168
for (auto [inDimName, idx] : indices) {
121169
if (auto constant = idx.getDefiningOp<LLVM::ConstantOp>()) {
122170
constantIns.push_back(
123171
{inDimName, cast<IntegerAttr>(constant.getValue()).getInt()});
124172
} else {
125173
constantIns.push_back({inDimName, 0});
174+
nonConstantIns.push_back({inDimName, idx});
126175
}
127176
}
128177
SmallVector<int32_t> constantComponent =
@@ -136,6 +185,28 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
136185
else
137186
outIndices.push_back({outDimName, b.i32_val(constantComponent[i])});
138187
}
188+
// Happy path: Only one output.
189+
if (outIndices.size() == 1) {
190+
SmallVector<StringAttr> inDimNames;
191+
// Concatenate input
192+
Value x = b.i32_val(0);
193+
int shift = 0;
194+
for (auto orderedName : layout.getInDimNames()) {
195+
for (auto [inDimName, idx] : nonConstantIns) {
196+
if (orderedName == inDimName) {
197+
inDimNames.push_back(inDimName);
198+
x = b.or_(x, b.shl(idx, b.i32_val(shift)));
199+
shift += layout.getInDimSizeLog2(inDimName);
200+
}
201+
}
202+
}
203+
// Flatten ins
204+
auto matrix = layout.sublayout(inDimNames, outIndices[0].first);
205+
matrix = matrix.flattenIns();
206+
auto out = triton::gpu::matrixVectorProd(b, matrix, x);
207+
outIndices[0].second = b.xor_(outIndices[0].second, out);
208+
return outIndices;
209+
}
139210

140211
for (auto [inDimName, idx] : indices) {
141212
if (idx.getDefiningOp<LLVM::ConstantOp>()) {

0 commit comments

Comments
 (0)