Skip to content

Commit e4035e9

Browse files
Merge OpenAI Triton commit 63cecbd (#3550)
This PR change the Triton base from 72193bb to 63cecbd (Feb 24). Pass rate: 97.65% Please do not squash and merge this PR.
2 parents b7604a9 + db7a20d commit e4035e9

File tree

9 files changed

+127
-121
lines changed

9 files changed

+127
-121
lines changed

include/triton/Dialect/Triton/IR/TritonOpInterfaces.td

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,47 +7,54 @@ include "mlir/IR/OpBase.td"
77
def TransposeOpInterface : OpInterface<"TransposeOpInterface"> {
88
let description = [{
99
This interface is implemented by operations that perform a transpose.
10-
It provides methods to access common properties such as the order attribute and the source operand.
10+
It provides methods to access common properties such as the order attribute
11+
and the source operand.
1112
}];
1213

1314
let cppNamespace = "::mlir::triton";
1415

1516
let methods = [
1617
InterfaceMethod<
17-
/*desc=*/[{
18-
Get the source operand of the transposition.
19-
}],
20-
/*retType=*/"::mlir::Value",
21-
/*methodName=*/"getSrc",
22-
/*args=*/(ins)>,
18+
/*desc=*/"Get the source operand of the transposition.",
19+
/*retType=*/"::mlir::Value",
20+
/*methodName=*/"getSrc",
21+
/*args=*/(ins)>,
2322
InterfaceMethod<
24-
/*desc=*/[{
25-
Get the order of the transposition.
26-
}],
27-
/*retType=*/"::mlir::ArrayRef<int32_t>",
28-
/*methodName=*/"getOrder",
29-
/*args=*/(ins)>
23+
/*desc=*/"Get the order of the transposition.",
24+
/*retType=*/"::mlir::ArrayRef<int32_t>",
25+
/*methodName=*/"getOrder",
26+
/*args=*/(ins)>
3027
];
3128

32-
let verify = [{ return ::mlir::triton::impl::verifyTransposeOpInterface($_op); }];
29+
let verify = [{
30+
return ::mlir::triton::impl::verifyTransposeOpInterface($_op);
31+
}];
3332
}
3433

3534
def DotOpInterface : OpInterface<"DotOpInterface"> {
3635
let description = [{
37-
This interface is implemented by operations that perform a dot product.
36+
This interface is implemented by operations that perform a dot product.
3837
}];
3938

4039
let cppNamespace = "::mlir::triton";
4140

4241
let methods = [
43-
InterfaceMethod<
44-
/*desc=*/[{
45-
Verifies the dimensions of the A and B DotOp operands.
46-
}],
47-
/*retType=*/"bool",
48-
/*methodName=*/"verifyDims",
49-
/*args=*/(ins)>
50-
];
42+
InterfaceMethod<
43+
/*desc=*/"Get the LHS A tensor",
44+
/*retType=*/"::mlir::Value",
45+
/*methodName=*/"getA",
46+
/*args=*/(ins)>,
47+
InterfaceMethod<
48+
/*desc=*/"Get the RHS B tensor",
49+
/*retType=*/"::mlir::Value",
50+
/*methodName=*/"getB",
51+
/*args=*/(ins)>,
52+
InterfaceMethod<
53+
/*desc=*/"Verify the dimensions of the A and B DotOp operands.",
54+
/*retType=*/"bool",
55+
/*methodName=*/"verifyDims",
56+
/*args=*/(ins)>
57+
];
5158

5259
let verify = [{ return ::mlir::triton::impl::verifyDotOpInterface($_op); }];
5360
}

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -681,29 +681,30 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
681681
let summary = "dot_scaled";
682682

683683
let description = [{
684-
$d = matrix_multiply(scale($lhs, $lhs_scale), scale(rlhs, $rhs_scale)) + $c.
684+
$d = matrix_multiply(scale($a, $a_scale), scale($b, $b_scale)) + $c.
685685
Where scale(x, s) is a function that applies the scale per block following microscaling spec.
686686
}];
687687

