Skip to content

Commit efc505b

Browse files
Merge OpenAI Triton commit 4b9efc5 (#4225)
This PR change the Triton base from 99b5e29 to 4b9efc5 (May 14). Pass rate: 97.25%->96.85% (#4267)
2 parents f415a12 + 29a9608 commit efc505b

File tree

36 files changed

+1182
-590
lines changed

36 files changed

+1182
-590
lines changed

.github/workflows/integration-tests-nvidia.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ on:
1010
jobs:
1111
integration-tests-nvidia:
1212
runs-on: ${{ matrix.runner }}
13-
timeout-minutes: 30
13+
timeout-minutes: 60
1414
strategy:
1515
matrix:
1616
runner: ${{ fromJson(inputs.matrix) }}
@@ -94,6 +94,8 @@ jobs:
9494
if [ "${{ matrix.runner[0] }}" == "nvidia-gb200" ]; then
9595
source /venv/bin/activate
9696
fi
97+
nproc
98+
nvidia-smi
9799
echo "PATH is '$PATH'"
98100
ccache --zero-stats
99101
make dev-install

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ include "triton/Dialect/Triton/IR/TritonTypes.td"
2929
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
3030
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
3131
include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"
32+
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
3233
include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td"
3334
include "triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td"
3435
include "mlir/IR/OpBase.td"
@@ -584,6 +585,12 @@ def TTNG_TMEMStoreOp : TTNG_Op<"tmem_store"> {
584585
);
585586
let results = (outs Optional<TTG_AsyncToken>:$token);
586587

588+
let builders = [
589+
OpBuilder<(ins "Value":$dst, "Value":$src, "Value":$pred), [{
590+
build($_builder, $_state, Type(), dst, Value(), src, pred);
591+
}]>
592+
];
593+
587594
let assemblyFormat = [{
588595
$src `,` $dst `` custom<Token>($dep, type($token)) `,` $pred
589596
attr-dict `:` type($src) `->` qualified(type($dst))

include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ std::unique_ptr<Pass> createTritonNvidiaGPUOptimizeDescriptorEncodingPass();
6464

6565
std::unique_ptr<Pass> createTritonNvidiaGPUOptimizeTMemLayoutsPass();
6666

67+
std::unique_ptr<Pass> createTritonNvidiaGPUInterleaveTMemPass();
68+
6769
/// Generate the code for registering passes.
6870
#define GEN_PASS_REGISTRATION
6971
#define GEN_PASS_DECL_TRITONNVIDIAGPULEGALIZETMALAYOUTS

include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,16 @@ def TritonNvidiaGPUOptimizeTMemLayoutsPass : Pass<"triton-nvidia-optimize-tmem-l
143143
"mlir::triton::TritonDialect"];
144144
}
145145

146+
def TritonNvidiaGPUInterleaveTMemPass : Pass<"triton-nvidia-interleave-tmem", "mlir::ModuleOp"> {
147+
let summary = "Interleave TMEM loads/stores.";
148+
149+
let description = [{
150+
The `triton-nvidia-interleave-tmem` pass attempts to sink TMEM loads and
151+
hoist TMEM stores, and potentially interleave them, to reduce register
152+
pressure.
153+
}];
154+
}
155+
146156
def TritonNvidiaGPURemoveTMEMTokensPass : Pass<"triton-nvidia-gpu-remove-tmem-tokens", "mlir::ModuleOp"> {
147157
let summary = "remove TMEM tokens";
148158

include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,30 @@ triton::gpu::SharedEncodingTrait
3939
getEncodingFromDescriptor(Operation *op, RankedTensorType tensorType,
4040
Value desc);
4141

42-
int64_t getTMAContigDim(Attribute encoding, ArrayRef<int64_t> shape);
42+
SmallVector<int64_t> getTMABlockShape(ArrayRef<int64_t> shapePerCTA,
43+
int elementBitWidth, int swizzleBytes,
44+
bool fp4Padded, bool transposed,
45+
bool packedSize);
46+
47+
inline SmallVector<int64_t> getTMABlockShape(Attribute encoding,
48+
ArrayRef<int64_t> shapePerCTA,
49+
bool packedSize) {
50+
auto mmaEnc = cast<gpu::NVMMASharedEncodingAttr>(encoding);
51+
return getTMABlockShape(shapePerCTA, mmaEnc.getElementBitWidth(),
52+
mmaEnc.getSwizzlingByteWidth(), mmaEnc.getFp4Padded(),
53+
mmaEnc.getTransposed(), packedSize);
54+
}
4355

44-
inline int64_t getTMAContigDim(RankedTensorType tensorType) {
45-
return getTMAContigDim(tensorType.getEncoding(), tensorType.getShape());
56+
inline SmallVector<int64_t> getTMABlockShape(RankedTensorType ty,
57+
bool packedSize) {
58+
auto shapePerCTA = gpu::getShapePerCTA(ty);
59+
return getTMABlockShape(ty.getEncoding(), shapePerCTA, packedSize);
4660
}
4761

48-
inline int64_t getTMAContigDim(gpu::MemDescType memDescType) {
49-
return getTMAContigDim(memDescType.getEncoding(), memDescType.getShape());
62+
inline SmallVector<int64_t> getTMABlockShape(triton::gpu::MemDescType ty,
63+
bool packedSize) {
64+
auto shapePerCTA = gpu::getShapePerCTA(ty);
65+
return getTMABlockShape(ty.getEncoding(), shapePerCTA, packedSize);
5066
}
5167

5268
std::optional<int> getTMASwizzleMode(Operation *op, TensorDescType ty);
@@ -74,16 +90,18 @@ mlir::LogicalResult createTMADesc(mlir::Value tmaPtr,
7490

7591
int paddingScale = fp4Padded ? 2 : 1;
7692
auto shapePerCTA = gpu::getShapePerCTA(encoding, op.getTensorShape());
77-
int32_t contig_dim_size = getTMAContigDim(encoding, op.getTensorShape());
93+
auto blockShape =
94+
getTMABlockShape(encoding, shapePerCTA, /*packedSize=*/false);
95+
auto contigDimSize = blockShape.back();
7896

7997
llvm::SmallVector<Value> boxDim;
80-
if (fp4Padded && contig_dim_size != 128) {
98+
if (fp4Padded && contigDimSize != 128) {
8199
return op->emitError(
82100
"FP4 padded loads require 128 elements or more in the last dim");
83101
}
84-
boxDim.push_back(mkI32Constant(contig_dim_size));
102+
boxDim.push_back(mkI32Constant(contigDimSize));
85103
for (int k = shapePerCTA.size() - 2; k >= 0; --k)
86-
boxDim.push_back(mkI32Constant(shapePerCTA[k]));
104+
boxDim.push_back(mkI32Constant(blockShape[k]));
87105

88106
unsigned swizzleBytes = mmaEncoding ? mmaEncoding.getSwizzlingByteWidth() : 0;
89107
if (!mmaEncoding) {

include/triton/Tools/LayoutUtils.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,17 @@ LinearLayout ensureLayoutNotSmallerThan(
8383
const LinearLayout &layout,
8484
const llvm::SmallDenseMap<StringAttr, int64_t> &shape);
8585

86+
inline LinearLayout
87+
ensureLayoutNotSmallerThan(const LinearLayout &layout,
88+
const llvm::ArrayRef<StringAttr> dimNames,
89+
const llvm::ArrayRef<int64_t> shape) {
90+
llvm::SmallDenseMap<StringAttr, int64_t> namedDims;
91+
for (auto [dimName, length] : llvm::zip_equal(dimNames, shape))
92+
namedDims[dimName] = length;
93+
assert(namedDims.size() == shape.size() && "duplicate dimension names given");
94+
return ensureLayoutNotSmallerThan(layout, namedDims);
95+
}
96+
8697
// Return a vector of the standard out dimension names for tensor layouts. These
8798
// are "dim0", "dim1", etc.
8899
SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank);

include/triton/Tools/LinearLayout.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,23 +325,32 @@ class LinearLayout {
325325
bases;
326326

327327
llvm::MapVector<StringAttr, int32_t /*size*/> outDims;
328-
bool surjective;
328+
bool surjective = true;
329329

330330
public:
331331
using BasesT = decltype(bases);
332332

333+
LinearLayout() = default;
334+
333335
// The 0-dimensional layout that maps everything to 0. This is useful as a
334336
// starting point when doing something like
335337
//
336338
// LinearLayout ret = LinearLayout::empty();
337339
// for (...) ret *= ...;
338340
// return ret;
339-
static LinearLayout empty() { return LinearLayout(BasesT{}, {}); }
341+
static LinearLayout empty() { return {}; }
342+
343+
// Creates a 1D -> 1D layout that's the function L(x) = stride * x
344+
// for x in [0, size).
345+
static LinearLayout strided1D(int32_t size, int32_t stride, StringAttr inDim,
346+
StringAttr outDim);
340347

341348
// Creates a 1D -> 1D layout that's the identity function, i.e. L(x) = x
342349
// for x in [0, size).
343350
static LinearLayout identity1D(int32_t size, StringAttr inDim,
344-
StringAttr outDim);
351+
StringAttr outDim) {
352+
return strided1D(size, /*stride=*/1, inDim, outDim);
353+
}
345354

346355
// Creates a 1D -> 1D layout that maps every input value to 0, i.e. L(x) = 0
347356
// for x in [0, size). By default this creates a surjective layout where

0 commit comments

Comments
 (0)