Skip to content

Commit 92e9426

Browse files
[RELAND] Simpler codegen for linear layouts (#4554)
The upstream commit simplifies linear layout codegen for the case when the output of the linear layout is 1 dimensional (post-folding). The generic algorithm loops over the bits in the requested index for each input dimension and selects the corresponding basis value for each non-zero bit. This requires checking to see if the bit is zero (eq) and then grabbing the basis value if the bit is nonzero (select). The simplified code flattens the inputs into one dimension which allows us to evaluate the linear layout using a simple linear function `L(a) = Ba` where `a` is input index and `B` is the matrix of basis vectors from our flattened layout which is what the `matrixVectorProd` code is doing. Anyway, the end result of all this is the processing of the thread ID (the lane ID and warp ID are held constant) changes from a select to make sure the thread ID bit value is non-zero to a series of xors and shifts. However, the number of xors for the print, and even the constants used, are identical. So, the nested layout encoding propagation is still working correctly. When updating the lit test, which I believe is designed to make sure the layout nesting is followed properly, I chose to drop the pre-amble evaluation of the linear layout and keep only the last two xors. Linear layout evaluation is tested basically everywhere (including functional correctness in `test_reduce_layouts`). Here, we focus on making sure the printfs were generated properly based on the nested layouts. This should result in less maintenance burden going forward as I suspect the linear layout evaluation code will be tweaked again in the future. close #4551
2 parents 7206598 + 02ddba3 commit 92e9426

File tree

4 files changed

+85
-24
lines changed

4 files changed

+85
-24
lines changed

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,53 @@ LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op,
9595
StringAttr::get(op->getContext(), libpath));
9696
return ret;
9797
}
98+
99+
Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
100+
assert(A.getNumInDims() == 1);
101+
assert(A.getNumOutDims() == 1);
102+
auto flatten = [](const std::vector<std::vector<int32_t>> &matrix) {
103+
SmallVector<int32_t> ret;
104+
for (const auto &row : matrix) {
105+
ret.push_back(row[0]);
106+
}
107+
return ret;
108+
};
109+
auto nCol = A.getTotalInDimSizeLog2();
110+
auto nRow = A.getTotalOutDimSizeLog2();
111+
SmallVector<int32_t> matrix = flatten(A.getBases().begin()->second);
112+
assert(matrix.size() == nCol);
113+
// We iterate the matrix following the diagonals
114+
// The idea here is that we want to generate code of the form:
115+
// \xor_i (x & mask_i) << s_i
116+
// where s_i may by positive or negative (left or right shift)
117+
// The hope here (and we see it in codegen) is that LLVM can turn
118+
// the xor into a sum and then the sum + LHS/RHS can be fused into a mad.lo
119+
// Get the i-th diagonal
120+
auto getMask = [&](int i) {
121+
uint32_t mask = 0;
122+
int row = i < 0 ? -i : 0;
123+
int col = i < 0 ? 0 : i;
124+
while (row < nRow && col < nCol) {
125+
uint32_t bitValue = (matrix[col] >> row) & 1u;
126+
mask |= bitValue << col;
127+
++row;
128+
++col;
129+
}
130+
return mask;
131+
};
132+
133+
Value ret = b.i32_val(0);
134+
for (int i = -nRow + 1; i < nCol; i++) {
135+
auto mask = getMask(i);
136+
if (mask == 0)
137+
continue;
138+
auto masked = b.and_(x, b.i32_val(mask));
139+
ret = b.xor_(ret, i >= 0 ? Value(b.lshr(masked, b.i32_val(i)))
140+
: Value(b.shl(masked, b.i32_val(-i))));
141+
}
142+
return ret;
143+
}
144+
98145
} // namespace triton::gpu
99146

100147
SmallVector<std::pair<StringAttr, Value>>
@@ -115,12 +162,14 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
115162

