@@ -95,6 +95,53 @@ LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op,
9595 StringAttr::get (op->getContext (), libpath));
9696 return ret;
9797}
98+
99+ Value matrixVectorProd (TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
100+ assert (A.getNumInDims () == 1 );
101+ assert (A.getNumOutDims () == 1 );
102+ auto flatten = [](const std::vector<std::vector<int32_t >> &matrix) {
103+ SmallVector<int32_t > ret;
104+ for (const auto &row : matrix) {
105+ ret.push_back (row[0 ]);
106+ }
107+ return ret;
108+ };
109+ auto nCol = A.getTotalInDimSizeLog2 ();
110+ auto nRow = A.getTotalOutDimSizeLog2 ();
111+ SmallVector<int32_t > matrix = flatten (A.getBases ().begin ()->second );
112+ assert (matrix.size () == nCol);
113+ // We iterate the matrix following the diagonals
114+ // The idea here is that we want to generate code of the form:
115+ // \xor_i (x & mask_i) << s_i
116+ // where s_i may by positive or negative (left or right shift)
117+ // The hope here (and we see it in codegen) is that LLVM can turn
118+ // the xor into a sum and then the sum + LHS/RHS can be fused into a mad.lo
119+ // Get the i-th diagonal
120+ auto getMask = [&](int i) {
121+ uint32_t mask = 0 ;
122+ int row = i < 0 ? -i : 0 ;
123+ int col = i < 0 ? 0 : i;
124+ while (row < nRow && col < nCol) {
125+ uint32_t bitValue = (matrix[col] >> row) & 1u ;
126+ mask |= bitValue << col;
127+ ++row;
128+ ++col;
129+ }
130+ return mask;
131+ };
132+
133+ Value ret = b.i32_val (0 );
134+ for (int i = -nRow + 1 ; i < nCol; i++) {
135+ auto mask = getMask (i);
136+ if (mask == 0 )
137+ continue ;
138+ auto masked = b.and_ (x, b.i32_val (mask));
139+ ret = b.xor_ (ret, i >= 0 ? Value (b.lshr (masked, b.i32_val (i)))
140+ : Value (b.shl (masked, b.i32_val (-i))));
141+ }
142+ return ret;
143+ }
144+
98145} // namespace triton::gpu
99146
100147SmallVector<std::pair<StringAttr, Value>>
@@ -115,12 +162,14 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
115162
116163 // Manually constant-fold the layout where possible.
117164 SmallVector<std::pair<StringAttr, int32_t >> constantIns;
165+ SmallVector<std::pair<StringAttr, Value>> nonConstantIns;
118166 for (auto [inDimName, idx] : indices) {
119167 if (auto constant = idx.getDefiningOp <LLVM::ConstantOp>()) {
120168 constantIns.push_back (
121169 {inDimName, cast<IntegerAttr>(constant.getValue ()).getInt ()});
122170 } else {
123171 constantIns.push_back ({inDimName, 0 });
172+ nonConstantIns.push_back ({inDimName, idx});
124173 }
125174 }
126175 SmallVector<int32_t > constantComponent =
@@ -134,6 +183,24 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
134183 else
135184 outIndices.push_back ({outDimName, b.i32_val (constantComponent[i])});
136185 }
186+ // Happy path: Only one output.
187+ if (outIndices.size () == 1 ) {
188+ SmallVector<StringAttr> inDimNames;
189+ // Concatenate input
190+ Value x = b.i32_val (0 );
191+ int shift = 0 ;
192+ for (auto [inDimName, idx] : nonConstantIns) {
193+ inDimNames.push_back (inDimName);
194+ x = b.or_ (x, b.shl (idx, b.i32_val (shift)));
195+ shift += layout.getInDimSizeLog2 (inDimName);
196+ }
197+ // Flatten ins
198+ auto matrix = layout.sublayout (inDimNames, outIndices[0 ].first );
199+ matrix = matrix.flattenIns ();
200+ auto out = triton::gpu::matrixVectorProd (b, matrix, x);
201+ outIndices[0 ].second = b.xor_ (outIndices[0 ].second , out);
202+ return outIndices;
203+ }
137204
138205 for (auto [inDimName, idx] : indices) {
139206 if (idx.getDefiningOp <LLVM::ConstantOp>()) {
0 commit comments