688688
let arguments = (
689689
ins
690690
// inputs are floats if we have a type for them, otherwise (fp4),
691691
// they are packed in pairs in an I8Tensor
692-
RankedTensorOf<[TT_Float,I8]>:$lhs,
693-
RankedTensorOf<[TT_Float,I8]>:$rhs,
692+
RankedTensorOf<[TT_Float,I8]>:$a,
693+
RankedTensorOf<[TT_Float,I8]>:$b,
694694
TT_FloatTensor:$c,
695-
Optional<RankedTensorOf<[TT_Float, I8]>>:$lhs_scale,
696-
Optional<RankedTensorOf<[TT_Float, I8]>>:$rhs_scale,
697-
TT_ScaleDotElemTypeAttr:$lhs_type,
698-
TT_ScaleDotElemTypeAttr:$rhs_type,
695+
Optional<RankedTensorOf<[TT_Float, I8]>>:$a_scale,
696+
Optional<RankedTensorOf<[TT_Float, I8]>>:$b_scale,
697+
TT_ScaleDotElemTypeAttr:$a_elem_type,
698+
TT_ScaleDotElemTypeAttr:$b_elem_type,
699699
BoolAttr:$fastMath
700700
);
701701

702702
let results = (outs TT_FloatTensor:$d);
703703

