Skip to content

Commit f631f67

Browse files
authored
Merge OpenAI Triton commit 625c8cb (#5131)
This PR change the Triton base from 2ad519c to 625c8cb (Sep 10). Pass rate: 98.8%
2 parents ac96e7f + 57354f1 commit f631f67

File tree

21 files changed

+362
-81
lines changed

21 files changed

+362
-81
lines changed

include/triton/Dialect/TritonGPU/IR/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,8 @@ set(LLVM_TARGET_DEFINITIONS TritonGPUTypeInterfaces.td)
2424
mlir_tablegen(TypeInterfaces.h.inc -gen-type-interface-decls)
2525
mlir_tablegen(TypeInterfaces.cpp.inc -gen-type-interface-defs)
2626
add_public_tablegen_target(TritonGPUTypeInterfacesIncGen)
27+
28+
set(LLVM_TARGET_DEFINITIONS TritonGPUOpInterfaces.td)
29+
mlir_tablegen(OpInterfaces.h.inc -gen-op-interface-decls)
30+
mlir_tablegen(OpInterfaces.cpp.inc -gen-op-interface-defs)
31+
add_public_tablegen_target(TritonGPUOpInterfacesIncGen)

include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
#ifndef TRITON_GPU_DIALECT_INTERFACES_H
22
#define TRITON_GPU_DIALECT_INTERFACES_H
33

4+
#include "mlir/IR/OpDefinition.h"
5+
46
// clang-format off
57
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
8+
#include "triton/Dialect/TritonGPU/IR/OpInterfaces.h.inc"
69
#include "triton/Dialect/TritonGPU/IR/AttrInterfaces.h.inc"
710
// clang-format on
811

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#ifndef TRITONGPU_OP_INTERFACES
2+
#define TRITONGPU_OP_INTERFACES
3+
4+
include "mlir/IR/OpBase.td"
5+
6+
def UpcastFpOpInterface : OpInterface<"UpcastFpOpInterface"> {
7+
let description = [{
8+
This interface is for operations that upcast floating-point numbers.
9+
}];
10+
11+
let cppNamespace = "::mlir::triton::gpu";
12+
13+
let methods = [
14+
InterfaceMethod<
15+
/*desc=*/"Infer destination encoding",
16+
/*retType=*/"mlir::Attribute",
17+
/*methodName=*/"inferDstEncoding",
18+
/*args=*/(ins "unsigned":$opIdx, "mlir::Attribute":$srcEnc)
19+
>,
20+
InterfaceMethod<
21+
/*desc=*/"Infer operand encoding from dst encoding",
22+
/*retType=*/"mlir::Attribute",
23+
/*methodName=*/"inferSrcEncoding",
24+
/*args=*/(ins "unsigned":$opIdx, "mlir::Attribute":$dstEnc)
25+
>
26+
];
27+
}
28+
29+
#endif // TRITONGPU_OP_INTERFACES

include/triton/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,22 @@ class DecomposeScaledBlocked : public OpRewritePattern<DotScaledOp> {
2222
ModuleOp mod, TypedValue<RankedTensorType> scale,
2323
int dim) const;
2424
TypedValue<RankedTensorType> maskNan(PatternRewriter &rewriter,
25-
DotScaledOp scaledDotOp, ModuleOp mod,
25+
DotScaledOp scaledDotOp,
2626
TypedValue<RankedTensorType> mxfp,
2727
TypedValue<RankedTensorType> scale,
2828
int dim) const;
29-
TypedValue<RankedTensorType> scaleArg(PatternRewriter &rewriter,
30-
DotScaledOp scaledDotOp, int opIdx,
31-
FloatType computeType) const;
29+
virtual TypedValue<RankedTensorType> scaleArg(PatternRewriter &rewriter,
30+
DotScaledOp scaledDotOp,
31+
int opIdx,
32+
FloatType computeType) const;
3233
TypedValue<RankedTensorType>
3334
cvtDotOperand(PatternRewriter &rewriter, DotScaledOp scaledDotOp, int opIdx,
3435
TypedValue<RankedTensorType> v) const;
36+
TypedValue<RankedTensorType>
37+
extendAndBroadcastScale(PatternRewriter &rewriter, DotScaledOp scaledDotOp,
38+
TypedValue<RankedTensorType> &scale,
39+
FloatType computeType, RankedTensorType dstType,
40+
int opIdx) const;
3541
static SmallVector<int, 2> getTransposeOrder(int rank);
3642
};
3743