116163
// Manually constant-fold the layout where possible.
117164
SmallVector<std::pair<StringAttr, int32_t>> constantIns;
165+
SmallVector<std::pair<StringAttr, Value>> nonConstantIns;
118166
for (auto [inDimName, idx] : indices) {
119167
if (auto constant = idx.getDefiningOp<LLVM::ConstantOp>()) {
120168
constantIns.push_back(
121169
{inDimName, cast<IntegerAttr>(constant.getValue()).getInt()});
122170
} else {
123171
constantIns.push_back({inDimName, 0});
172+
nonConstantIns.push_back({inDimName, idx});
124173
}
125174
}
126175
SmallVector<int32_t> constantComponent =
@@ -134,6 +183,24 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
134183
else
135184
outIndices.push_back({outDimName, b.i32_val(constantComponent[i])});
136185
}
186+
// Happy path: Only one output.
187+
if (outIndices.size() == 1) {
188+
SmallVector<StringAttr> inDimNames;
189+
// Concatenate input
190+
Value x = b.i32_val(0);
191+
int shift = 0;
192+
for (auto [inDimName, idx] : nonConstantIns) {
193+
inDimNames.push_back(inDimName);
194+
x = b.or_(x, b.shl(idx, b.i32_val(shift)));
195+
shift += layout.getInDimSizeLog2(inDimName);
196+
}
197+
// Flatten ins
198+
auto matrix = layout.sublayout(inDimNames, outIndices[0].first);
199+
matrix = matrix.flattenIns();
200+
auto out = triton::gpu::matrixVectorProd(b, matrix, x);
201+
outIndices[0].second = b.xor_(outIndices[0].second, out);
202+
return outIndices;
203+
}
137204

