|
1 | 1 | #include "mlir/IR/PatternMatch.h" |
| 2 | +#include "triton/Dialect/Triton/IR/Dialect.h" |
2 | 3 |
|
3 | 4 | namespace mlir::triton::gpu { |
4 | 5 |
|
| 6 | +class DecomposeScaledBlocked : public OpRewritePattern<DotScaledOp> { |
| 7 | +public: |
| 8 | + DecomposeScaledBlocked(MLIRContext *context, PatternBenefit benefit) |
| 9 | + : OpRewritePattern<DotScaledOp>(context, benefit) {} |
| 10 | + |
| 11 | + LogicalResult matchAndRewrite(DotScaledOp scaledDotOp, |
| 12 | + PatternRewriter &rewriter) const override; |
| 13 | + |
| 14 | +protected: |
| 15 | + FloatType getComputeType(ScaleDotElemType aType, ScaleDotElemType bType, |
| 16 | + PatternRewriter &rewriter) const; |
| 17 | + TypedValue<RankedTensorType> scaleTo16(PatternRewriter &rewriter, |
| 18 | + TypedValue<RankedTensorType> scale, |
| 19 | + FloatType computeType) const; |
| 20 | + TypedValue<RankedTensorType> |
| 21 | + broadcastScale(PatternRewriter &rewriter, DotScaledOp scaledDotOp, |
| 22 | + ModuleOp mod, TypedValue<RankedTensorType> scale, |
| 23 | + int dim) const; |
| 24 | + TypedValue<RankedTensorType> maskNan(PatternRewriter &rewriter, |
| 25 | + DotScaledOp scaledDotOp, ModuleOp mod, |
| 26 | + TypedValue<RankedTensorType> mxfp, |
| 27 | + TypedValue<RankedTensorType> scale, |
| 28 | + int dim) const; |
| 29 | + TypedValue<RankedTensorType> scaleArg(PatternRewriter &rewriter, |
| 30 | + DotScaledOp scaledDotOp, int opIdx, |
| 31 | + FloatType computeType) const; |
| 32 | + TypedValue<RankedTensorType> |
| 33 | + cvtDotOperand(PatternRewriter &rewriter, DotScaledOp scaledDotOp, int opIdx, |
| 34 | + TypedValue<RankedTensorType> v) const; |
| 35 | + static SmallVector<int, 2> getTransposeOrder(int rank); |
| 36 | +}; |
| 37 | + |
5 | 38 | void populateDecomposeScaledBlockedPatterns(mlir::RewritePatternSet &patterns, |
6 | 39 | int benefit); |
7 | 40 |
|
|
0 commit comments