Skip to content

Commit 7dc5492

Browse files
authored
[PIPELINER] Pipeline RS WGMMA (#6804) (#6812)
This PR allows to pipeline WGMMAs that take the lhs on registers. The strategy is to wait on the WGMMA from the previous loop to have finished before executing the next one to avoid overwritting the registers too early. Note that this does depend on ptxas handling the register allocation correctly. This PR also includes: - A fix for the WGMMAPrefetch with `swizzlingByteWidth = 128`, which produced wrong results - A fix in the way we lower `memdesc_subview` (now it's simpler again) - A fix in the way we lower mmav3 and mmav5 (which renders the complex path in `memdesc_subview` unnecessary) All these are tested end-to-end via the improved `test_cast_matmul.py` In an 8k x 8k x 8k dense bf16 x mxfp4 matmul we get a speed up of: 2.441 -> 2.039 We might need to split the pointwise computations and interleave them with the wgmmas similar to how CUTLASS does it, but we don't do that in this PR. This PR supersedes WGMMAPrefetch as it drops most of the preconditions of that pass.
1 parent a3f5ea6 commit 7dc5492

File tree

15 files changed

+455
-1190
lines changed

15 files changed

+455
-1190
lines changed

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -219,23 +219,6 @@ def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
219219
"mlir::arith::ArithDialect"];
220220
}
221221

222-
def TritonGPUWGMMAPrefetch : Pass<"tritongpu-wgmma-prefetch", "mlir::ModuleOp"> {
223-
let summary = "prefetch for wgmma mixed precision";
224-
225-
let description = [{
226-
This pass attempts to prefetch from shared memory for mixed-precision
227-
wgmma when operand A is in the shared memory and needs to be loaded
228-
to the local registers.
229-
}];
230-
231-
let dependentDialects = [ "mlir::triton::gpu::TritonGPUDialect",
232-
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
233-
"mlir::scf::SCFDialect",
234-
"mlir::arith::ArithDialect"];
235-
}
236-
237-
238-
239222
def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::ModuleOp"> {
240223
let summary = "accelerate matmul";
241224

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ getNumElementsPerThread(Operation *op, SmallVector<unsigned> order,
5454
// Returns whether the op is a "view op", i.e. doesn't move any data
5555
bool isView(Operation *op);
5656

57+
// Returns whether the op is a "noop op", i.e. has one input and one output
58+
// and lowers to llvm as the identity function (returns the input)
59+
bool isNoop(Operation *op);
60+
5761
/* Dump Triton IR in graphviz dot format.
5862
*
5963
* You can override `onValue` and `onOperation` in a subclass to mark

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -441,45 +441,7 @@ struct MemDescSubviewOpConversion
441441
for (int i = rankReduced; i < opOffsetVals.size(); i++) {
442442
offsetVals.push_back(b.add(opOffsetVals[i], smemObj.getOffsets()[i]));
443443
}
444-
Value offset;
445-
auto allocShape = srcTy.getAllocShape();
446-
auto nvmmaEnc = dyn_cast<NVMMASharedEncodingAttr>(enc);
447-
bool isSimpleSubview =
448-
(!nvmmaEnc || allocShape.take_back(destRank) == destTy.getShape() ||
449-
nvmmaEnc.getSwizzlingByteWidth() == 0);
450-
if (!isSimpleSubview) {
451-
assert(destRank >= 2 &&
452-
"Shape size should be >= 2 when using NVMMAShared encoding");
453-
auto swizzleStride = b.i32_val((nvmmaEnc.getSwizzlingByteWidth() * 8) /
454-
llvmElemTy.getIntOrFloatBitWidth());
455-
offset = b.i32_val(0);
456-
for (auto i = 0; i < opOffsetVals.size() - 2; ++i) {
457-
offset = b.add(offset, b.mul(opOffsetVals[i], opSmemStrides[i]));
458-
}
459-
// newOffset = offset - (stridedOff * swizzledStride + contigOff /
460-
// swizzledStride * tileSize + contigOff % swizzledStride)
461-
// + stridedInc * swizzledStride + contigInc / swizzledStride *
462-
// tileSize + contigInc % swizzledStride
463-
auto stridedDim = destRank - 1 - layoutOrder[0];
464-
auto contigDim = destRank - 1 - layoutOrder[1];
465-
auto stridedOff = smemObj.getOffsets()[stridedDim];
466-
auto contigOff = smemObj.getOffsets()[contigDim];
467-
auto stridedInc = offsetVals[stridedDim];
468-
auto contigInc = offsetVals[contigDim];
469-
int allocStridedDim = allocShape.size() - 1 - layoutOrder[0];
470-
auto tileSize =
471-
b.mul(b.i32_val(allocShape[allocStridedDim]), swizzleStride);
472-
offset = b.sub(offset, b.mul(stridedOff, swizzleStride));
473-
offset = b.sub(offset, b.mul(b.udiv(contigOff, swizzleStride), tileSize));
474-
offset = b.sub(offset, b.urem(contigOff, swizzleStride));
475-
offset = b.add(offset, b.mul(stridedInc, swizzleStride));
476-
offset = b.add(offset, b.mul(b.udiv(contigInc, swizzleStride), tileSize));
477-
offset = b.add(offset, b.urem(contigInc, swizzleStride));
478-
} else {
479-
// Compute the offset based on the original strides of the shared memory
480-
// object
481-
offset = dot(rewriter, loc, opOffsetVals, opSmemStrides);
482-
}
444+
Value offset = dot(rewriter, loc, opOffsetVals, opSmemStrides);
483445
auto base = smemObj.getBase();
484446
auto elemPtrTy = base.getType();
485447
smemObj = SharedMemoryObject(b.gep(elemPtrTy, llvmElemTy, base, offset),

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ add_triton_library(TritonGPUTransforms
2525
Pipeliner/PipeliningUtility.cpp
2626
Pipeliner/Schedule.cpp
2727
Prefetch.cpp
28-
WGMMAPrefetch.cpp
2928
RemoveLayoutConversions.cpp
3029
ReorderInstructions.cpp
3130
CoalesceAsyncCopy.cpp

0 commit comments

Comments
 (0)