Skip to content

Commit 9a3f308

Browse files
Merge commit 'ade4d3ac30dbe4a8de9d3da1441160544beb6d79'
2 parents 402d57c + ade4d3a commit 9a3f308

File tree

43 files changed

+3233
-1875
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+3233
-1875
lines changed

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,16 @@ def TritonGPUOptimizePartitionWarps : Pass<"tritongpu-optimize-partition-warps",
165165
}];
166166
}
167167

168+
def TritonGPUPartitionScheduling : Pass<"tritongpu-partition-scheduling", "mlir::ModuleOp"> {
169+
let summary = "warp specialization partitioning pass";
170+
171+
let description = [{
172+
The `tritongpu-partition-scheduling` analyzes the loads, MMAs, and other
173+
operations in a loop that is meant to be warp specialized and determines
174+
which partitions to assign to each operation.
175+
}];
176+
}
177+
168178
def TritonGPULoadMMASpecialization : Pass<"tritongpu-load-mma-specialization", "mlir::ModuleOp"> {
169179
let summary = "load MMA specialization";
170180

@@ -219,23 +229,6 @@ def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
219229
"mlir::arith::ArithDialect"];
220230
}
221231

222-
def TritonGPUWGMMAPrefetch : Pass<"tritongpu-wgmma-prefetch", "mlir::ModuleOp"> {
223-
let summary = "prefetch for wgmma mixed precision";
224-
225-
let description = [{
226-
This pass attempts to prefetch from shared memory for mixed-precision
227-
wgmma when operand A is in the shared memory and needs to be loaded
228-
to the local registers.
229-
}];
230-
231-
let dependentDialects = [ "mlir::triton::gpu::TritonGPUDialect",
232-
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
233-
"mlir::scf::SCFDialect",
234-
"mlir::arith::ArithDialect"];
235-
}
236-
237-
238-
239232
def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::ModuleOp"> {
240233
let summary = "accelerate matmul";
241234

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ getNumElementsPerThread(Operation *op, SmallVector<unsigned> order,
5454
// Returns whether the op is a "view op", i.e. doesn't move any data
5555
bool isView(Operation *op);
5656

57+
// Returns whether the op is a "noop op", i.e. has one input and one output
58+
// and lowers to llvm as the identity function (returns the input)
59+
bool isNoop(Operation *op);
60+
5761
/* Dump Triton IR in graphviz dot format.
5862
*
5963
* You can override `onValue` and `onOperation` in a subclass to mark

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -441,45 +441,7 @@ struct MemDescSubviewOpConversion
441441
for (int i = rankReduced; i < opOffsetVals.size(); i++) {
442442
offsetVals.push_back(b.add(opOffsetVals[i], smemObj.getOffsets()[i]));
443443
}
444-
Value offset;
445-
auto allocShape = srcTy.getAllocShape();
446-
auto nvmmaEnc = dyn_cast<NVMMASharedEncodingAttr>(enc);
447-
bool isSimpleSubview =
448-
(!nvmmaEnc || allocShape.take_back(destRank) == destTy.getShape() ||
449-
nvmmaEnc.getSwizzlingByteWidth() == 0);
450-
if (!isSimpleSubview) {
451-
assert(destRank >= 2 &&
452-
"Shape size should be >= 2 when using NVMMAShared encoding");
453-
auto swizzleStride = b.i32_val((nvmmaEnc.getSwizzlingByteWidth() * 8) /
454-
llvmElemTy.getIntOrFloatBitWidth());
455-
offset = b.i32_val(0);
456-
for (auto i = 0; i < opOffsetVals.size() - 2; ++i) {
457-
offset = b.add(offset, b.mul(opOffsetVals[i], opSmemStrides[i]));
458-
}
459-
// newOffset = offset - (stridedOff * swizzledStride + contigOff /
460-
// swizzledStride * tileSize + contigOff % swizzledStride)
461-
// + stridedInc * swizzledStride + contigInc / swizzledStride *
462-
// tileSize + contigInc % swizzledStride
463-
auto stridedDim = destRank - 1 - layoutOrder[0];
464-
auto contigDim = destRank - 1 - layoutOrder[1];
465-
auto stridedOff = smemObj.getOffsets()[stridedDim];
466-
auto contigOff = smemObj.getOffsets()[contigDim];
467-
auto stridedInc = offsetVals[stridedDim];
468-
auto contigInc = offsetVals[contigDim];
469-
int allocStridedDim = allocShape.size() - 1 - layoutOrder[0];
470-
auto tileSize =
471-
b.mul(b.i32_val(allocShape[allocStridedDim]), swizzleStride);
472-
offset = b.sub(offset, b.mul(stridedOff, swizzleStride));
473-
offset = b.sub(offset, b.mul(b.udiv(contigOff, swizzleStride), tileSize));
474-
offset = b.sub(offset, b.urem(contigOff, swizzleStride));
475-
offset = b.add(offset, b.mul(stridedInc, swizzleStride));
476-
offset = b.add(offset, b.mul(b.udiv(contigInc, swizzleStride), tileSize));
477-
offset = b.add(offset, b.urem(contigInc, swizzleStride));
478-
} else {
479-
// Compute the offset based on the original strides of the shared memory
480-
// object
481-
offset = dot(rewriter, loc, opOffsetVals, opSmemStrides);
482-
}
444+
Value offset = dot(rewriter, loc, opOffsetVals, opSmemStrides);
483445
auto base = smemObj.getBase();
484446
auto elemPtrTy = base.getType();
485447
smemObj = SharedMemoryObject(b.gep(elemPtrTy, llvmElemTy, base, offset),

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ add_triton_library(TritonGPUTransforms
2525
Pipeliner/PipeliningUtility.cpp
2626
Pipeliner/Schedule.cpp
2727
Prefetch.cpp
28-
WGMMAPrefetch.cpp
2928
RemoveLayoutConversions.cpp
3029
ReorderInstructions.cpp
3130
CoalesceAsyncCopy.cpp
@@ -35,6 +34,7 @@ add_triton_library(TritonGPUTransforms
3534
WarpSpecialization/Partition.cpp
3635
WarpSpecialization/OptimizePartitionWarps.cpp
3736
WarpSpecialization/PartitionLoops.cpp
37+
WarpSpecialization/PartitionScheduling.cpp
3838
WarpSpecialization/RewritePartitionDependencies.cpp
3939

4040
DEPENDS

0 commit comments

Comments
 (0)