Skip to content

Commit 5cf00e1

Browse files
Merge OpenAI Triton commit 4bcdbde (#4903)
This PR change the Triton base from cf0db92 to 4bcdbde (Jul 31). Pass rate: 98.83%
2 parents 3f2cc86 + 097a106 commit 5cf00e1

File tree

59 files changed

+2394
-1448
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+2394
-1448
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -357,9 +357,6 @@ class SharedMemoryObject {
357357

358358
SmallVector<Type> getTypes() const;
359359

360-
SmallVector<Value> getStrides(triton::gpu::MemDescType memDesc, Location loc,
361-
RewriterBase &rewriter) const;
362-
363360
// Returns a mask representing all the bits of the memdesc offsets that
364361
// may be modified by an affine offset coming from a memdesc_subslice.
365362
// The offsets are considered to be in the type of the memdesc.
@@ -385,14 +382,6 @@ class SharedMemoryObject {
385382
Value getBaseBeforeSlice(int dim, Location loc, RewriterBase &rewriter) const;
386383

387384
private:
388-
static SmallVector<unsigned> getOrderForShape(ArrayRef<int64_t> shape,
389-
ArrayRef<unsigned> layoutOrder);
390-
391-
static SmallVector<Value> getStridesForShape(ArrayRef<int64_t> shape,
392-
ArrayRef<unsigned> layoutOrder,
393-
Location loc,
394-
RewriterBase &rewriter);
395-
396385
Value base; // i32 ptr. The start address of the shared memory object.
397386
Type baseElemType;
398387
SmallVector<Value>

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 0 additions & 168 deletions
Original file line numberDiff line numberDiff line change
@@ -99,174 +99,6 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
9999
MLIRContext *ctx, ArrayRef<unsigned> tensorShape,
100100
ArrayRef<unsigned> repShape, ArrayRef<unsigned> order);
101101

102-
// This function constructs a linear layout that maps
103-
// <register, lane, warp> to <shared memory offset, iteration>.
104-
// The primary goal is to efficiently store 2D tiles of a tensor into shared
105-
// memory using the `stmatrix` instruction, with each thread responsible for
106-
// storing `N` elements. If `stmatrix` cannot be used for the given tensor
107-
// encoding, this function returns `std::nullopt`.
108-
//
109-
// Unlike standard vectorized stores, such as `st.shared.v4 [%offset],
110-
// %vec_reg`, where `%vec_reg` contains four consecutive data elements, the
111-
// `stmatrix` instruction allows `N` registers to point to non-contiguous
112-
// locations within a tensor tile.
113-
//
114-
// For instance, the `stmatrix [%offset], %mat_reg` instruction on NVIDIA GPUs
115-
// enables `%mat_reg` to store `N` elements that do not need to be consecutive.
116-
// However, it is crucial that the address (`%offset`) of each row in a tensor
117-
// tile should be aligned to `N` * `elemBitWidth`. The `%offset` of each thread
118-
// is calculated based on the provided tensor encoding.
119-
//
120-
// Currently, we support only the NVIDIA MMAv3 encoding and the `stmatrix.x4`
121-
// instruction. Each `stmatrix.x4` instruction stores eight 16-bit elements per
122-
// thread, resulting in a total of 8 * 32 = 256 elements per warp, or 16 * 16
123-
// elements per warp when distributed across four 8x8 tiles. Each thread's
124-
// `%offset` points to an address aligned with 8 * 16 bits, denoting a row in
125-
// the 8x8 tile. The values in `%mat_reg` are non-consecutive elements,
126-
// composed of 4 pairs of consecutive elements. These matrix addresses are
127-
// distributed as follows:
128-
//
129-
// col[0-7] col[8-15]
130-
// row[0-7] lane[0-7] lane[16-23]
131-
// row[8-15] lane[8-15] lane[24-31]
132-
//
133-
// The matrix elements of thread 0 are distributed in the following pattern:
134-
//
135-
// col0 col8
136-
// row0 reg[0-1] reg[4-5]
137-
// row8 reg[2-3] reg[6-7]
138-
//
139-
// When `swizzleByteSize` is non-zero, the layout is constructed
140-
// differently due to leading dimension offset and swizzling.
141-
// There are two key concepts to understand:
142-
//
143-
// 1. Chunks: The leading dimension (i.e., the column dimension) is divided
144-
// into chunks, where each chunk's size is determined by `swizzleByteSize`.
145-
// 2. Swizzling within tiles: Each tile applies a swizzling pattern to its
146-
// rows to optimize memory access.
147-
//
148-
// - Concept 1: Chunks
149-
//
150-
// In the swizzled layout, the leading dimension is strided by
151-
// `swizzleByteSize`. This introduces the concept of a "chunk", where each chunk
152-
// spans a certain number of columns.
153-
//
154-
// For a tile size of `stmatrix.x4` (16x16 elements), with each element being 16
155-
// bits (2 bytes), each tile occupies 16 rows and 32 bytes per row (since 16
156-
// elements * 2 bytes per element = 32 bytes per row).
157-
//
158-
// Given a `swizzleByteSize` of 128 bytes, the number of tiles per chunk can be
159-
// calculated as:
160-
//
161-
// Number of tiles per chunk = swizzleByteSize / (bytes per row) = 128 bytes /
162-
// 32 bytes = 4 tiles
163-
//
164-
// Therefore, each chunk contains 4 tiles horizontally, spanning 64 columns
165-
// (since each tile is 16 columns):
166-
//
167-
// col0-15 col16-31 col32-47 col48-63
168-
// row0-15 tile0 tile1 tile2 tile3
169-
//
170-
// For a tensor of size 128x128 elements (#rows x #columns), and each element
171-
// being 16 bits, the tensor can be divided into multiple chunks both
172-
// horizontally and vertically. Chunks are stored in memory in a "column-major"
173-
// order based on chunks, meaning chunk1's address follows chunk0's.
174-
//
175-
// Assuming we have 8 warps, and we assign each warp to process a chunk of 16
176-
// rows (rows per tile) and 128 columns (the width of two chunks). This results
177-
// in each warp handling one horizontal slice of the tensor.
178-
//
179-
// The overall layout can be visualized as:
180-
//
181-
// |<- 128 * 128 bytes ->|<- 128 * 128 bytes ->|
182-
// columns 0-63 columns 64-127
183-
// warp0 | rows 0-15 chunk0 chunk8
184-
// warp1 | rows 16-31 chunk1 chunk9
185-
// warp2 | rows 32-47 chunk2 chunk10
186-
// warp3 | rows 48-63 chunk3 chunk11
187-
// warp4 | rows 64-79 chunk4 chunk12
188-
// warp5 | rows 80-95 chunk5 chunk13
189-
// warp6 | rows 96-111 chunk6 chunk14
190-
// warp7 | rows 112-127 chunk7 chunk15
191-
//
192-
// - Concept 2: Swizzling within tiles
193-
//
194-
// Within each 16x16 tile, rows are swizzled to optimize memory access patterns.
195-
// This swizzling is similar to what's defined in `TritonGPUAttrDefs.td`. at the
196-
// level of each 16x16 tile rather than the entire tensor.
197-
//
198-
// Key parameters for swizzling:
199-
//
200-
// - `perPhase`: The number of rows over which to apply a XOR operation at
201-
// each phase.
202-
// - `maxPhase`: The total number of phases.
203-
// - `vectorWidth`: The number of elements per vector, which is 8 in this case
204-
// because `stmatrix` stores 8 contiguous elements per thread.
205-
//
206-
// The offset of each element within a tile is calculated using the formula:
207-
//
208-
// offset = row * swizzleByteSize + (vectorWidth * ((row / perPhase) %
209-
// maxPhase)) * elementSize
210-
//
211-
// where `elementSize` is the size of each element in bytes (2 bytes for 16-bit
212-
// elements).
213-
//
214-
// For example, consider the element at index `(row=1, col=0)` in chunk0:
215-
//
216-
// Without swizzling:
217-
//
218-
// offset = row * swizzleByteSize + col * elementSize
219-
// = 1 * 128 bytes + 0 * 2 bytes
220-
// = 128 bytes
221-
//
222-
// With swizzling (assuming `perPhase=1`, `maxPhase=8`, `vectorWidth=8`):
223-
//
224-
// offset = row * swizzleByteSize + (vectorWidth * ((row / perPhase) %
225-
// maxPhase)) * elementSize
226-
// = 1 * 128 bytes + (8 * ((1 / 1) % 8)) * 2 bytes
227-
// = 128 bytes + (8 * (1 % 8)) * 2 bytes
228-
// = 128 bytes + 8 * 2 bytes
229-
// = 128 bytes + 16 bytes
230-
// = 144 bytes
231-
//
232-
// This swizzling ensures that elements are stored in a way that optimizes for
233-
// memory bandwidth and reduces bank conflicts.
234-
//
235-
// - Verification through Linear Layout
236-
//
237-
// We can verify the offsets with the following outputs of the corresponding
238-
// linear layout, where each element is 16 bits (2 bytes):
239-
//
240-
// - register=1 -> offset=1
241-
// register=2 -> offset=2
242-
// register=4 -> offset=4
243-
// register=8 -> offset=16
244-
// register=16 -> offset=32
245-
// register=32 -> offset=8192
246-
// - lane=1 -> offset=72
247-
// lane=2 -> offset=144
248-
// lane=4 -> offset=288
249-
// lane=8 -> offset=512
250-
// lane=16 -> offset=8
251-
// - warp=1 -> offset=1024
252-
// warp=2 -> offset=2048
253-
// warp=4 -> offset=4096
254-
//
255-
// For index `(row=1, col=0)`, which corresponds to `reg=0` and `lane=1` in
256-
// `warp=0`, the offset is calculated as 72 * 2 bytes = 144 bytes. The result
257-
// matches our earlier calculation.
258-
//
259-
// TODO(Keren): We should replace tensorTy with a LinearLayout and the element
260-
// bit width of the tensor in the future to support more flexible tensor
261-
// encodings
262-
LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
263-
int swizzleByteSize);
264-
265-
// The primary goal of this function is to efficiently store 2D tiles of a
266-
// tensor into shared memory using the `ldmatrix` instruction.
267-
LinearLayout chooseLdMatrixLayout(Attribute enc, ArrayRef<int64_t> shape,
268-
bool needTrans, int32_t elemBitWidth);
269-
270102
// The primary goal of this function is to efficiently load 2D tiles of a
271103
// tensor from shared memory using the `ds_read_tr` instruction for AMD GPUs.
272104
LinearLayout chooseDsReadB64TrLayout(Attribute enc, ArrayRef<int64_t> shape,

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,7 @@ def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> {
6464
let results = (outs TTG_AsyncToken:$asyncToken);
6565
let arguments = (ins Variadic<TTG_AsyncToken>:$inputTokens);
6666

67-
let assemblyFormat = [{
68-
$inputTokens attr-dict
69-
}];
67+
let assemblyFormat = [{($inputTokens ^)?attr-dict}];
7068

7169
let extraClassDeclaration = [{
7270
static bool isSupported(int computeCapability) {

include/triton/Dialect/TritonGPU/Transforms/PartitionBuilder.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#ifndef TRITONGPU_WARPSPECIALIZATION_PARTITIONBUILDER_H
2-
#define TRITONGPU_WARPSPECIALIZATION_PARTITIONBUILDER_H
1+
#ifndef TRITON_TRITONGPU_TRANSFORMS_PARTITIONBUILDER_H
2+
#define TRITON_TRITONGPU_TRANSFORMS_PARTITIONBUILDER_H
33

44
#include "mlir/IR/ImplicitLocOpBuilder.h"
55

@@ -33,4 +33,4 @@ StageCluster getStageCluster(Operation *op);
3333

3434
} // namespace mlir::triton::gpu
3535

36-
#endif // TRITONGPU_WARPSPECIALIZATION_PARTITIONBUILDER_H
36+
#endif // TRITON_TRITONGPU_TRANSFORMS_PARTITIONBUILDER_H

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,14 @@ void combineRedundantWaitOps(
132132
llvm::SmallSetVector<gpu::AsyncWaitOp, 8> &waitOps);
133133

134134
// Get the type of the view of a multi-buffered tensor value.
135-
gpu::MemDescType getBufferViewType(gpu::MemDescType allocTy);
135+
gpu::MemDescType getBufferViewType(gpu::MemDescType allocTy,
136+
bool mutableMemory = true);
137+
138+
// Get a mutable, multi-buffered version of the given memdesc type, with
139+
// multiplicity "depth".
140+
gpu::MemDescType getMultiBufferedType(gpu::MemDescType memDescType,
141+
int32_t depth);
142+
136143
// Get a generic shared encoding for a tensor.
137144
gpu::SharedEncodingTrait getSharedEncoding(RankedTensorType ty);
138145
// Get a shared encoding for a tensor based on its uses.
@@ -157,6 +164,22 @@ Value createIncrementModulo(OpBuilder &builder, Location loc, Value counter,
157164

158165
scf::ForOp lowerTMADescriptors(scf::ForOp forOp, CoarseSchedule &schedule);
159166

167+
DenseSet<Operation *>
168+
getTopLevelUsersInLoop(Operation *op, scf::ForOp forOp,
169+
std::function<bool(Operation *)> filter = nullptr);
170+
171+
// Return the "first" op in terms of the stage and cluser ordering
172+
Operation *
173+
getFirstUseOfPipelinedOp(ArrayRef<Operation *> ops, scf::ForOp forOp,
174+
CoarseSchedule &schedule,
175+
std::function<bool(Operation *)> filterUse = nullptr);
176+
177+
// Return the "last" op in terms of the stage and cluser ordering
178+
Operation *
179+
getLastUseOfPipelinedOp(ArrayRef<Operation *> ops, scf::ForOp forOp,
180+
CoarseSchedule &schedule,
181+
std::function<bool(Operation *)> filterUse = nullptr);
182+
160183
} // namespace triton
161184
} // namespace mlir
162185

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,11 @@ namespace mlir::triton {
255255
/// Replace all uses of `oldUse` with `val` and propagate the type if needed.
256256
/// This is useful when we need to change a memory descriptor from immutable to
257257
/// mutable.
258-
void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
259-
Value val);
258+
/// The callback is invoked for each pair of an old and a cloned memdesc op
259+
/// as the type is propagated.
260+
void replaceUsesAndPropagateType(
261+
OpBuilder &builder, Operation *oldUse, Value val,
262+
std::function<void(Operation *, Operation *)> callback = nullptr);
260263

261264
/// Replace all uses of `old` with a local load from `alloc` unless the use is a
262265
/// `ttg.local_alloc` with a matching shared encoding, in which case the shared

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [
8787

8888
let arguments = (ins
8989
TTG_TensorOrMemDesc:$a,
90-
TTG_TensorOrMemDesc:$b,
90+
TTG_MemDescType:$b,
9191
TT_FpIntTensor:$c,
9292
Optional<I1>:$useC,
9393
DefaultValuedAttr<TT_InputPrecisionAttr, "::mlir::triton::InputPrecision::IEEE">:$inputPrecision,
@@ -99,7 +99,7 @@ def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [
9999

100100
let assemblyFormat = [{
101101
$a`,` $b`,` $c (`,` $useC^)? attr-dict
102-
`:` type($a) `*` type($b) `->` type($d)
102+
`:` type($a) `*` qualified(type($b)) `->` type($d)
103103
}];
104104

105105
let extraClassDeclaration = [{

include/triton/Tools/LayoutUtils.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,6 @@ namespace mlir::triton {
1010
bool squareSublayoutIsIdentity(const LinearLayout &ll,
1111
ArrayRef<StringAttr> dimNames);
1212

13-
// Is the sublayout defined from dimNames to dimNames a subpermutation matrix?
14-
// I.e. the layout matrix is formed by selecting unique columns from the
15-
// identity matrix and adding zero columns. A zero column in the layout means
16-
// that changing a bit in the inputs does not change the bits of the outputs
17-
// (broadcasting).
18-
bool squareSublayoutIsPermutation(const LinearLayout &ll,
19-
ArrayRef<StringAttr> dimNames);
20-
2113
// For each output dimension d, ensure that the layout's output size (i.e., its
2214
// codomain) does not exceed shape[d]. Do this without changing the size of the
2315
// layout's inputs (i.e., leave its domain unchanged).

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,19 +1079,6 @@ SmallVector<Type> SharedMemoryObject::getTypes() const {
10791079
return types;
10801080
}
10811081

1082-
SmallVector<Value>
1083-
SharedMemoryObject::getStrides(triton::gpu::MemDescType memDesc, Location loc,
1084-
RewriterBase &rewriter) const {
1085-
auto allocShape = memDesc.getAllocShape();
1086-
auto allocShapePerCTA =
1087-
triton::gpu::getAllocationShapePerCTA(memDesc.getEncoding(), allocShape);
1088-
auto layoutOrder = triton::gpu::getOrder(memDesc);
1089-
auto allocStrides = SharedMemoryObject::getStridesForShape(
1090-
allocShapePerCTA, layoutOrder, loc, rewriter);
1091-
return SmallVector<Value>(allocStrides.end() - offsets.size(),
1092-
allocStrides.end());
1093-
}
1094-
10951082
Value SharedMemoryObject::getBaseBeforeSlice(int dim, Location loc,
10961083
RewriterBase &rewriter) const {
10971084
auto b = TritonLLVMOpBuilder(loc, rewriter);
@@ -1101,42 +1088,6 @@ Value SharedMemoryObject::getBaseBeforeSlice(int dim, Location loc,
11011088
return b.gep(type, baseElemType, base, offset);
11021089
}
11031090

1104-
SmallVector<unsigned>
1105-
SharedMemoryObject::getOrderForShape(ArrayRef<int64_t> shape,
1106-
ArrayRef<unsigned> layoutOrder) {
1107-
SmallVector<unsigned> order(shape.size());
1108-
// Default minor-to-major order
1109-
std::iota(order.rbegin(), order.rend(), 0);
1110-
if (layoutOrder.size() > 0) {
1111-
// If a layout order is provided, we assume it specifies the order in
1112-
// which the dimensions are first accessed, and unspecified dimensions
1113-
// retain the minor-to-major order. For example, if order = [2, 1, 0] and
1114-
// layoutOrder = [0, 1], we need to shift `layoutOrder`
1115-
// by -1 (move them right). The resulting order will then be [1, 2, 0].
1116-
int rankDiff = layoutOrder.size() - shape.size();
1117-
auto minRank = std::min<size_t>(shape.size(), layoutOrder.size());
1118-
for (size_t i = 0; i < minRank; ++i)
1119-
order[i] = layoutOrder[i] - rankDiff;
1120-
assert(isPermutationOfIota(order) && "Invalid order");
1121-
}
1122-
return order;
1123-
}
1124-
1125-
SmallVector<Value>
1126-
SharedMemoryObject::getStridesForShape(ArrayRef<int64_t> shape,
1127-
ArrayRef<unsigned> layoutOrder,
1128-
Location loc, RewriterBase &rewriter) {
1129-
SmallVector<Value> strides(shape.size());
1130-
auto order = SharedMemoryObject::getOrderForShape(shape, layoutOrder);
1131-
int64_t stride = 1;
1132-
auto b = TritonLLVMOpBuilder(loc, rewriter);
1133-
for (auto idx : order) {
1134-
strides[idx] = b.i32_val(stride);
1135-
stride *= shape[idx];
1136-
}
1137-
return strides;
1138-
}
1139-
11401091
uint64_t
11411092
SharedMemoryObject::getMaskSpanOffsets(triton::gpu::MemDescType srcTy) {
11421093
auto ctx = srcTy.getContext();

0 commit comments

Comments
 (0)