Skip to content

Commit 40fd289

Browse files
Merge OpenAI Triton commit 75fe113 (#4390)
This PR change the Triton base from 2d6fb76 to 75fe113 (May 23). Pass rate: 97.23% -> 97.23% Note: * didn't apply changes from triton-lang/triton@747e205#diff-f0e7e61833e691ba711378a49219d2c4392e7ca9aed419b2d26a47d4abc1e6da to `06-fused-attention.py` (Seems related to #4283) * reverted llvm pin update as incompatible with llvm-spirv translator: e7d0e43. The issue for this: #4391
2 parents 99065a2 + e7d0e43 commit 40fd289

File tree

56 files changed

+920
-795
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+920
-795
lines changed

.github/workflows/test-backends.yml

Lines changed: 0 additions & 83 deletions
This file was deleted.

docs/python-api/triton.language.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Programming Model
1212
:nosignatures:
1313

1414
tensor
15+
tensor_descriptor
1516
program_id
1617
num_programs
1718

@@ -71,6 +72,9 @@ Memory/Pointer Ops
7172

7273
load
7374
store
75+
make_tensor_descriptor
76+
load_tensor_descriptor
77+
store_tensor_descriptor
7478
make_block_ptr
7579
advance
7680

include/triton/Dialect/TritonGPU/Transforms/Partition.h

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
#include "mlir/Support/LLVM.h"
55
#include "llvm/ADT/ArrayRef.h"
66
#include "llvm/ADT/DenseMap.h"
7-
#include "llvm/ADT/GraphTraits.h"
8-
#include "llvm/ADT/MapVector.h"
97
#include "llvm/ADT/SmallVector.h"
108

119
namespace mlir {
@@ -26,43 +24,38 @@ static constexpr char kPartitionStagesAttrName[] = "ttg.partition.stages";
2624
//===----------------------------------------------------------------------===//
2725

2826
namespace mlir::triton::gpu {
27+
// A partition has a stage and contains some operation. The stage of a
28+
// partition determines how many cycles the partition's outputs are buffered
29+
// relative to its consumers.
30+
class Partition {
31+
public:
32+
Partition(int idx, int stage) : idx(idx), stage(stage) {}
33+
34+
int getIndex() const { return idx; }
35+
int getStage() const { return stage; }
36+
ArrayRef<Operation *> getOps() const { return ops; }
37+
38+
private:
39+
void setIndex(int idx) { this->idx = idx; }
40+
friend class WarpSchedule;
41+
42+
// The partition number.
43+
int idx;
44+
// The stage of the partition.
45+
int stage;
46+
// The ops in the partition.
47+
SmallVector<Operation *> ops;
48+
};
49+
2950
// A warp schedule divides a loop into multiple partitions. Ops in a loop are
3051
// assigned at most one partition. A warp schedule represents asynchronous
3152
// execution of the loop body, where partitions may execute simultaneously.
3253
class WarpSchedule {
3354
static constexpr int kSentinel = -1;
3455

3556
public:
36-
// A partition has a stage and contains some operation. The stage of a
37-
// partition determines how many cycles the partition's outputs are buffered
38-
// relative to its consumers.
39-
class Partition {
40-
public:
41-
Partition(int idx, int stage) : idx(idx), stage(stage) {}
42-
43-
int getIndex() const { return idx; }
44-
int getStage() const { return stage; }
45-
ArrayRef<Operation *> getOps() const { return ops; }
46-
47-
void insert(Operation *op) { ops.push_back(op); }
48-
void remove(Operation *op) { ops.erase(llvm::find(ops, op)); }
49-
50-
private:
51-
void setIndex(int idx) { this->idx = idx; }
52-
friend class WarpSchedule;
53-
54-
// The partition number.
55-
int idx;
56-
// The stage of the partition.
57-
int stage;
58-
// The ops in the partition.
59-
SmallVector<Operation *> ops;
60-
};
61-
6257
// Create a new partition with a stage.
6358
Partition *addPartition(unsigned stage);
64-
// Update the op to partition mapping.
65-
void updatePartitions();
6659

6760
// Get the partition the op belongs to.
6861
Partition *getPartition(Operation *op);

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,36 +26,27 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
2626
];
2727
}
2828

29-
def TritonGPUTestPipelineAssignLatencies : Pass<"tritongpu-test-pipeline-assign-latencies", "mlir::ModuleOp"> {
30-
let summary = "test assigning latencies to interesting ops ahead of pipelining";
29+
def TritonGPUAssignLatencies : Pass<"tritongpu-assign-latencies", "mlir::ModuleOp"> {
30+
let summary = "assign latencies to interesting ops ahead of pipelining";
3131

3232
let description = [{
33-
This is a test pass that tests `assignLatencies` method of `TritonGPUPipeline`.
33+
The `tritongpu-assign-latencies` pass assigns latencies to latency ops based
34+
on the number of stages.
3435
}];
3536

36-
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
37-
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
38-
"mlir::scf::SCFDialect",
39-
"mlir::arith::ArithDialect"];
40-
4137
let options = [
42-
Option<"numStages", "num-stages",
43-
"int32_t", /*default*/"3",
38+
Option<"numStages", "num-stages", "int32_t", /*default*/"3",
4439
"number of pipeline stages">
4540
];
4641
}
4742

48-
def TritonGPUTestPipelineScheduleLoop : Pass<"tritongpu-test-pipeline-schedule-loop", "mlir::ModuleOp"> {
49-
let summary = "test scheduling a loop for software pipelining";
43+
def TritonGPUScheduleLoops : Pass<"tritongpu-schedule-loops", "mlir::ModuleOp"> {
44+
let summary = "software pipeline loop scheduling";
5045

5146
let description = [{
52-
This is a test pass that tests `scheduleLoop` method of `TritonGPUPipeline`.
47+
The `tritongpu-schedule-loops` pass performs scheduling for loop pipelining
48+
for loops with latency ops.
5349
}];
54-
55-
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
56-
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
57-
"mlir::scf::SCFDialect",
58-
"mlir::arith::ArithDialect"];
5950
}
6051

6152
def TritonGPUHoistTMEMAlloc : Pass<"tritongpu-hoist-tmem-alloc", "mlir::ModuleOp"> {

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 4 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ static const char *kWarpSpecializeAttrName = "tt.warp_specialize";
1919
static const char *kLoopStageAttrName = "loop.stage";
2020
static const char *kLoopClusterAttrName = "loop.cluster";
2121
static const char *kScheduledMaxStageAttrName = "tt.scheduled_max_stage";
22-
static const char *kAssignedStageAttrName = "ttg.assigned_stage";
23-
static const char *kAssignedClusterAttrName = "ttg.assigned_cluster";
2422

2523
//===----------------------------------------------------------------------===//
2624
// Hoisting Utilities
@@ -133,42 +131,12 @@ int getNumStagesOrDefault(scf::ForOp forOp, int defaultNumStages);
133131

134132
// Given a result of MemDescSubview, or Alloca, create a MemDescSubview with a
135133
// single buffer slice (leading dimension equal to 1), at the given index.
136-
template <typename TBuilder>
137134
TypedValue<triton::gpu::MemDescType>
138-
createSingleBufferView(TBuilder &builder, Value alloc, Value idx) {
139-
assert(isa<triton::gpu::MemDescType>(alloc.getType()) &&
140-
"Expected MemDescType");
141-
auto allocDescType = cast<triton::gpu::MemDescType>(alloc.getType());
142-
SmallVector<int64_t> shape;
143-
if (allocDescType.getShape().size() > 1) {
144-
shape.insert(shape.end(), allocDescType.getShape().begin() + 1,
145-
allocDescType.getShape().end());
146-
} else {
147-
shape.push_back(1);
148-
}
149-
auto viewDescType = triton::gpu::MemDescType::get(
150-
shape, allocDescType.getElementType(), allocDescType.getEncoding(),
151-
allocDescType.getMemorySpace(), allocDescType.getMutableMemory(),
152-
/*allocShape=*/allocDescType.getAllocShape());
153-
SmallVector<Value> idxs = {idx};
154-
if (allocDescType.getShape().size() > 1) {
155-
Value zero =
156-
builder.template create<arith::ConstantIntOp>(alloc.getLoc(), 0, 32);
157-
for (unsigned i = 1; i < allocDescType.getShape().size(); i++) {
158-
idxs.push_back(zero);
159-
}
160-
}
161-
return builder.template create<triton::gpu::MemDescSubviewOp>(
162-
alloc.getLoc(), viewDescType, alloc, idxs);
163-
}
164-
165-
template <typename TBuilder>
135+
createSingleBufferView(OpBuilder &builder, Value alloc, Value idx);
136+
// Given a result of MemDescSubview, or Alloca, create a MemDescSubview with a
137+
// single buffer slice (leading dimension equal to 1), at the given index.
166138
TypedValue<triton::gpu::MemDescType>
167-
createSingleBufferView(TBuilder &builder, Value alloc, int idx) {
168-
return createSingleBufferView(
169-
builder, alloc,
170-
builder.template create<arith::ConstantIntOp>(alloc.getLoc(), idx, 32));
171-
}
139+
createSingleBufferView(OpBuilder &builder, Value alloc, int idx);
172140

173141
} // namespace triton
174142
} // namespace mlir

include/triton/Dialect/TritonGPU/Transforms/Schedule.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,6 @@ namespace triton {
1313

1414
namespace gpu {
1515

16-
/// Discover operations that should become async and assign latencies to them
17-
/// based on the numStages value provided by the user.
18-
void assignLatencies(ModuleOp moduleOp, int numStages);
19-
20-
/// Schedule the loops based on the latencies assigned to the operations.
21-
void scheduleLoops(ModuleOp moduleOp);
22-
2316
/// Lower the loops to prepare them for pipeline expansion.
2417
void lowerLoops(ModuleOp moduleOp);
2518

@@ -115,6 +108,10 @@ class CoarseSchedule {
115108
bool insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster,
116109
bool includeArg, bool insertIfEarlier = false);
117110

111+
// Remove empty stages and clusters from the schedule, adjusting the maximum
112+
// number of stages as appropriate.
113+
void shrinkToFit();
114+
118115
void erase(Operation *op) { opToStageAndCluster.erase(op); }
119116

120117
int count(Operation *op) { return opToStageAndCluster.count(op); }

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -247,47 +247,25 @@ SetVector<Value> getNestedOperands(Operation *op);
247247
// Erase the given loop carried values from the loop, where `loop` is replaced
248248
// with a new loop.
249249
void eraseLoopCarriedValues(scf::ForOp &loop, llvm::BitVector indices);
250+
251+
// Get a boolean if the Value is an arith::ConstantOp
252+
std::optional<bool> getBoolFromConstant(Value cst);
250253
} // namespace mlir
251254

252255
namespace mlir::triton {
253-
254256
/// Replace all uses of `oldUse` with `val` and propagate the type if needed.
255257
/// This is useful when we need to change a memory descriptor from immutable to
256258
/// mutable.
257259
void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
258260
Value val);
259261

260-
template <typename BuilderT>
262+
/// Replace all uses of `old` with a local load from `alloc` unless the use is a
263+
/// `ttg.local_alloc` with a matching shared encoding, in which case the shared
264+
/// memory is forwarded directly into the use.
261265
void replaceUsesWithLocalLoad(
262-
BuilderT &builder, OpResult old, TypedValue<triton::gpu::MemDescType> alloc,
263-
TypedValue<triton::gpu::AsyncTokenType> token = {}) {
264-
// Remove redundant local_load -> local_alloc
265-
namespace ttg = triton::gpu;
266-
using triton::gpu::LocalAllocOp;
267-
auto allocTy = alloc.getType();
268-
SmallVector<LocalAllocOp> allocsToErase;
269-
for (Operation *user : old.getUsers()) {
270-
if (auto userAlloc = dyn_cast<LocalAllocOp>(user)) {
271-
if (allocTy.getEncoding() == userAlloc.getType().getEncoding()) {
272-
replaceUsesAndPropagateType(builder, userAlloc, alloc);
273-
allocsToErase.push_back(userAlloc);
274-
}
275-
}
276-
}
277-
278-
// If there are some uses that were not local_allocs, we need to create a
279-
// local_load for them.
280-
if (std::distance(old.getUsers().begin(), old.getUsers().end()) >
281-
allocsToErase.size()) {
282-
auto loc = old.getOwner()->getLoc();
283-
auto sharedLoad = builder.template create<ttg::LocalLoadOp>(
284-
loc, old.getType(), alloc, token);
285-
old.replaceAllUsesWith(sharedLoad.getResult());
286-
}
287-
for (auto alloc : allocsToErase) {
288-
alloc.erase();
289-
}
290-
}
266+
OpBuilder &builder, OpResult old,
267+
TypedValue<triton::gpu::MemDescType> alloc,
268+
TypedValue<triton::gpu::AsyncTokenType> token = {});
291269
} // namespace mlir::triton
292270

293271
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

include/triton/Dialect/TritonGPU/Transforms/WarpSpecialization.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@ namespace scf {
88
class ForOp;
99
} // namespace scf
1010
namespace triton::gpu {
11-
// Identify load-mma dependencies and specialize them to different partitions.
12-
LogicalResult specializeLoadMMADependencies(scf::ForOp &loop,
13-
int defaultNumStages);
1411
// This is the final step to prepare a loop for warp specialization. This takes
1512
// a loop with a partition schedule and rewrites the loop such that all SSA
1613
// dependencies between partitions are passed through shared memory and

0 commit comments

Comments
 (0)