Skip to content

Commit 488474f

Browse files
Merge OpenAI Triton commit 94643b2 (#3376)
This PR change the Triton base from ac9574c to 94643b2 (Feb 5). Pass rate: 97.97% Please do not squash and merge this PR.
2 parents c3482a5 + 217515e commit 488474f

File tree

54 files changed

+1869
-666
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+1869
-666
lines changed

include/triton/Dialect/Triton/IR/OpInterfaces.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ namespace impl {
1111

1212
LogicalResult verifyTransposeOpInterface(Operation *op);
1313

14+
LogicalResult verifyDotOpInterface(Operation *op);
15+
1416
} // namespace impl
1517

1618
} // namespace triton

include/triton/Dialect/Triton/IR/Traits.h

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -58,53 +58,6 @@ class VerifyTensorLayoutsTrait
5858
}
5959
};
6060

61-
// Verify if the op is a dot-like operation.
62-
// A dot-like operation should have three operands.
63-
// The first two operands should share a common dimension, and the result
64-
// should have the dimensions of the two operands that are not shared.
65-
// A dot-like operation can be either 2d or 3d.
66-
// In the 3d case, the first dimension of operands is the batch dimension.
67-
template <class ConcreteType>
68-
class DotLike : public TraitBase<ConcreteType, DotLike> {
69-
public:
70-
static LogicalResult verifyTrait(Operation *op) {
71-
if (op->getNumOperands() < 3)
72-
return op->emitOpError("expected at least 3 operands");
73-
auto aTy = cast<ShapedType>(op->getOperand(0).getType());
74-
auto bTy = cast<ShapedType>(op->getOperand(1).getType());
75-
auto cTy = cast<ShapedType>(op->getOperand(2).getType());
76-
auto aShape = aTy.getShape();
77-
auto bShape = bTy.getShape();
78-
auto cShape = cTy.getShape();
79-
// Check if all 3d or all 2d
80-
if (aShape.size() != 2 && aShape.size() != 3)
81-
return op->emitOpError("expected operands to be 2d or 3d");
82-
if (aShape.size() != bShape.size() || aShape.size() != cShape.size())
83-
return op->emitOpError("expected all operands to have the same rank");
84-
// Check if the first two operands share a common dimension
85-
// TODO: enable back with an interface to support scaled dot.
86-
// if (aShape[aShape.size() - 1] != bShape[aShape.size() - 2])
87-
// return op->emitOpError("expected the last dimension of the first
88-
// operand "
89-
// "to be equal to the second-to-last dimension of
90-
// " "the second operand");
91-
// Check the batch dimension
92-
if (aShape.size() == 3 &&
93-
(aShape[0] != cShape[0] || bShape[0] != cShape[0]))
94-
return op->emitOpError("expected the first dimension of the first "
95-
"operand to be equal to the first dimension of "
96-
"the result");
97-
// Check the output shape
98-
if (cShape[cShape.size() - 2] != aShape[aShape.size() - 2] ||
99-
cShape[cShape.size() - 1] != bShape[aShape.size() - 1])
100-
return op->emitOpError(
101-
"expected the output shape to be the concatenation of the last "
102-
"dimension of the first operand and the last dimension of the "
103-
"second ");
104-
return success();
105-
}
106-
};
107-
10861
template <typename ConcreteType>
10962
class SameOperandsAndResultEncoding
11063
: public TraitBase<ConcreteType, SameOperandsAndResultEncoding> {

include/triton/Dialect/Triton/IR/TritonInterfaces.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
66

77
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
88
def VerifyTensorLayoutsTrait : NativeOpTrait<"VerifyTensorLayoutsTrait">;
9-
def DotLike : NativeOpTrait<"DotLike">;
109
def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">;
1110
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
1211
def SameLoadStoreOperandsShape : NativeOpTrait<"SameLoadStoreOperandsShape">;

include/triton/Dialect/Triton/IR/TritonOpInterfaces.td

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,27 @@ def TransposeOpInterface : OpInterface<"TransposeOpInterface"> {
2929
/*args=*/(ins)>
3030
];
3131

32-
let verify = [{ return ::mlir::triton::impl::verifyTransposeOpInterface($_op); }];
32+
let verify = [{ return ::mlir::triton::impl::verifyTransposeOpInterface($_op); }];
33+
}
34+
35+
def DotOpInterface : OpInterface<"DotOpInterface"> {
36+
let description = [{
37+
This interface is implemented by operations that perform a dot product.
38+
}];
39+
40+
let cppNamespace = "::mlir::triton";
41+
42+
let methods = [
43+
InterfaceMethod<
44+
/*desc=*/[{
45+
Verifies the dimensions of the A and B DotOp operands.
46+
}],
47+
/*retType=*/"bool",
48+
/*methodName=*/"verifyDims",
49+
/*args=*/(ins)>
50+
];
51+
52+
let verify = [{ return ::mlir::triton::impl::verifyDotOpInterface($_op); }];
3353
}
3454