704704
let assemblyFormat = [{
705-
$lhs (`scale` $lhs_scale^)? `,` $rhs (`scale` $rhs_scale^)? `,` $c `lhs` `=` $lhs_type `rhs` `=` $rhs_type attr-dict
706-
`:` type($lhs) (`,` type($lhs_scale)^)? `*` type($rhs) (`,` type($rhs_scale)^)? `->` type($d)
705+
$a (`scale` $a_scale^)? `,` $b (`scale` $b_scale^)? `,` $c
706+
`lhs` `=` $a_elem_type `rhs` `=` $b_elem_type attr-dict
707+
`:` type($a) (`,` type($a_scale)^)? `*` type($b) (`,` type($b_scale)^)? `->` type($d)
707708
}];
708709
}
709710

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -318,14 +318,14 @@ bool DotOp::verifyDims() {
318318

319319
//-- DotScaledOp --
320320
bool DotScaledOp::verifyDims() {
321-
auto aShape = this->getLhs().getType().getShape();
322-
auto bShape = this->getRhs().getType().getShape();
321+
auto aShape = this->getA().getType().getShape();
322+
auto bShape = this->getB().getType().getShape();
323323

324324
auto aKdim = aShape[aShape.size() - 1];
325325
auto bKdim = bShape[aShape.size() - 2];
326-
if (this->getLhsType() == ScaleDotElemType::E2M1)
326+
if (this->getAElemType() == ScaleDotElemType::E2M1)
327327
aKdim *= 2;
328-
if (this->getRhsType() == ScaleDotElemType::E2M1)
328+
if (this->getBElemType() == ScaleDotElemType::E2M1)
329329
bKdim *= 2;
330330

331331
return aKdim == bKdim;

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ class ScaledBlockedToMMAv5
629629
mlir::isa<NvidiaMmaEncodingAttr>(oldRetType.getEncoding()))
630630
return failure();
631631

632-
if (dotOp.getLhsScale() == nullptr || dotOp.getRhsScale() == nullptr) {
632+
if (dotOp.getAScale() == nullptr || dotOp.getBScale() == nullptr) {
633633
return failure();
634634
}
635635

@@ -643,18 +643,18 @@ class ScaledBlockedToMMAv5
643643
return failure();
644644
Location loc = dotOp.getLoc();
645645
// operands
646-
Value a = dotOp.getLhs();
647-
Value b = dotOp.getRhs();
648-
auto oldAType = dotOp.getLhs().getType();
649-
auto oldBType = dotOp.getRhs().getType();
646+
Value a = dotOp.getA();
647+
Value b = dotOp.getB();
648+
auto oldAType = a.getType();
649+
auto oldBType = b.getType();
650650

651651
bool IsAMixedPrecFp4 = false;
652652
bool IsBMixedPrecFp4 = false;
653653

654-
if (dotOp.getLhsType() != dotOp.getRhsType()) {
655-
if (dotOp.getLhsType() == ScaleDotElemType::E2M1)
654+
if (dotOp.getAElemType() != dotOp.getBElemType()) {
655+
if (dotOp.getAElemType() == ScaleDotElemType::E2M1)
656656
IsAMixedPrecFp4 = true;
657-
else if (dotOp.getRhsType() == ScaleDotElemType::E2M1)
657+
else if (dotOp.getBElemType() == ScaleDotElemType::E2M1)
658658
IsBMixedPrecFp4 = true;
659659
}
660660

@@ -676,8 +676,8 @@ class ScaledBlockedToMMAv5
676676
// descriptor requires options that are unavailable to the .kind=mxf4 mma.
677677
// This is likely preferable over a silent runtime performance degradation
678678
// from running f4xf4 via .kind=mxf8f6f4
679-
if (dotOp.getLhsType() == ScaleDotElemType::E2M1 &&
680-
dotOp.getRhsType() == ScaleDotElemType::E2M1) {
679+
if (dotOp.getAElemType() == ScaleDotElemType::E2M1 &&
680+
dotOp.getBElemType() == ScaleDotElemType::E2M1) {
681681
k = 64;
682682
}
683683
SmallVector<unsigned> instrShape = {m, n, k};
@@ -701,8 +701,8 @@ class ScaledBlockedToMMAv5
701701
auto acc = rewriter.create<triton::nvidia_gpu::TMEMAllocOp>(
702702
loc, accMemDescType, cvtAcc);
703703

704-
RankedTensorType oldScaleAType = dotOp.getLhsScale().getType();
705-
RankedTensorType oldScaleBType = dotOp.getRhsScale().getType();
704+
RankedTensorType oldScaleAType = dotOp.getAScale().getType();
705+
RankedTensorType oldScaleBType = dotOp.getBScale().getType();
706706

707707
Attribute scaleEncoding =
708708
triton::nvidia_gpu::TensorMemoryScalesEncodingAttr::get(
@@ -724,8 +724,8 @@ class ScaledBlockedToMMAv5
724724
RankedTensorType newScaleBType = RankedTensorType::get(
725725
oldScaleBType.getShape(), oldScaleBType.getElementType(), scaleBLayout);
726726

727-
auto lhsScale = addSmemStageToScaleLoad(dotOp.getLhsScale(), rewriter);
728-
auto rhsScale = addSmemStageToScaleLoad(dotOp.getRhsScale(), rewriter);
727+
auto lhsScale = addSmemStageToScaleLoad(dotOp.getAScale(), rewriter);
728+
auto rhsScale = addSmemStageToScaleLoad(dotOp.getBScale(), rewriter);
729729

730730
Value newScaleA =
731731
rewriter.create<ConvertLayoutOp>(loc, newScaleAType, lhsScale);
@@ -737,8 +737,8 @@ class ScaledBlockedToMMAv5
737737
loc, scaleBType, newScaleB);
738738
auto vTrue = rewriter.create<arith::ConstantIntOp>(dotOp.getLoc(), 1, 1);
739739
rewriter.create<triton::nvidia_gpu::TCGen5MMAScaledOp>(
740-
loc, a, b, acc, scaleA, scaleB, dotOp.getLhsType(), dotOp.getRhsType(),
741-
vTrue, vTrue, Value());
740+
loc, a, b, acc, scaleA, scaleB, dotOp.getAElemType(),
741+
dotOp.getBElemType(), vTrue, vTrue, Value());
742742

743743
auto ld =
744744
rewriter.create<triton::nvidia_gpu::TMEMLoadOp>(loc, newAccType, acc);
@@ -792,17 +792,17 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) {
792792
// Transpose scaled_dot ops that have a scale on lhs.
793793
static void transposeDotOp(DotScaledOp dotOp) {
794794
OpBuilder builder(dotOp);
795-
Value lhs = dotOp.getLhs();
795+
Value lhs = dotOp.getA();
796796
std::array<int, 2> transOrder = {1, 0};
797797
Value lhsTransposed = builder.create<TransOp>(lhs.getLoc(), lhs, transOrder);
798-
Value rhs = dotOp.getRhs();
798+
Value rhs = dotOp.getB();
799799
Value rhsTransposed = builder.create<TransOp>(rhs.getLoc(), rhs, transOrder);
800800
Value c = dotOp.getC();
801801
Value cTransposed = builder.create<TransOp>(c.getLoc(), c, transOrder);
802802
Value result = builder.create<DotScaledOp>(
803803
dotOp.getLoc(), cTransposed.getType(), rhsTransposed, lhsTransposed,
804-
cTransposed, dotOp.getRhsScale(), dotOp.getLhsScale(), dotOp.getRhsType(),
805-
dotOp.getLhsType(), dotOp.getFastMath());
804+
cTransposed, dotOp.getBScale(), dotOp.getAScale(), dotOp.getBElemType(),
805+
dotOp.getAElemType(), dotOp.getFastMath());
806806
Operation *transposedResult =
807807
builder.create<TransOp>(result.getLoc(), result, transOrder);
808808
dotOp.replaceAllUsesWith(transposedResult);
@@ -814,7 +814,7 @@ static void transposeDots(ModuleOp m) {
814814
// want to use rhs from register for mmav3.
815815
SmallVector<DotScaledOp> toTranspose;
816816
m.walk([&](DotScaledOp dotOp) -> void {
817-
if (dotOp.getLhsScale() == nullptr && dotOp.getRhsScale() != nullptr)
817+
if (dotOp.getAScale() == nullptr && dotOp.getBScale() != nullptr)
818818
toTranspose.push_back(dotOp);
819819
});
820820
for (DotScaledOp dotOp : toTranspose) {

lib/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include "mlir/IR/Types.h"
44
#include "mlir/IR/Value.h"
55
#include "mlir/Support/LogicalResult.h"
6-
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
76

87
#include "triton/Dialect/Triton/IR/Dialect.h"
98
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
@@ -32,8 +31,8 @@ class DecomposeScaledBlocked : public OpRewritePattern<DotScaledOp> {
3231
LogicalResult matchAndRewrite(DotScaledOp scaledDotOp,
3332
PatternRewriter &rewriter) const override {
3433
// Types
35-
auto computeType = getComputeType(scaledDotOp.getLhsType(),
36-
scaledDotOp.getRhsType(), rewriter);
34+
auto computeType = getComputeType(scaledDotOp.getAElemType(),
35+
scaledDotOp.getBElemType(), rewriter);
3736
auto loc = scaledDotOp.getLoc();
3837

3938
auto cvtDotOperand = [&](TypedValue<RankedTensorType> v,
@@ -185,12 +184,11 @@ class DecomposeScaledBlocked : public OpRewritePattern<DotScaledOp> {
185184
TypedValue<RankedTensorType> scaleArg(PatternRewriter &rewriter,
186185
DotScaledOp scaledDotOp, int opIdx,
187186
FloatType computeType) const {
188-
auto v = opIdx == 0 ? scaledDotOp.getLhs() : scaledDotOp.getRhs();
189-
auto scale =
190-
opIdx == 0 ? scaledDotOp.getLhsScale() : scaledDotOp.getRhsScale();
187+
auto v = opIdx == 0 ? scaledDotOp.getA() : scaledDotOp.getB();
188+
auto scale = opIdx == 0 ? scaledDotOp.getAScale() : scaledDotOp.getBScale();
191189
auto isFp4 =
192-
(opIdx == 0 ? scaledDotOp.getLhsType() : scaledDotOp.getRhsType()) ==
193-
ScaleDotElemType::E2M1;
190+
ScaleDotElemType::E2M1 ==
191+
(opIdx == 0 ? scaledDotOp.getAElemType() : scaledDotOp.getBElemType());
194192
auto fastMath = scaledDotOp.getFastMath();
195193

196194
auto *ctx = rewriter.getContext();

third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -516,10 +516,10 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
516516
assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) ||
517517
(mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64));
518518

519-
Value a = op.getLhs();
520-
Value b = op.getRhs();
521-
Value aScale = op.getLhsScale();
522-
Value bScale = op.getRhsScale();
519+
Value a = op.getA();
520+
Value b = op.getB();
521+
Value aScale = op.getAScale();
522+
Value bScale = op.getBScale();
523523
bool isAScaleConstant = aScale.getDefiningOp<arith::ConstantOp>();
524524
bool isBScaleConstant = bScale.getDefiningOp<arith::ConstantOp>();
525525
Value d = op.getD();
@@ -528,8 +528,8 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
528528
auto dTensorTy = cast<RankedTensorType>(d.getType());
529529
auto elemTyA = aTensorTy.getElementType();
530530
auto elemTyB = bTensorTy.getElementType();
531-
ScaleDotElemType aElemType = op.getLhsType();
532-
ScaleDotElemType bElemType = op.getRhsType();
531+
ScaleDotElemType aElemType = op.getAElemType();
532+
ScaleDotElemType bElemType = op.getBElemType();
533533

534534
const auto kDimOperandSize = aTensorTy.getShape().back();
535535

@@ -576,10 +576,10 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
576576
constexpr int scaleKWidth = 1;
577577
constexpr int scaleKBase = 1;
578578

579-
Value loadedA = adaptor.getLhs();
580-
Value loadedB = adaptor.getRhs();
581-
Value loadedAScale = adaptor.getLhsScale();
582-
Value loadedBScale = adaptor.getRhsScale();
579+
Value loadedA = adaptor.getA();
580+
Value loadedB = adaptor.getB();
581+
Value loadedAScale = adaptor.getAScale();
582+
Value loadedBScale = adaptor.getBScale();
583583
Value loadedC = adaptor.getC();
584584

585585
auto numRepM = repA[1];
@@ -709,12 +709,12 @@ LogicalResult convertScaledMFMA(triton::DotScaledOp op,
709709
triton::DotScaledOp::Adaptor adaptor,
710710
const LLVMTypeConverter *typeConverter,
711711
ConversionPatternRewriter &rewriter) {
712-
assert(isa<DotOperandEncodingAttr>(op.getLhs().getType().getEncoding()) &&
713-
isa<DotOperandEncodingAttr>(op.getRhs().getType().getEncoding()) &&
712+
assert(isa<DotOperandEncodingAttr>(op.getA().getType().getEncoding()) &&
713+
isa<DotOperandEncodingAttr>(op.getB().getType().getEncoding()) &&
714714
"Both lhs and rhs should be DotOperand layout.");
715715

716-
assert(isa<LinearEncodingAttr>(op.getLhsScale().getType().getEncoding()) &&
717-
isa<LinearEncodingAttr>(op.getRhsScale().getType().getEncoding()) &&
716+
assert(isa<LinearEncodingAttr>(op.getAScale().getType().getEncoding()) &&
717+
isa<LinearEncodingAttr>(op.getBScale().getType().getEncoding()) &&
718718
"Both LhsScale and RhsScale should be linear layout.");
719719

720720
auto cTensorTy = op.getC().getType();

0 commit comments

Comments
 (0)