138205
for (auto [inDimName, idx] : indices) {
139206
if (idx.getDefiningOp<LLVM::ConstantOp>()) {

test/TritonIntelGPU/tritonintlgpu-nested-layout.mlir

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
179179
// CHECK-DAG: %[[CST_5:.*]] = llvm.mlir.constant(5 : i32) : i32
180180
// CHECK-DAG: %[[CST_6:.*]] = llvm.mlir.constant(6 : i32) : i32
181181
// CHECK-DAG: %[[CST_7:.*]] = llvm.mlir.constant(7 : i32) : i32
182-
// CHECK-DAG: %[[CST_8:.*]] = llvm.mlir.constant(8 : i32) : i32
183182
// CHECK-DAG: %[[CST_16:.*]] = llvm.mlir.constant(16 : i32) : i32
184183
// CHECK-DAG: %[[CST_17:.*]] = llvm.mlir.constant(17 : i32) : i32
185184
// CHECK-DAG: %[[CST_18:.*]] = llvm.mlir.constant(18 : i32) : i32
@@ -188,29 +187,24 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
188187
// CHECK-DAG: %[[CST_21:.*]] = llvm.mlir.constant(21 : i32) : i32
189188
// CHECK-DAG: %[[CST_22:.*]] = llvm.mlir.constant(22 : i32) : i32
190189
// CHECK-DAG: %[[CST_23:.*]] = llvm.mlir.constant(23 : i32) : i32
191-
// CHECK: %[[THREADS_ID:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[CST_0]])
192-
// CHECK: %[[THREADS_ID_32:.*]] = llvm.trunc %[[THREADS_ID]] : i64 to i32
193-
// CHECK: %[[WARP_ID:.*]] = llvm.udiv %[[THREADS_ID_32]], %[[CST_16]] : i32
194-
// CHECK: %[[VAL_26:.*]] = llvm.and %[[WARP_ID]], %[[CST_2]] : i32
195-
// CHECK: %[[VAL_27:.*]] = llvm.icmp "eq" %[[VAL_26]], %[[CST_0]] : i32
196-
// CHECK: %[[VAL_28:.*]] = llvm.select %[[VAL_27]], %[[CST_0]], %[[CST_8]] : i1, i32
197-
// CHECK: %[[VAL_29:.*]] = llvm.xor %[[CST_0]], %[[VAL_28]] : i32
198-
// CHECK: %[[OFFSET_X_0:.*]] = llvm.xor %[[VAL_29]], %[[CST_0]] : i32
199-
// CHECK: %[[OFFSET_X_1:.*]] = llvm.xor %[[VAL_29]], %[[CST_1]] : i32
200-
// CHECK: %[[OFFSET_X_2:.*]] = llvm.xor %[[VAL_29]], %[[CST_2]] : i32
201-
// CHECK: %[[OFFSET_X_3:.*]] = llvm.xor %[[VAL_29]], %[[CST_3]] : i32
202-
// CHECK: %[[OFFSET_X_4:.*]] = llvm.xor %[[VAL_29]], %[[CST_4]] : i32
203-
// CHECK: %[[OFFSET_X_5:.*]] = llvm.xor %[[VAL_29]], %[[CST_5]] : i32
204-
// CHECK: %[[OFFSET_X_6:.*]] = llvm.xor %[[VAL_29]], %[[CST_6]] : i32
205-
// CHECK: %[[OFFSET_X_7:.*]] = llvm.xor %[[VAL_29]], %[[CST_7]] : i32
206-
// CHECK: %[[OFFSET_X_8:.*]] = llvm.xor %[[VAL_29]], %[[CST_16]] : i32
207-
// CHECK: %[[OFFSET_X_9:.*]] = llvm.xor %[[VAL_29]], %[[CST_17]] : i32
208-
// CHECK: %[[OFFSET_X_10:.*]] = llvm.xor %[[VAL_29]], %[[CST_18]] : i32
209-
// CHECK: %[[OFFSET_X_11:.*]] = llvm.xor %[[VAL_29]], %[[CST_19]] : i32
210-
// CHECK: %[[OFFSET_X_12:.*]] = llvm.xor %[[VAL_29]], %[[CST_20]] : i32
211-
// CHECK: %[[OFFSET_X_13:.*]] = llvm.xor %[[VAL_29]], %[[CST_21]] : i32
212-
// CHECK: %[[OFFSET_X_14:.*]] = llvm.xor %[[VAL_29]], %[[CST_22]] : i32
213-
// CHECK: %[[OFFSET_X_15:.*]] = llvm.xor %[[VAL_29]], %[[CST_23]] : i32
190+
// CHECK: %[[VAL_34:.*]] = llvm.xor {{.*}} : i32
191+
// CHECK: %[[VAL_35:.*]] = llvm.xor %[[CST_0]], %[[VAL_34]] : i32
192+
// CHECK: %[[OFFSET_X_0:.*]] = llvm.xor %[[VAL_35]], %[[CST_0]] : i32
193+
// CHECK: %[[OFFSET_X_1:.*]] = llvm.xor %[[VAL_35]], %[[CST_1]] : i32
194+
// CHECK: %[[OFFSET_X_2:.*]] = llvm.xor %[[VAL_35]], %[[CST_2]] : i32
195+
// CHECK: %[[OFFSET_X_3:.*]] = llvm.xor %[[VAL_35]], %[[CST_3]] : i32
196+
// CHECK: %[[OFFSET_X_4:.*]] = llvm.xor %[[VAL_35]], %[[CST_4]] : i32
197+
// CHECK: %[[OFFSET_X_5:.*]] = llvm.xor %[[VAL_35]], %[[CST_5]] : i32
198+
// CHECK: %[[OFFSET_X_6:.*]] = llvm.xor %[[VAL_35]], %[[CST_6]] : i32
199+
// CHECK: %[[OFFSET_X_7:.*]] = llvm.xor %[[VAL_35]], %[[CST_7]] : i32
200+
// CHECK: %[[OFFSET_X_8:.*]] = llvm.xor %[[VAL_35]], %[[CST_16]] : i32
201+
// CHECK: %[[OFFSET_X_9:.*]] = llvm.xor %[[VAL_35]], %[[CST_17]] : i32
202+
// CHECK: %[[OFFSET_X_10:.*]] = llvm.xor %[[VAL_35]], %[[CST_18]] : i32
203+
// CHECK: %[[OFFSET_X_11:.*]] = llvm.xor %[[VAL_35]], %[[CST_19]] : i32
204+
// CHECK: %[[OFFSET_X_12:.*]] = llvm.xor %[[VAL_35]], %[[CST_20]] : i32
205+
// CHECK: %[[OFFSET_X_13:.*]] = llvm.xor %[[VAL_35]], %[[CST_21]] : i32
206+
// CHECK: %[[OFFSET_X_14:.*]] = llvm.xor %[[VAL_35]], %[[CST_22]] : i32
207+
// CHECK: %[[OFFSET_X_15:.*]] = llvm.xor %[[VAL_35]], %[[CST_23]] : i32
214208
// CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z18__spirv_ocl_printf({{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[OFFSET_X_0]], {{.*}}, {{.*}})
215209
// CHECK: %[[VAL_57:.*]] = llvm.call spir_funccc @_Z18__spirv_ocl_printf({{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[OFFSET_X_1]], {{.*}}, {{.*}})
216210
// CHECK: %[[VAL_58:.*]] = llvm.call spir_funccc @_Z18__spirv_ocl_printf({{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[OFFSET_X_2]], {{.*}}, {{.*}})

0 commit comments

Comments
 (0)