Skip to content

Commit fe79064

Browse files
Merge commit '336cc1d530fe9df8db610e880330b9fa4de82925'
2 parents dc3c13d + 336cc1d commit fe79064

File tree

5 files changed

+64
-52
lines changed

5 files changed

+64
-52
lines changed

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
309309
auto totalStoreCvt = srcLayout.invertAndCompose(smem);
310310
auto totalLoadCvt = dstLayout.invertAndCompose(smem);
311311

312-
// FIXME(Lezcano): The legacy path also creates PRMT, so we should revisit
313-
314312
// The permutation exists by construction of the reps dimension in
315313
// optimalSwizzling
316314
auto permStore =

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,10 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
151151
auto b = TritonLLVMOpBuilder(loc, rewriter);
152152
assert(layout.getNumInDims() == indices.size());
153153
assert(llvm::equal(layout.getInDimNames(), llvm::make_first_range(indices)));
154+
// Trivial layout
155+
if (layout.getNumOutDims() == 0) {
156+
return {};
157+
}
154158

155159
// This function can emit a lot of MLIR code, which ultimately makes
156160
// compilation slow. (We think this shouldn't be the case -- it's not *that*
@@ -164,62 +168,65 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
164168
SmallVector<std::pair<StringAttr, int32_t>> constantIns;
165169
SmallVector<std::pair<StringAttr, Value>> nonConstantIns;
166170
for (auto [inDimName, idx] : indices) {
167-
if (auto constant = idx.getDefiningOp<LLVM::ConstantOp>()) {
168-
constantIns.push_back(
169-
{inDimName, cast<IntegerAttr>(constant.getValue()).getInt()});
171+
APInt constant;
172+
if (matchPattern(idx, m_ConstantInt(&constant))) {
173+
constantIns.push_back({inDimName, constant.getSExtValue()});
170174
} else {
171175
constantIns.push_back({inDimName, 0});
172176
nonConstantIns.push_back({inDimName, idx});
173177
}
174178
}
175-
SmallVector<int32_t> constantComponent =
176-
llvm::to_vector(llvm::make_second_range(layout.apply(constantIns)));
177179

180+
// Compute constant part of the output and wrap it as values
178181
Value zero = b.i32_val(0);
179182
SmallVector<std::pair<StringAttr, Value>> outIndices;
180-
for (auto [i, outDimName] : llvm::enumerate(layout.getOutDimNames())) {
181-
if (constantComponent[i] == 0)
183+
for (auto [outDimName, constant] : layout.apply(constantIns)) {
184+
if (constant == 0)
182185
outIndices.push_back({outDimName, zero});
183186
else
184-
outIndices.push_back({outDimName, b.i32_val(constantComponent[i])});
187+
outIndices.push_back({outDimName, b.i32_val(constant)});
188+
}
189+
190+
if (nonConstantIns.size() == 0) {
191+
return outIndices;
185192
}
186-
// Happy path: Only one output.
187-
if (outIndices.size() == 1) {
188-
SmallVector<StringAttr> inDimNames;
189-
// Concatenate input
190-
Value x = b.i32_val(0);
193+
194+
// Concatenate input
195+
Value x = b.i32_val(0);
196+
if (nonConstantIns.size() == 1) {
197+
x = nonConstantIns[0].second;
198+
} else {
191199
int shift = 0;
192200
for (auto [inDimName, idx] : nonConstantIns) {
193-
inDimNames.push_back(inDimName);
194201
x = b.or_(x, b.shl(idx, b.i32_val(shift)));
195202
shift += layout.getInDimSizeLog2(inDimName);
196203
}
197-
// Flatten ins
198-
auto matrix = layout.sublayout(inDimNames, outIndices[0].first);
199-
matrix = matrix.flattenIns();
200-
auto out = triton::gpu::matrixVectorProd(b, matrix, x);
201-
outIndices[0].second = b.xor_(outIndices[0].second, out);
202-
return outIndices;
203204
}
204205

205-
for (auto [inDimName, idx] : indices) {
206-
if (idx.getDefiningOp<LLVM::ConstantOp>()) {
207-
continue;
208-
}
206+
// Remove constant input dims from the layout and flatten it
207+
auto inDimNames = llvm::to_vector(llvm::make_first_range(nonConstantIns));
208+
auto matrix = layout.sublayout(
209+
inDimNames, llvm::to_vector(llvm::make_first_range(outIndices)));
210+
auto flatMatrix = matrix.flattenIns().flattenOuts();
211+
212+
// Lower the matrix-vector product
213+
auto out = triton::gpu::matrixVectorProd(b, flatMatrix, x);
209214

210-
int nBits = layout.getInDimSizeLog2(inDimName);
211-
for (int i = 0; i < nBits; i++) {
212-
Value bit = b.and_(idx, b.i32_val(1 << i));
213-
Value bit_is_zero = b.icmp_eq(bit, zero);
214-
for (auto &[outDimName, outIdx] : outIndices) {
215-
int32_t basis = layout.getBasis(inDimName, i, outDimName);
216-
if (basis == 0)
217-
continue;
218-
outIdx = b.xor_(outIdx, b.select(bit_is_zero, zero, b.i32_val(basis)));
219-
}
215+
// Unpack the output
216+
if (matrix.getNumOutDims() == 1) {
217+
outIndices[0].second = b.xor_(outIndices[0].second, out);
218+
} else {
219+
assert(llvm::equal(matrix.getOutDimNames(),
220+
llvm::make_first_range(outIndices)));
221+
int shift = 0;
222+
for (auto &[dimName, outIdx] : outIndices) {
223+
auto outDimSizeLog2 = layout.getOutDimSizeLog2(dimName);
224+
auto mask = (1 << outDimSizeLog2) - 1;
225+
outIdx = b.xor_(outIdx,
226+
b.and_(b.lshr(out, b.i32_val(shift)), b.i32_val(mask)));
227+
shift += outDimSizeLog2;
220228
}
221229
}
222-
223230
return outIndices;
224231
}
225232

lib/Tools/GenericSwizzling.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,9 @@ LinearLayout optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
311311
// Bits in a bank segment: 32 banks x 32 bits
312312
constexpr int32_t bankBits = 32 * 32;
313313
// Bases needed to cover a whole bank segment
314-
const int32_t lenBbasis =
315-
llvm::Log2_32(bankBits / ((1 << vbasis.size()) * bitwidth));
314+
const int32_t lenBbasis = std::min<int32_t>(
315+
llvm::Log2_32(bankBits / ((1 << vbasis.size()) * bitwidth)),
316+
dim - vbasis.size());
316317
// Bases to cover all the tensor
317318
const int32_t lenSbasis = dim - lenBbasis - vbasis.size();
318319

python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,17 @@ def _p_matmul_ogs(
384384
block_shape=[BLOCK_M, OUT_BLOCK_N],
385385
)
386386

387+
# bias + scale
388+
offs_y_n = off_n1 + tl.arange(0, BLOCK_N)
389+
mask_n = offs_y_n < N
390+
if B is not None:
391+
BPtrs = B + expt_id1 * stride_b_e + offs_y_n
392+
if pid_k1 == 0:
393+
bias = tl.load(BPtrs, mask=mask_n, other=0)
394+
else:
395+
bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
396+
else:
397+
bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
387398
if Betas is not None:
388399
betas = tl.load(Betas + start_m1 + offs_m, mask=mask_m, other=0.0)
389400
else:
@@ -399,15 +410,21 @@ def _p_matmul_ogs(
399410
w_scale = load_scale(WScale)
400411

401412
accs = (acc,)
413+
biases = (bias,)
402414

403415
if SUBTILE_FACTOR >= 2:
404416
acc0, acc1 = acc.reshape(BLOCK_M, 2, BLOCK_N // 2).permute(0, 2, 1).split()
405417
accs = (acc0, acc1)
418+
bias0, bias1 = bias.reshape(2, BLOCK_N // 2).permute(1, 0).split()
419+
biases = (bias0, bias1)
406420

407421
if SUBTILE_FACTOR >= 4:
408422
acc00, acc01 = acc0.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1).split()
409423
acc10, acc11 = acc1.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1).split()
410424
accs = (acc00, acc01, acc10, acc11)
425+
bias00, bias01 = bias0.reshape(2, BLOCK_N // 4).permute(1, 0).split()
426+
bias10, bias11 = bias1.reshape(2, BLOCK_N // 4).permute(1, 0).split()
427+
biases = (bias00, bias01, bias10, bias11)
411428

412429
tl.static_assert(EPILOGUE_BLOCK_N == BLOCK_N // SUBTILE_FACTOR)
413430
tl.static_assert(len(accs) == SUBTILE_FACTOR)
@@ -419,18 +436,7 @@ def _p_matmul_ogs(
419436
if SWAP_XW:
420437
acc_tile = acc_tile.T
421438

422-
if B is not None:
423-
offs_y_n = off_n1 + EPILOGUE_BLOCK_N * a_i + tl.arange(0, EPILOGUE_BLOCK_N)
424-
mask_n = offs_y_n < N
425-
BPtrs = B + expt_id1 * stride_b_e + offs_y_n
426-
if pid_k1 == 0:
427-
bias = tl.load(BPtrs, mask=mask_n, other=0)
428-
else:
429-
bias = tl.full([EPILOGUE_BLOCK_N], 0, dtype=tl.float32)
430-
else:
431-
bias = tl.full([EPILOGUE_BLOCK_N], 0, dtype=tl.float32)
432-
433-
acc_tile = acc_tile + bias[None, :] * betas[:, None]
439+
acc_tile = acc_tile + biases[a_i][None, :] * betas[:, None]
434440
if out_alpha is not None:
435441
acc_tile *= out_alpha
436442

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1293,7 +1293,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
12931293
// CHECK-LABEL: linear_layout_with_multiple_iterations
12941294
tt.func @linear_layout_with_multiple_iterations(%src: tensor<8x4xbf16, #linear>) {
12951295
%cvt = ttg.convert_layout %src : tensor<8x4xbf16, #linear> -> tensor<8x4xbf16, #linear1>
1296-
// CHECK-COUNT-2: llvm.store {{.*}} : vector<2xi16>
1296+
// CHECK-COUNT-1: llvm.store {{.*}} : vector<4xi16>
12971297
// CHECK: nvvm.barrier0
12981298
// CHECK-COUNT: llvm.load{{.*}}->vector<2xi16>
12991299
tt.return

0 commit comments

Comments
 (0)