lib/Analysis/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ add_triton_library(TritonAnalysis
1010
TritonGPUTableGen
1111
TritonGPUAttrDefsIncGen
1212
TritonGPUTypeInterfacesIncGen
13+
TritonGPUOpInterfacesIncGen
1314

1415
LINK_LIBS PUBLIC
1516
MLIRAnalysis

lib/Dialect/TritonGPU/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ add_triton_library(TritonGPUIR
1010
TritonGPUAttrDefsIncGen
1111
TritonGPUTypeInterfacesIncGen
1212
TritonIntelGPUAttrDefsIncGen
13+
TritonGPUOpInterfacesIncGen
1314

1415
LINK_LIBS PUBLIC
1516
MLIRGPUDialect

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
// Include TableGen'erated code
3333
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
34+
#include "triton/Dialect/TritonGPU/IR/OpInterfaces.cpp.inc"
3435
#include "triton/Dialect/TritonGPU/IR/TypeInterfaces.cpp.inc"
3536

3637
using namespace mlir;

lib/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.cpp

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,16 @@ TypedValue<RankedTensorType> DecomposeScaledBlocked::broadcastScale(
135135
}
136136

137137
TypedValue<RankedTensorType> DecomposeScaledBlocked::maskNan(
138-
PatternRewriter &rewriter, DotScaledOp scaledDotOp, ModuleOp mod,
138+
PatternRewriter &rewriter, DotScaledOp scaledDotOp,
139139
TypedValue<RankedTensorType> mxfp, TypedValue<RankedTensorType> scale,
140140
int dim) const {
141+
// Skip NaN checks if fastMath
142+
if (scaledDotOp.getFastMath())
143+
return mxfp;
144+
141145
// Implement tl.where(scale == 0xFF, float("nan"), mxfp)
142146
auto loc = scale.getLoc();
147+
auto mod = scaledDotOp->getParentOfType<ModuleOp>();
143148

144149
// Scale is NaN
145150
auto scaleTy = scale.getType();
@@ -180,7 +185,6 @@ DecomposeScaledBlocked::scaleArg(PatternRewriter &rewriter,
180185
auto fastMath = scaledDotOp.getFastMath();
181186

182187
auto loc = v.getLoc();
183-
auto mod = scaledDotOp->getParentOfType<ModuleOp>();
184188
auto rank = v.getType().getRank();
185189
auto kDim = opIdx == 0 ? rank - 1 : rank - 2;
186190

@@ -196,9 +200,33 @@ DecomposeScaledBlocked::scaleArg(PatternRewriter &rewriter,
196200
if (!scale)
197201
return v;
198202

203+
// 1) Cast scale to fp16/bf16, broadcast it and convert its layout
204+
auto reshapeScale = extendAndBroadcastScale(rewriter, scaledDotOp, scale,
205+
computeType, v.getType(), opIdx);
206+
207+
// 2) Multiply
208+
auto mxfp = cast<TypedValue<RankedTensorType>>(
209+
rewriter.create<arith::MulFOp>(loc, v, reshapeScale).getResult());
210+
211+
// 3) If the scale is NaN, return NaN, else return the scaled value.
212+
return maskNan(rewriter, scaledDotOp, mxfp, scale, kDim);
213+
}
214+
215+
TypedValue<RankedTensorType> DecomposeScaledBlocked::extendAndBroadcastScale(
216+
PatternRewriter &rewriter, DotScaledOp scaledDotOp,
217+
TypedValue<RankedTensorType> &scale, FloatType computeType,
218+
RankedTensorType dstType, int opIdx) const {
219+
auto loc = scale.getLoc();
220+
auto mod = scaledDotOp->getParentOfType<ModuleOp>();
221+
auto v = opIdx == 0 ? scaledDotOp.getA() : scaledDotOp.getB();
222+
auto rank = v.getType().getRank();
223+
auto kDim = opIdx == 0 ? rank - 1 : rank - 2;
224+
199225
// For some weird reason, we take the scale with shape as if it were coming
200226
// from the lhs even when it's the rhs. In a normal world, we should accept
201-
// this parametre transposed, as we do with the mxfp.
227+
// this parameter transposed, as we do with the mxfp.
228+
//
229+
// Notice: this is an inplace change.
202230
if (opIdx == 1) {
203231
auto order = getTransposeOrder(rank);
204232
scale = rewriter.create<TransOp>(loc, scale, order);
@@ -207,21 +235,9 @@ DecomposeScaledBlocked::scaleArg(PatternRewriter &rewriter,
207235
// 1) Cast scale to compute type (fp16/bf16)
208236
auto scale16 = scaleTo16(rewriter, scale, computeType);
209237

210-
// 2) Broadcast scale to the same shape and layout as v
238+
// 2) Broadcast scale to the same shape as v and convert the layout
211239
auto reshapeScale = broadcastScale(rewriter, scaledDotOp, mod, scale16, kDim);
212-
reshapeScale =
213-
rewriter.create<ConvertLayoutOp>(loc, v.getType(), reshapeScale);
214-
215-
// 3) Multiply
216-
auto mxfp = cast<TypedValue<RankedTensorType>>(
217-
rewriter.create<arith::MulFOp>(loc, v, reshapeScale).getResult());
218-
219-
// Skip NaN checks if fastMath
220-
if (fastMath)
221-
return mxfp;
222-
223-
// 4) If the scale is NaN, return NaN, else return the scaled value.
224-
return maskNan(rewriter, scaledDotOp, mod, mxfp, scale, kDim);
240+
return rewriter.create<ConvertLayoutOp>(loc, dstType, reshapeScale);
225241
}
226242

227243
TypedValue<RankedTensorType>

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "triton/Analysis/Utility.h"
1818
#include "triton/Dialect/Triton/IR/Dialect.h"
1919
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
20+
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
2021
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
2122
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
2223
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
@@ -1309,7 +1310,9 @@ void LayoutRematerialization::hoistConvertDotOperand(
13091310
// threads We do views and elementwise pure ops for now
13101311
auto noDataMovement = [](Operation *op) {
13111312
return (op->hasTrait<OpTrait::Elementwise>() && isMemoryEffectFree(op)) ||
1312-
isa<BroadcastOp, Fp4ToFpOp, ConvertLayoutOp>(op) || isView(op);
1313+
isa<BroadcastOp, Fp4ToFpOp, ConvertLayoutOp, UpcastFpOpInterface>(
1314+
op) ||
1315+
isView(op);
13131316
};
13141317
// Stop the slice as soon as we find an operation that cannot be done without
13151318
// data movement between threads

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,10 @@ Attribute inferSrcEncoding(Operation *op, Attribute encoding) {
525525
if (!isa<triton::gpu::BlockedEncodingAttr>(encoding))
526526
return {};
527527
}
528+
529+
if (isa<triton::gpu::UpcastFpOpInterface>(op))
530+
return {};
531+
528532
if (op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() ||
529533
op->hasTrait<mlir::OpTrait::SameLoadStoreOperandsAndResultEncoding>() ||
530534
op->hasTrait<mlir::OpTrait::Elementwise>() ||
@@ -558,6 +562,9 @@ Attribute inferDstEncoding(Operation *op, Attribute encoding) {
558562
if (!isa<triton::gpu::BlockedEncodingAttr>(encoding))
559563
return {};
560564
}
565+
if (isa<triton::gpu::UpcastFpOpInterface>(op))
566+
return {};
567+
561568
if (op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() ||
562569
op->hasTrait<mlir::OpTrait::SameLoadStoreOperandsAndResultEncoding>() ||
563570
op->hasTrait<mlir::OpTrait::Elementwise>() ||
@@ -938,7 +945,13 @@ LogicalResult getConvertBackwardSlice(
938945
continue;
939946
}
940947
for (auto [i, operand] : llvm::enumerate(definingOp->getOpOperands())) {
941-
auto srcEncoding = inferSrcEncoding(definingOp, encoding);
948+
Attribute srcEncoding;
949+
if (auto upcast =
950+
dyn_cast<triton::gpu::UpcastFpOpInterface>(definingOp)) {
951+
srcEncoding = upcast.inferSrcEncoding(i, encoding);
952+
} else {
953+
srcEncoding = inferSrcEncoding(definingOp, encoding);
954+
}
942955
if (!srcEncoding)
943956
return failure();
944957
// If the infered layout matches the original one we don't need to keep

0 commit comments

Comments
 (0)