Skip to content

Commit 584f273

Browse files
anmyachevmgornyatalmanMogballzhuhan0
authored
Merge OpenAI Triton commit 6fedb78 (#4037)
This PR change the Triton base from 9451f8f to 6fedb78 (Apr 24). Pass rate: 88.73% Please do not squash and merge this PR. --------- Signed-off-by: Anatoly Myachev <[email protected]> Co-authored-by: Michał Górny <[email protected]> Co-authored-by: Andrey Talman <[email protected]> Co-authored-by: Jeff Niu <[email protected]> Co-authored-by: Han Zhu <[email protected]> Co-authored-by: Lei Zhang <[email protected]> Co-authored-by: peterbell10 <[email protected]> Co-authored-by: Jingning Tang <[email protected]> Co-authored-by: Jingning Tang <[email protected]> Co-authored-by: Dan Zimmerman <[email protected]>
2 parents b06a9fe + a7b48c8 commit 584f273

File tree

31 files changed

+710
-165
lines changed

31 files changed

+710
-165
lines changed

bench/bench/bench_mlp.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype,
9898
# -- benchmark --
9999
fpath = Path(f"logs/{name}/{batch}-{dim1}-{dim2}-{n_expts_tot}-{n_expts_act}-{x_dtype}-{w_dtype}.hatchet")
100100
fpath.parent.mkdir(parents=True, exist_ok=True)
101-
x_dtype = {"bf16": torch.bfloat16, "fp8": torch.float8_e4m3fn}[x_dtype]
101+
x_dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.float8_e4m3fn}[x_dtype]
102102
# special treatment of fp8_e4m3 on AMD CDNA3 because it uses fp8_e4m3fnuz
103103
if x_dtype == torch.float8_e4m3fn and get_cdna_version() == 3:
104104
x_dtype = torch.float8_e4m3fnuz
@@ -140,17 +140,29 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype,
140140
min_time = max(min_time_flops, min_time_bytes)
141141
util = min_time / tot_time
142142
else:
143-
util = "N/A"
143+
util = 0.0
144144
tflops = sum([tot_flops[w] for w in [8, 16]]) / tot_time * 1e-3
145145
tbps = tot_bytes / tot_time * 1e-3
146+
print(f"Utilization: {util:.0%}; {tflops:>6.1f} TFLOPs, {tbps:.1f} TB/s")
146147

147148
return util, tflops, tbps
148149

149150

150151
if __name__ == "__main__":
151152
has_native_mx4 = torch.cuda.get_device_capability(0)[0] >= 10 or get_cdna_version() == 4
152-
qxdtype = "fp8" if has_native_mx4 else "bf16"
153-
print(bench_mlp(8192, 8192, 8192, 1, 1, "fp8", "fp8", TP=1, EP=1, name="dense"))
154-
print(bench_mlp(8192, 8192, 8192, 1, 1, qxdtype, "mx4", TP=1, EP=1, name="dense"))
155-
print(bench_mlp(2048, 5120, 8192, 128, 4, "fp8", "fp8", TP=4, EP=1, name="llama4"))
156-
print(bench_mlp(2048, 5120, 8192, 128, 4, qxdtype, "mx4", TP=4, EP=1, name="llama4"))
153+
if SPECS is None:
154+
print("Current GPU has no specs provided, utilization is N/A")
155+
if has_native_mx4:
156+
bench_mlp(8192, 8192, 8192, 1, 1, "fp8", "fp8", TP=1, EP=1, name="dense")
157+
bench_mlp(8192, 8192, 8192, 1, 1, "fp8", "mx4", TP=1, EP=1, name="dense")
158+
bench_mlp(2048, 5120, 8192, 128, 4, "fp8", "fp8", TP=4, EP=1, name="llama4")
159+
bench_mlp(2048, 5120, 8192, 128, 4, "fp8", "mx4", TP=4, EP=1, name="llama4")
160+
else:
161+
# bf16/fp16 x fp8 is skipped because matmul_ogs requires x and w has the
162+
# same type when not doing mxfp operation
163+
bench_mlp(8192, 8192, 8192, 1, 1, "fp8", "fp8", TP=1, EP=1, name="dense")
164+
bench_mlp(8192, 8192, 8192, 1, 1, "fp16", "mx4", TP=1, EP=1, name="dense")
165+
bench_mlp(8192, 8192, 8192, 1, 1, "bf16", "mx4", TP=1, EP=1, name="dense")
166+
bench_mlp(2048, 5120, 8192, 128, 4, "fp8", "fp8", TP=4, EP=1, name="llama4")
167+
bench_mlp(2048, 5120, 8192, 128, 4, "bf16", "mx4", TP=4, EP=1, name="llama4")
168+
bench_mlp(2048, 5120, 8192, 128, 4, "fp16", "mx4", TP=4, EP=1, name="llama4")

