Skip to content

Commit 977d370

Browse files
committed
Merge commit 'a6d11f755d6c01c6ef369af5d4b4f3447522313b'
2 parents 3752705 + a6d11f7 commit 977d370

File tree

8 files changed

+261
-287
lines changed

8 files changed

+261
-287
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ python/triton/backends/*
3535
!python/triton/backends/driver.py
3636

3737
# Language extras
38-
python/triton/language/extra
38+
python/triton/language/extra/*
39+
!python/triton/language/extra/__init__.py
40+
!python/triton/language/extra/libdevice.py
3941

4042
# Tools extras
4143
python/triton/tools/extra

include/triton/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,40 @@
11
#include "mlir/IR/PatternMatch.h"
2+
#include "triton/Dialect/Triton/IR/Dialect.h"
23

34
namespace mlir::triton::gpu {
45

6+
class DecomposeScaledBlocked : public OpRewritePattern<DotScaledOp> {
7+
public:
8+
DecomposeScaledBlocked(MLIRContext *context, PatternBenefit benefit)
9+
: OpRewritePattern<DotScaledOp>(context, benefit) {}
10+
11+
LogicalResult matchAndRewrite(DotScaledOp scaledDotOp,
12+
PatternRewriter &rewriter) const override;
13+
14+
protected:
15+
FloatType getComputeType(ScaleDotElemType aType, ScaleDotElemType bType,
16+
PatternRewriter &rewriter) const;
17+
TypedValue<RankedTensorType> scaleTo16(PatternRewriter &rewriter,
18+
TypedValue<RankedTensorType> scale,
19+
FloatType computeType) const;
20+
TypedValue<RankedTensorType>
21+
broadcastScale(PatternRewriter &rewriter, DotScaledOp scaledDotOp,
22+
ModuleOp mod, TypedValue<RankedTensorType> scale,
23+
int dim) const;
24+
TypedValue<RankedTensorType> maskNan(PatternRewriter &rewriter,
25+
DotScaledOp scaledDotOp, ModuleOp mod,
26+
TypedValue<RankedTensorType> mxfp,
27+
TypedValue<RankedTensorType> scale,
28+
int dim) const;
29+
TypedValue<RankedTensorType> scaleArg(PatternRewriter &rewriter,
30+
DotScaledOp scaledDotOp, int opIdx,
31+
FloatType computeType) const;
32+
TypedValue<RankedTensorType>
33+
cvtDotOperand(PatternRewriter &rewriter, DotScaledOp scaledDotOp, int opIdx,
34+
TypedValue<RankedTensorType> v) const;
35+
static SmallVector<int, 2> getTransposeOrder(int rank);
36+
};
37+
538
void populateDecomposeScaledBlockedPatterns(mlir::RewritePatternSet &patterns,
639
int benefit);
740

0 commit comments

Comments
 (0)