|
| 1 | +#include "mlir/Analysis/SliceAnalysis.h" |
1 | 2 | #include "mlir/IR/TypeUtilities.h" |
2 | 3 | #include "mlir/Pass/PassManager.h" |
3 | 4 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
@@ -181,27 +182,89 @@ class TMemSplitLoadPattern : public OpRewritePattern<tt::SplitOp> { |
181 | 182 | } |
182 | 183 | }; |
183 | 184 |
|
184 | | -class TritonNvidiaGPUOptimizeTMemSubtilingPass |
185 | | - : public TritonNvidiaGPUOptimizeTMemSubtilingPassBase< |
186 | | - TritonNvidiaGPUOptimizeTMemSubtilingPass> { |
| 185 | +// Pick an optimized tmem load layout based on its users. When there are |
| 186 | +// multiple warpgroups tmem_load results can be distirbuted along M or N across |
| 187 | +// the warpgroups. By default distribute along N but when there is a reduction |
| 188 | +// along N dimension we want to distribute along M instead to avoid having to |
| 189 | +// reduce across warps. |
| 190 | +class TMemLoadReducePattern : public OpRewritePattern<ttng::TMEMLoadOp> { |
187 | 191 | public: |
188 | | - using BaseT = TritonNvidiaGPUOptimizeTMemSubtilingPassBase< |
189 | | - TritonNvidiaGPUOptimizeTMemSubtilingPass>; |
| 192 | + using OpRewritePattern::OpRewritePattern; |
| 193 | + |
| 194 | + LogicalResult matchAndRewrite(ttng::TMEMLoadOp tmemLoadOp, |
| 195 | + PatternRewriter &rewriter) const override { |
| 196 | + int numWarps = ttg::lookupNumWarps(tmemLoadOp); |
| 197 | + // If there is only 1 warpgroup there is nothing to optimize as the layout |
| 198 | + // is already reduction friendly. |
| 199 | + if (numWarps != 8) |
| 200 | + return failure(); |
| 201 | + auto tmemEnc = dyn_cast<triton::nvidia_gpu::TensorMemoryEncodingAttr>( |
| 202 | + tmemLoadOp.getSrc().getType().getEncoding()); |
| 203 | + if (!tmemEnc) |
| 204 | + return failure(); |
| 205 | + int M = tmemEnc.getBlockM(); |
| 206 | + int N = tmemEnc.getBlockN(); |
| 207 | + if (M != 128) |
| 208 | + return failure(); |
| 209 | + bool foundReductionAlongN = false; |
| 210 | + auto filter = [&](Operation *op) { |
| 211 | + if (isa<ttg::ConvertLayoutOp>(op) || op->hasTrait<OpTrait::Elementwise>()) |
| 212 | + return true; |
| 213 | + if (auto reduce = dyn_cast<triton::ReduceOp>(op)) { |
| 214 | + foundReductionAlongN = reduce.getAxis() == 1; |
| 215 | + } |
| 216 | + return false; |
| 217 | + }; |
| 218 | + ForwardSliceOptions fwdOpt; |
| 219 | + fwdOpt.filter = filter; |
| 220 | + SetVector<mlir::Operation *> fwdSlices; |
| 221 | + getForwardSlice(tmemLoadOp.getResult(), &fwdSlices, fwdOpt); |
| 222 | + if (!foundReductionAlongN) |
| 223 | + return failure(); |
| 224 | + // Try to split along M dimension but follow the restrictions of TMEM: |
| 225 | + // warp0 get M = 0, warp 1 gets M = 32, warp 2 gets M = 64, warp 3 gets |
| 226 | + // M = 96 warp 4 gets M = 16, warp 5 gets M = 48, warp 6 gets M = 80, |
| 227 | + // warp 7 gets M = 112 |
| 228 | + RankedTensorType oldType = tmemLoadOp.getType(); |
| 229 | + Attribute newLayout = ttg::LinearEncodingAttr::get( |
| 230 | + tmemLoadOp.getContext(), |
| 231 | + ttg::getTmemLoadLayoutSplitLongM(M, N, oldType, numWarps)); |
| 232 | + if (newLayout == oldType.getEncoding()) |
| 233 | + return failure(); |
| 234 | + |
| 235 | + auto newType = RankedTensorType::get(oldType.getShape(), |
| 236 | + oldType.getElementType(), newLayout); |
| 237 | + tmemLoadOp.getResult().setType(newType); |
| 238 | + OpBuilder builder(tmemLoadOp); |
| 239 | + builder.setInsertionPointAfter(tmemLoadOp); |
| 240 | + auto cvt = builder.create<ttg::ConvertLayoutOp>( |
| 241 | + tmemLoadOp.getLoc(), oldType, tmemLoadOp.getResult()); |
| 242 | + tmemLoadOp.getResult().replaceAllUsesExcept(cvt.getResult(), cvt); |
| 243 | + return success(); |
| 244 | + } |
| 245 | +}; |
| 246 | + |
| 247 | +class TritonNvidiaGPUOptimizeTMemLayoutsPass |
| 248 | + : public TritonNvidiaGPUOptimizeTMemLayoutsPassBase< |
| 249 | + TritonNvidiaGPUOptimizeTMemLayoutsPass> { |
| 250 | +public: |
| 251 | + using BaseT = TritonNvidiaGPUOptimizeTMemLayoutsPassBase< |
| 252 | + TritonNvidiaGPUOptimizeTMemLayoutsPass>; |
190 | 253 | using BaseT::BaseT; |
191 | 254 |
|
192 | 255 | void runOnOperation() override { |
193 | 256 | MLIRContext *context = &getContext(); |
194 | 257 | ModuleOp m = getOperation(); |
195 | 258 |
|
196 | 259 | mlir::RewritePatternSet patterns(context); |
197 | | - patterns.add<TMemSplitLoadPattern>(context); |
| 260 | + patterns.add<TMemSplitLoadPattern, TMemLoadReducePattern>(context); |
198 | 261 | if (failed(applyPatternsGreedily(m, std::move(patterns)))) |
199 | 262 | signalPassFailure(); |
200 | 263 | } |
201 | 264 | }; |
202 | 265 |
|
203 | 266 | } // namespace |
204 | 267 |
|
205 | | -std::unique_ptr<Pass> mlir::createTritonNvidiaGPUOptimizeTMemSubtilingPass() { |
206 | | - return std::make_unique<TritonNvidiaGPUOptimizeTMemSubtilingPass>(); |
| 268 | +std::unique_ptr<Pass> mlir::createTritonNvidiaGPUOptimizeTMemLayoutsPass() { |
| 269 | + return std::make_unique<TritonNvidiaGPUOptimizeTMemLayoutsPass>(); |
207 | 270 | } |
0 commit comments