include/triton/Dialect/Triton/IR/OpInterfaces.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define TRITON_IR_OP_INTERFACES_H_
33

44
#include "mlir/IR/OpDefinition.h"
5+
#include "triton/Dialect/Triton/IR/Types.h"
56

67
namespace mlir {
78

include/triton/Dialect/Triton/IR/TritonAttrDefs.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,21 @@ def TT_AtomicRMWAttr : I32EnumAttr<
6767
let cppNamespace = "::mlir::triton";
6868
}
6969

70+
def TT_DescriptorReduceKindAttr : I32EnumAttr<
71+
"DescriptorReduceKind", "",
72+
[
73+
I32EnumAttrCase<"ADD", 1, "add">,
74+
I32EnumAttrCase<"MIN", 2, "min">,
75+
I32EnumAttrCase<"MAX", 3, "max">,
76+
I32EnumAttrCase<"INC", 4, "inc">,
77+
I32EnumAttrCase<"DEC", 5, "dec">,
78+
I32EnumAttrCase<"AND", 6, "and">,
79+
I32EnumAttrCase<"OR", 7, "or">,
80+
I32EnumAttrCase<"XOR", 8, "xor">,
81+
]> {
82+
let cppNamespace = "::mlir::triton";
83+
}
84+
7085
def TT_MemSyncScopeAttr : I32EnumAttr<
7186
"MemSyncScope", "",
7287
[

include/triton/Dialect/Triton/IR/TritonOpInterfaces.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,5 +75,31 @@ def DotOpInterface : OpInterface<"DotOpInterface"> {
7575
let verify = [{ return ::mlir::triton::impl::verifyDotOpInterface($_op); }];
7676
}
7777

78+
def TT_DescriptorOpInterface : OpInterface<"DescriptorOpInterface"> {
79+
let description = [{
80+
Common interface to get the descriptor argument from an operation on tensor descriptors.
81+
}];
82+
83+
let methods = [
84+
InterfaceMethod<
85+
/*desc=*/"Get the descriptor",
86+
/*retType=*/"::mlir::TypedValue<mlir::triton::TensorDescType>",
87+
/*methodName=*/"getDesc",
88+
/*args=*/(ins)>,
89+
];
90+
}
91+
92+
def TT_DescriptorStoreLikeOpInterface : OpInterface<"DescriptorStoreLikeOpInterface", [TT_DescriptorOpInterface]> {
93+
let cppNamespace = "::mlir::triton";
94+
95+
let methods = [
96+
InterfaceMethod<
97+
/*desc=*/"Get Source tensor",
98+
/*retType=*/"::mlir::TypedValue<mlir::RankedTensorType>",
99+
/*methodName=*/"getSrc",
100+
/*args=*/(ins)>,
101+
];
102+
}
103+
78104

79105
#endif // TRITON_OP_INTERFACES

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,7 +1019,7 @@ def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [
10191019
let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` attr-dict `:` type($base) `,` type($result)";
10201020

10211021
let builders = [
1022-
OpBuilder<(ins "Value":$base, "ValueRange":$shape, "ValueRange":$strides, "ArrayRef<int32_t>":$blockShape)>
1022+
OpBuilder<(ins "Value":$base, "ValueRange":$shape, "ValueRange":$strides, "ArrayRef<int32_t>":$blockShape, "bool":$isSignedInteger)>
10231023
];
10241024

10251025
let extraClassDeclaration = [{
@@ -1259,7 +1259,7 @@ def ReturnOp : TT_Op<"return", [Pure, HasParent<"FuncOp">, /*MemRefsNormalizable
12591259
}
12601260

12611261

1262-
def TT_DescriptorLoadOp : TT_Op<"descriptor_load"> {
1262+
def TT_DescriptorLoadOp : TT_Op<"descriptor_load", [TT_DescriptorOpInterface]> {
12631263
let summary = "Load from descriptor";
12641264
let description = [{
12651265
This operation will be lowered to Nvidia TMA load operation on targets supporting it.
@@ -1287,7 +1287,7 @@ def TT_DescriptorLoadOp : TT_Op<"descriptor_load"> {
12871287
let hasVerifier = 1;
12881288
}
12891289

1290-
def TT_DescriptorStoreOp : TT_Op<"descriptor_store"> {
1290+
def TT_DescriptorStoreOp : TT_Op<"descriptor_store", [TT_DescriptorStoreLikeOpInterface]> {
12911291
let summary = "store value based on descriptor";
12921292
let description = [{
12931293
This operation will be lowered to Nvidia TMA store operation on targets supporting it.
@@ -1304,11 +1304,30 @@ def TT_DescriptorStoreOp : TT_Op<"descriptor_store"> {
13041304
$desc `[` $indices `]` `,` $src
13051305
attr-dict `:` qualified(type($desc)) `,` type($src)
13061306
}];
1307-
13081307
let hasVerifier = 1;
13091308
}
13101309

1311-
def TT_DescriptorGatherOp : TT_Op<"descriptor_gather"> {
1310+
def TT_DescriptorReduceOp : TT_Op<"descriptor_reduce", [TT_DescriptorStoreLikeOpInterface]> {
1311+
let summary = "performs a reducing store operation based on a descriptor";
1312+
let description = [{
1313+
This operation will be lowered to Nvidia TMA store operation on targets supporting it.
1314+
`desc` is a tensor descriptor object.
1315+
The shape and types of `src` must match the descriptor otherwise the result is undefined.
1316+
}];
1317+
let arguments = (ins
1318+
TT_DescriptorReduceKindAttr:$kind,
1319+
Arg<TT_TensorDescType, "", [MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>:$desc,
1320+
TT_Tensor:$src,
1321+
Variadic<I32>:$indices
1322+
);
1323+
1324+
let assemblyFormat = [{
1325+
$kind `,` $desc `[` $indices `]` `,` $src
1326+
attr-dict `:` qualified(type($desc)) `,` type($src)
1327+
}];
1328+
}
1329+
1330+
def TT_DescriptorGatherOp : TT_Op<"descriptor_gather", [TT_DescriptorOpInterface]> {
13121331
let summary = "gather multiple rows from a descriptor into a single tensor";
13131332
let description = [{
13141333
The `tt.descriptor_gather` op will be lowered to NVIDIA TMA
@@ -1341,7 +1360,7 @@ def TT_DescriptorGatherOp : TT_Op<"descriptor_gather"> {
13411360
}];
13421361
}
13431362

1344-
def TT_DescriptorScatterOp : TT_Op<"descriptor_scatter"> {
1363+
def TT_DescriptorScatterOp : TT_Op<"descriptor_scatter", [TT_DescriptorStoreLikeOpInterface]> {
13451364
let summary = "scatter multiple rows to a descriptor from a single tensor";
13461365
let description = [{
13471366
The `tt.descriptor_scatter` op will be lowered to NVIDIA TMA

include/triton/Dialect/Triton/IR/TritonTypes.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,28 @@ def TT_TensorDescType : TritonTypeDef<"TensorDesc", "tensordesc", []> {
102102

103103
let parameters = (ins "RankedTensorType":$blockType);
104104
let assemblyFormat = "`<` $blockType `>`";
105+
106+
let builders = [
107+
TypeBuilder<(ins "RankedTensorType":$blockType, "bool":$isSigned), [{
108+
if (auto intTy = llvm::dyn_cast<IntegerType>(blockType.getElementType())) {
109+
auto sem = isSigned ? IntegerType::Signed : IntegerType::Unsigned;
110+
auto elemTy = IntegerType::get($_ctxt, intTy.getWidth(), sem);
111+
blockType = RankedTensorType::get(blockType.getShape(), elemTy);
112+
}
113+
return Base::get($_ctxt, blockType);
114+
}]>,
115+
];
116+
let extraClassDeclaration = [{
117+
RankedTensorType getSignlessBlockType() const {
118+
auto resTy = getBlockType();
119+
if (auto intTy = llvm::dyn_cast<IntegerType>(resTy.getElementType())) {
120+
auto width = resTy.getElementTypeBitWidth();
121+
auto signlessTy = IntegerType::get(getContext(), width);
122+
resTy = RankedTensorType::get(resTy.getShape(), signlessTy);
123+
}
124+
return resTy;
125+
}
126+
}];
105127
}
106128

107129
#endif

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def TTNG_AsyncTMACopyLocalToGlobalOp : TTNG_Op<"async_tma_copy_local_to_global">
316316
}];
317317

318318
let arguments = (ins
319-
Arg<TT_PtrType, "", [MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>:$desc_ptr,
319+
Arg<TT_PtrType, "", [MemWrite<GlobalMemory>]>:$desc_ptr,
320320
Variadic<I32>:$coord,
321321
Arg<TTG_MemDescType, "", [MemRead<SharedMemory>]>:$src
322322
);
@@ -327,6 +327,29 @@ def TTNG_AsyncTMACopyLocalToGlobalOp : TTNG_Op<"async_tma_copy_local_to_global">
327327
}];
328328
}
329329

330+
def TTNG_AsyncTMAReduceOp : TTNG_Op<"async_tma_reduce", [MemoryEffects<[MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>]> {
331+
let summary = "reduce result in gmem based on a TMA descriptor";
332+
333+
let description = [{
334+
This operation copies data from local memory to global memory
335+
asynchronously, and atomically performs the specified reduction kind.
336+
Atomicity is at the granularity of individual elements, and only relaxed
337+
semantics are implied.
338+
}];
339+
340+
let arguments = (ins
341+
TT_DescriptorReduceKindAttr:$kind,
342+
Arg<TT_PtrType, "", [MemRead<GlobalMemory>]>:$desc_ptr,
343+
Variadic<I32>:$coord,
344+
Arg<TTG_MemDescType, "", [MemRead<SharedMemory>]>:$src
345+
);
346+
347+
let assemblyFormat = [{
348+
$kind `,` $desc_ptr `[` $coord `]` $src
349+
attr-dict `:` qualified(type($desc_ptr)) `,` qualified(type($src))
350+
}];
351+
}
352+
330353
def TTNG_AsyncTMAGatherOp : TTNG_Op<"async_tma_gather"> {
331354
let summary = "gather data based on descriptor from global memory to local memory asynchronously";
332355

@@ -365,7 +388,7 @@ def TTNG_AsyncTMAScatterOp : TTNG_Op<"async_tma_scatter"> {
365388
}];
366389

367390
let arguments = (ins
368-
Arg<TT_PtrType, "", [MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>:$desc_ptr,
391+
Arg<TT_PtrType, "", [MemWrite<GlobalMemory>]>:$desc_ptr,
369392
RankedTensorOf<[I32]>:$x_offsets,
370393
I32:$y_offset,
371394
Arg<TTG_MemDescType, "", [MemRead<SharedMemory>]>:$src

lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
620620
GenericOpPattern<triton::AtomicRMWOp>, GenericOpPattern<ReturnOp>,
621621
GenericOpPattern<triton::DescriptorLoadOp>,
622622
GenericOpPattern<triton::DescriptorStoreOp>,
623+
GenericOpPattern<triton::DescriptorReduceOp>,
623624
GenericOpPattern<triton::ExperimentalTensormapCreateOp>,
624625
GenericOpPattern<triton::ExperimentalTensormapFenceproxyAcquireOp>,
625626
// this assumes the right layout will be set later for dot scaled.

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -969,16 +969,17 @@ OpFoldResult AdvanceOp::fold(FoldAdaptor adaptor) {
969969
//-- MakeTensorDescOp --
970970
void MakeTensorDescOp::build(OpBuilder &builder, OperationState &state,
971971
Value base, ValueRange shape, ValueRange strides,
972-
ArrayRef<int32_t> blockShape) {
972+
ArrayRef<int32_t> blockShape,
973+
bool isSignedInteger) {
973974
auto ptrTy = dyn_cast<triton::PointerType>(base.getType());
974975
if (!ptrTy) {
975976
llvm::report_fatal_error("Expected pointer type");
976977
}
977978
auto elemTy = ptrTy.getPointeeType();
978-
979979
SmallVector<int64_t> blockShape64(blockShape);
980980
auto blockTy = RankedTensorType::get(blockShape64, elemTy);
981-
auto descTy = TensorDescType::get(builder.getContext(), blockTy);
981+
auto descTy =
982+
TensorDescType::get(builder.getContext(), blockTy, isSignedInteger);
982983
return build(builder, state, descTy, base, shape, strides);
983984
}
984985

@@ -1333,21 +1334,23 @@ static LogicalResult verifyGatherScatterOp(Operation *op,
13331334
}
13341335

13351336
LogicalResult DescriptorGatherOp::verify() {
1336-
return verifyGatherScatterOp(*this, getDesc().getType().getBlockType(),
1337+
return verifyGatherScatterOp(*this,
1338+
getDesc().getType().getSignlessBlockType(),
13371339
getResult().getType(), getXOffsets().getType());
13381340
}
13391341

13401342
// -- DescriptorScatterOp --
13411343
LogicalResult DescriptorScatterOp::verify() {
1342-
return verifyGatherScatterOp(*this, getDesc().getType().getBlockType(),
1344+
return verifyGatherScatterOp(*this,
1345+
getDesc().getType().getSignlessBlockType(),
13431346
getSrc().getType(), getXOffsets().getType());
13441347
}
13451348

13461349
// -- DescriptorLoadOp --
13471350
static LogicalResult verifyDescriptorLoadStoreType(Operation *op,
13481351
TensorDescType desc,
13491352
RankedTensorType tensor) {
1350-
RankedTensorType block = desc.getBlockType();
1353+
RankedTensorType block = desc.getSignlessBlockType();
13511354
ArrayRef<int64_t> blockShape = block.getShape();
13521355
ArrayRef<int64_t> tensorShape = tensor.getShape();
13531356
if (blockShape.size() > tensorShape.size()) {

lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,8 @@ static SmallVector<TMAStore> getTMAStores(scf::ForOp forOp) {
1818
SmallVector<TMAStore> tmaStores;
1919

2020
forOp.getBody()->walk<mlir::WalkOrder::PreOrder>([&](Operation *op) {
21-
if (auto storeOp = dyn_cast<tt::DescriptorStoreOp>(op)) {
21+
if (auto storeOp = dyn_cast<tt::DescriptorStoreLikeOpInterface>(op)) {
2222
tmaStores.push_back({storeOp, storeOp.getDesc(), storeOp.getSrc()});
23-
} else if (auto scatterOp = dyn_cast<tt::DescriptorScatterOp>(op)) {
24-
tmaStores.push_back({scatterOp, scatterOp.getDesc(), scatterOp.getSrc()});
25-
2623
// Don't walk into nested loops.
2724
} else if (isa<scf::ForOp>(op)) {
2825
return WalkResult::skip();
@@ -77,6 +74,13 @@ static void createTMAAsyncCopy(scf::ForOp forOp, const TMAStore &store,
7774
storeOp.getIndices());
7875
builder.create<ttng::AsyncTMACopyLocalToGlobalOp>(
7976
loc, tmaPtr, storeOp.getIndices(), alloc);
77+
} else if (auto reduceOp = dyn_cast<tt::DescriptorReduceOp>(store.op)) {
78+
auto indices = ttng::translateTMAIndices(
79+
builder, reduceOp.getLoc(),
80+
reduceOp.getDesc().getType().getBlockType().getEncoding(),
81+
reduceOp.getIndices());
82+
builder.create<ttng::AsyncTMAReduceOp>(loc, reduceOp.getKind(), tmaPtr,
83+
reduceOp.getIndices(), alloc);
8084
} else {
8185
auto scatterOp = cast<tt::DescriptorScatterOp>(store.op);
8286
builder.create<ttng::AsyncTMAScatterOp>(

0 commit comments

Comments
 (0)