Skip to content

Commit 4587aef

Browse files
committed
Revert "Revert "[BACKEND] simpler codegen for linear layouts (#7201)""
This reverts commit 0844235.
1 parent 7206598 commit 4587aef

File tree

3 files changed

+67
-0
lines changed

3 files changed

+67
-0
lines changed

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,53 @@ LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op,
9595
StringAttr::get(op->getContext(), libpath));
9696
return ret;
9797
}
98+
99+
Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
100+
assert(A.getNumInDims() == 1);
101+
assert(A.getNumOutDims() == 1);
102+
auto flatten = [](const std::vector<std::vector<int32_t>> &matrix) {
103+
SmallVector<int32_t> ret;
104+
for (const auto &row : matrix) {
105+
ret.push_back(row[0]);
106+
}
107+
return ret;
108+
};
109+
auto nCol = A.getTotalInDimSizeLog2();
110+
auto nRow = A.getTotalOutDimSizeLog2();
111+
SmallVector<int32_t> matrix = flatten(A.getBases().begin()->second);
112+
assert(matrix.size() == nCol);
113+
// We iterate the matrix following the diagonals
114+
// The idea here is that we want to generate code of the form:
115+
// \xor_i (x & mask_i) << s_i
116+
// where s_i may by positive or negative (left or right shift)
117+
// The hope here (and we see it in codegen) is that LLVM can turn
118+
// the xor into a sum and then the sum + LHS/RHS can be fused into a mad.lo
119+
// Get the i-th diagonal
120+
auto getMask = [&](int i) {
121+
uint32_t mask = 0;
122+
int row = i < 0 ? -i : 0;
123+
int col = i < 0 ? 0 : i;
124+
while (row < nRow && col < nCol) {
125+
uint32_t bitValue = (matrix[col] >> row) & 1u;
126+
mask |= bitValue << col;
127+
++row;
128+
++col;
129+
}
130+
return mask;
131+
};
132+
133+
Value ret = b.i32_val(0);
134+
for (int i = -nRow + 1; i < nCol; i++) {
135+
auto mask = getMask(i);
136+
if (mask == 0)
137+
continue;
138+
auto masked = b.and_(x, b.i32_val(mask));
139+
ret = b.xor_(ret, i >= 0 ? Value(b.lshr(masked, b.i32_val(i)))
140+
: Value(b.shl(masked, b.i32_val(-i))));
141+
}
142+
return ret;
143+
}
144+
98145
} // namespace triton::gpu
99146

100147
SmallVector<std::pair<StringAttr, Value>>
@@ -115,12 +162,14 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
115162

116163
// Manually constant-fold the layout where possible.
117164
SmallVector<std::pair<StringAttr, int32_t>> constantIns;
165+
SmallVector<std::pair<StringAttr, Value>> nonConstantIns;
118166
for (auto [inDimName, idx] : indices) {
119167
if (auto constant = idx.getDefiningOp<LLVM::ConstantOp>()) {
120168
constantIns.push_back(
121169
{inDimName, cast<IntegerAttr>(constant.getValue()).getInt()});
122170
} else {
123171
constantIns.push_back({inDimName, 0});
172+
nonConstantIns.push_back({inDimName, idx});
124173
}
125174
}
126175
SmallVector<int32_t> constantComponent =
@@ -134,6 +183,24 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
134183
else
135184
outIndices.push_back({outDimName, b.i32_val(constantComponent[i])});
136185
}
186+
// Happy path: Only one output.
187+
if (outIndices.size() == 1) {
188+
SmallVector<StringAttr> inDimNames;
189+
// Concatenate input
190+
Value x = b.i32_val(0);
191+
int shift = 0;
192+
for (auto [inDimName, idx] : nonConstantIns) {
193+
inDimNames.push_back(inDimName);
194+
x = b.or_(x, b.shl(idx, b.i32_val(shift)));
195+
shift += layout.getInDimSizeLog2(inDimName);
196+
}
197+
// Flatten ins
198+
auto matrix = layout.sublayout(inDimNames, outIndices[0].first);
199+
matrix = matrix.flattenIns();
200+
auto out = triton::gpu::matrixVectorProd(b, matrix, x);
201+
outIndices[0].second = b.xor_(outIndices[0].second, out);
202+
return outIndices;
203+
}
137204

138205
for (auto [inDimName, idx] : indices) {
139206
if (idx.getDefiningOp<LLVM::ConstantOp>()) {

0 commit comments

Comments
 (0)