3555

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [Pure]> {
631631
//
632632
def TT_DotOp : TT_Op<"dot", [Pure,
633633
DeclareOpInterfaceMethods<InferTypeOpInterface>,
634-
DotLike,
634+
DeclareOpInterfaceMethods<DotOpInterface>,
635635
TypesMatchWith<"result's type matches accumulator's type",
636636
"d", "c", "$_self">]> {
637637
let summary = "dot";
@@ -671,7 +671,7 @@ def TT_DotOp : TT_Op<"dot", [Pure,
671671
//
672672
def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
673673
AttrSizedOperandSegments,
674-
DotLike,
674+
DeclareOpInterfaceMethods<DotOpInterface>,
675675
TypesMatchWith<"result's type matches accumulator's type",
676676
"d", "c", "$_self">]> {
677677
let summary = "dot_scaled";

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -200,13 +200,7 @@ StringRef getAMDArch(Operation *module);
200200
std::optional<mlir::triton::gpu::SwizzledSharedEncodingAttr>
201201
getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible);
202202

203-
enum class MMALoadType {
204-
SharedV3,
205-
Registers, // may be v2 or v3
206-
DoNotPipeline, // could be a valid shared/registers MMA operand, but skip
207-
// pipelining
208-
};
209-
MMALoadType getMMALoadType(Operation *loadOp);
203+
bool canUseMMAv3Pipelining(Operation *loadOp);
210204

211205
// Convert \param op operands and results to layout \param encoding.
212206
void convertOpEncoding(Attribute encoding, Operation *op);

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ include "mlir/Dialect/Arith/IR/ArithBase.td"
2727
include "triton/Dialect/Triton/IR/TritonTypes.td"
2828
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
2929
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
30+
include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"
3031
include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td"
3132
include "triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td"
3233
include "mlir/IR/OpBase.td"
@@ -71,7 +72,7 @@ def TTNG_ClusterWaitOp : TTNG_Op<"cluster_wait", []> {
7172
//
7273
def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [DeclareOpInterfaceMethods<InferTypeOpInterface>,
7374
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
74-
DotLike,
75+
DeclareOpInterfaceMethods<DotOpInterface>,
7576
TypesMatchWith<"result's type matches accumulator's type",
7677
"d", "c", "$_self">]> {
7778
let summary = "warp group dot";
@@ -325,7 +326,7 @@ def TTNG_TMAStoreWaitOp : TTNG_Op<"async_tma_store_wait"> {
325326
let assemblyFormat = "attr-dict";
326327
}
327328

328-
def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DotLike]> {
329+
def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DeclareOpInterfaceMethods<DotOpInterface>]> {
329330
let summary = "block level op mapping to tensorcore gen5 mma";
330331

331332
let description = [{
@@ -343,11 +344,12 @@ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [DeclareOpInterfaceMethods<MemoryE
343344
I1:$pred,
344345
Optional<TTG_MemDescType>:$barrier,
345346
OptionalAttr<UnitAttr>:$two_ctas);
347+
346348
// TODO: improve printing format.
347349
let assemblyFormat = "$a`,` $b`,` $d`,` $useD`,` $pred (`,` $barrier^)? attr-dict `:` functional-type(operands, results)";
348350
}
349351

350-
def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DotLike]> {
352+
def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DeclareOpInterfaceMethods<DotOpInterface>]> {
351353
let summary = "block level op mapping to tensorcore gen5 mma";
352354

353355
let description = [{
@@ -366,6 +368,7 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [DeclareOpInterfaceMe
366368
I1:$useD,
367369
I1:$pred,
368370
Optional<TTG_MemDescType>:$barrier);
371+
369372
// TODO: improve printing format.
370373
let assemblyFormat = "$a `,` $b `,` $d `,` $a_scale `,` $b_scale `,` $useD`,` $pred `lhs` `=` $a_type `rhs` `=` $b_type (`,` $barrier^)? attr-dict `:` functional-type(operands, results)";
371374
}

include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,18 @@ mlir::LogicalResult createTMADesc(mlir::Value tmaPtr,
5050

5151
Value elemSizeVal = builder.template create<arith::ConstantOp>(
5252
loc, builder.getI64Type(), builder.getI64IntegerAttr(elemSize));
53-
Value globalStride = builder.template create<arith::MulIOp>(
54-
loc, op.getStrides()[0], elemSizeVal);
53+
54+
SmallVector<Value> globalDim(llvm::reverse(op.getShape()));
55+
SmallVector<Value> globalStride;
56+
for (int k = op.getStrides().size() - 2; k >= 0; --k) {
57+
globalStride.push_back(op.getStrides()[k]);
58+
}
59+
60+
SmallVector<Value> elementStride(globalDim.size(), mkI32Constant(1));
61+
62+
for (int i = 0; i < globalStride.size(); ++i)
63+
globalStride[i] = builder.template create<arith::MulIOp>(
64+
loc, globalStride[i], elemSizeVal);
5565

5666
int elemTypeEnum;
5767
switch (elemSize) {
@@ -75,15 +85,14 @@ mlir::LogicalResult createTMADesc(mlir::Value tmaPtr,
7585
}
7686
}
7787

78-
auto one = mkI32Constant(1);
7988
builder.template create<triton::ExperimentalTensormapCreateOp>(
8089
loc,
8190
/*desc_ptr=*/tmaPtr,
8291
/*global_address=*/op.getBase(),
8392
/*box_dim=*/boxDim,
84-
/*global_dim=*/ValueRange{op.getShape()[1], op.getShape()[0]},
85-
/*global_stride=*/ValueRange{globalStride},
86-
/*element_strides=*/ValueRange{one, one},
93+
/*global_dim=*/globalDim,
94+
/*global_stride=*/globalStride,
95+
/*element_strides=*/elementStride,
8796
/*elem_type*/ builder.getI32IntegerAttr(elemTypeEnum),
8897
/*interleave_layout*/ builder.getI32IntegerAttr(0),
8998
/*swizzle_mode=*/builder.getI32IntegerAttr(swizzle_mode),

lib/Analysis/AxisInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -935,7 +935,7 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
935935
// Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n
936936
lhsDivisibility = 1;
937937
}
938-
return std::max<int64_t>(1, lhsDivisibility / (1 << shift));
938+
return std::max<int64_t>(1, lhsDivisibility / (int64_t(1) << shift));
939939
}
940940

941941
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,

lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,17 @@ Type TritonGPUToLLVMTypeConverter::convertTritonTensorType(
4747
Type TritonGPUToLLVMTypeConverter::convertMemDescType(
4848
MemDescType type, const TargetInfoBase &targetInfo) {
4949
auto ctx = type.getContext();
50-
SmallVector<Type, 4> types;
5150
// base ptr
5251
auto ptrType =
5352
LLVM::LLVMPointerType::get(ctx, targetInfo.getSharedAddressSpace());
53+
54+
if (isa<triton::nvidia_gpu::TensorMemoryEncodingAttr,
55+
triton::nvidia_gpu::TensorMemoryScalesEncodingAttr>(
56+
type.getEncoding())) {
57+
return ptrType;
58+
}
59+
60+
SmallVector<Type, 4> types;
5461
types.push_back(ptrType);
5562
auto rank = type.getRank();
5663
// offsets

0 commit comments

Comments
 (0)