Skip to content

Commit 0844235

Browse files
Revert "[BACKEND] simpler codegen for linear layouts (#7201)"
This reverts commit 9d11c09.
1 parent cc1a80c commit 0844235

File tree

3 files changed

+0
-67
lines changed

3 files changed

+0
-67
lines changed

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 0 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -95,53 +95,6 @@ LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op,
9595
StringAttr::get(op->getContext(), libpath));
9696
return ret;
9797
}
98-
99-
Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
100-
assert(A.getNumInDims() == 1);
101-
assert(A.getNumOutDims() == 1);
102-
auto flatten = [](const std::vector<std::vector<int32_t>> &matrix) {
103-
SmallVector<int32_t> ret;
104-
for (const auto &row : matrix) {
105-
ret.push_back(row[0]);
106-
}
107-
return ret;
108-
};
109-
auto nCol = A.getTotalInDimSizeLog2();
110-
auto nRow = A.getTotalOutDimSizeLog2();
111-
SmallVector<int32_t> matrix = flatten(A.getBases().begin()->second);
112-
assert(matrix.size() == nCol);
113-
// We iterate the matrix following the diagonals
114-
// The idea here is that we want to generate code of the form:
115-
// \xor_i (x & mask_i) << s_i
116-
// where s_i may by positive or negative (left or right shift)
117-
// The hope here (and we see it in codegen) is that LLVM can turn
118-
// the xor into a sum and then the sum + LHS/RHS can be fused into a mad.lo
119-
// Get the i-th diagonal
120-
auto getMask = [&](int i) {
121-
uint32_t mask = 0;
122-
int row = i < 0 ? -i : 0;
123-
int col = i < 0 ? 0 : i;
124-
while (row < nRow && col < nCol) {
125-
uint32_t bitValue = (matrix[col] >> row) & 1u;
126-
mask |= bitValue << col;
127-
++row;
128-
++col;
129-
}
130-
return mask;
131-
};
132-
133-
Value ret = b.i32_val(0);
134-
for (int i = -nRow + 1; i < nCol; i++) {
135-
auto mask = getMask(i);
136-
if (mask == 0)
137-
continue;
138-
auto masked = b.and_(x, b.i32_val(mask));
139-
ret = b.xor_(ret, i >= 0 ? Value(b.lshr(masked, b.i32_val(i)))
140-
: Value(b.shl(masked, b.i32_val(-i))));
141-
}
142-
return ret;
143-
}
144-
14598
} // namespace triton::gpu
14699

147100
SmallVector<std::pair<StringAttr, Value>>
@@ -162,14 +115,12 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
162115

163116
// Manually constant-fold the layout where possible.
164117
SmallVector<std::pair<StringAttr, int32_t>> constantIns;
165-
SmallVector<std::pair<StringAttr, Value>> nonConstantIns;
166118
for (auto [inDimName, idx] : indices) {
167119
if (auto constant = idx.getDefiningOp<LLVM::ConstantOp>()) {
168120
constantIns.push_back(
169121
{inDimName, cast<IntegerAttr>(constant.getValue()).getInt()});
170122
} else {
171123
constantIns.push_back({inDimName, 0});
172-
nonConstantIns.push_back({inDimName, idx});
173124
}
174125
}
175126
SmallVector<int32_t> constantComponent =
@@ -183,24 +134,6 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
183134
else
184135
outIndices.push_back({outDimName, b.i32_val(constantComponent[i])});
185136
}
186-
// Happy path: Only one output.
187-
if (outIndices.size() == 1) {
188-
SmallVector<StringAttr> inDimNames;
189-
// Concatenate input
190-
Value x = b.i32_val(0);
191-
int shift = 0;
192-
for (auto [inDimName, idx] : nonConstantIns) {
193-
inDimNames.push_back(inDimName);
194-
x = b.or_(x, b.shl(idx, b.i32_val(shift)));
195-
shift += layout.getInDimSizeLog2(inDimName);
196-
}
197-
// Flatten ins
198-
auto matrix = layout.sublayout(inDimNames, outIndices[0].first);
199-
matrix = matrix.flattenIns();
200-
auto out = triton::gpu::matrixVectorProd(b, matrix, x);
201-
outIndices[0].second = b.xor_(outIndices[0].second, out);
202-
return outIndices;
203-
}
204137

205138
for (auto [inDimName, idx] : indices) {
206139
if (idx.getDefiningOp<LLVM::ConstantOp>()) {

0 commit comments

Comments
 (0)