Skip to content

Commit a6d11f7

Browse files
authored
[Backend][NFC] Expose DecomposeScaledBlocked Pass (#8078)
Exposed and Refactored DecomposeScaledBlocked pass so that it can be used by other backends. Modifications include: - Moved the pass definition to a header file. - Made `cvtDotOperand` a member method to be reused later. - Made `getTransposeOrder` a static method. This is one of a series of PRs to decompose scaled dot on AMD backend.
1 parent abce3c8 commit a6d11f7

File tree

2 files changed

+238
-216
lines changed

2 files changed

+238
-216
lines changed

include/triton/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,40 @@
11
#include "mlir/IR/PatternMatch.h"
2+
#include "triton/Dialect/Triton/IR/Dialect.h"
23

34
namespace mlir::triton::gpu {
45

6+
class DecomposeScaledBlocked : public OpRewritePattern<DotScaledOp> {
7+
public:
8+
DecomposeScaledBlocked(MLIRContext *context, PatternBenefit benefit)
9+
: OpRewritePattern<DotScaledOp>(context, benefit) {}
10+
11+
LogicalResult matchAndRewrite(DotScaledOp scaledDotOp,
12+
PatternRewriter &rewriter) const override;
13+
14+
protected:
15+
FloatType getComputeType(ScaleDotElemType aType, ScaleDotElemType bType,
16+
PatternRewriter &rewriter) const;
17+
TypedValue<RankedTensorType> scaleTo16(PatternRewriter &rewriter,
18+
TypedValue<RankedTensorType> scale,
19+
FloatType computeType) const;
20+
TypedValue<RankedTensorType>
21+
broadcastScale(PatternRewriter &rewriter, DotScaledOp scaledDotOp,
22+
ModuleOp mod, TypedValue<RankedTensorType> scale,
23+
int dim) const;
24+
TypedValue<RankedTensorType> maskNan(PatternRewriter &rewriter,
25+
DotScaledOp scaledDotOp, ModuleOp mod,
26+
TypedValue<RankedTensorType> mxfp,
27+
TypedValue<RankedTensorType> scale,
28+
int dim) const;
29+
TypedValue<RankedTensorType> scaleArg(PatternRewriter &rewriter,
30+
DotScaledOp scaledDotOp, int opIdx,
31+
FloatType computeType) const;
32+
TypedValue<RankedTensorType>
33+
cvtDotOperand(PatternRewriter &rewriter, DotScaledOp scaledDotOp, int opIdx,
34+
TypedValue<RankedTensorType> v) const;
35+
static SmallVector<int, 2> getTransposeOrder(int rank);
36+
};
37+
538
void populateDecomposeScaledBlockedPatterns(mlir::RewritePatternSet &patterns,
639
int benefit);
740

0 commit comments

Comments
 (0)