Skip to content

Commit bf46a53

Browse files
Merge OpenAI Triton commit 882a02e (#3776)
This PR change the Triton base from a39389a to 882a02e (Mar 27). Pass rate: 89.99%
2 parents fed3b87 + 6c196eb commit bf46a53

File tree

58 files changed

+2025
-795
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+2025
-795
lines changed

.github/workflows/integration-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ jobs:
5656
- name: Detect if build deps (e.g. LLVM hash) changed
5757
id: detect-change
5858
if: github.event_name == 'push'
59-
uses: tj-actions/changed-files@v45
59+
uses: tj-actions/changed-files@v46
6060
with:
6161
files: |
6262
cmake/*.txt

.github/workflows/integration-tests.yml.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ jobs:
6262
- name: Detect if build deps (e.g. LLVM hash) changed
6363
id: detect-change
6464
if: github.event_name == 'push'
65-
uses: tj-actions/changed-files@v45
65+
uses: tj-actions/changed-files@v46
6666
with:
6767
files: |
6868
cmake/*.txt

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
| **`Documentation`** | **`Nightly Wheels`** |
66
|-------------------- | -------------------- |
7-
| [![Documentation](https://github.com/triton-lang/triton/actions/workflows/documentation.yml/badge.svg)](https://triton-lang.org/) | [![Wheels](https://github.com/triton-lang/triton/actions/workflows/wheels.yml/badge.svg?branch=release/2.0.x)](https://github.com/triton-lang/triton/actions/workflows/wheels.yml) |
7+
| [![Documentation](https://github.com/triton-lang/triton/actions/workflows/documentation.yml/badge.svg)](https://triton-lang.org/) | [![Wheels](https://github.com/triton-lang/triton/actions/workflows/wheels.yml/badge.svg)](https://github.com/triton-lang/triton/actions/workflows/wheels.yml) |
88

99
# Triton
1010

include/triton/Dialect/TritonGPU/Transforms/MMAv5PipelineUtility.h

Lines changed: 74 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,105 @@
11
#ifndef TRITON_TRITONGPU_TRANSFORMS_MMAV5PIPELINEUTILITY_H_
22
#define TRITON_TRITONGPU_TRANSFORMS_MMAV5PIPELINEUTILITY_H_
33

4-
#include <functional>
5-
#include <optional>
6-
#include <tuple>
4+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
75

86
namespace mlir {
7+
98
class OpBuilder;
10-
class Operation;
9+
class DominanceInfo;
1110

1211
namespace scf {
1312
class ForOp;
14-
}
13+
} // namespace scf
1514
namespace triton::nvidia_gpu {
16-
class MMAv5OpInterface;
17-
class TMEMAllocOp;
18-
class TMEMLoadOp;
15+
16+
//===----------------------------------------------------------------------===//
17+
// MMAInfo
18+
//===----------------------------------------------------------------------===//
19+
20+
// This struct contains analysis information about an MMAv5 operation inside a
21+
// loop used for pipelining MMA ops.
22+
struct MMAInfo {
23+
// This struct contains information about when the MMA's accumulator is
24+
// overridden in the loop, if it is at all.
25+
struct AccOverridePoint {
26+
// The operation which overrides the accumulator.
27+
Operation *op;
28+
// The condition on which the accumulator is reset.
29+
Value condition = nullptr;
30+
// The initial value of the accumulator and the value after a reset.
31+
Value initValue = nullptr;
32+
// The number of loop iterations ago the accumulator was reset.
33+
int distance = 0;
34+
// Whether the accumulator is reset via setting the `useAcc` flag to false
35+
// or by clearing the accumulator tensor value.
36+
bool isFlag = false;
37+
};
38+
39+
// The TMEM allocation of the accumuator, which directly precedes the dot op.
40+
TMEMAllocOp accAlloc;
41+
// The TMEM load of the accumulator value out of TMEM, which directly follows
42+
// the dot op.
43+
TMEMLoadOp accLoad;
44+
// The override point of the accumulator value, if it is overriden in the
45+
// loop. E.g. this is typically present for persistent kernels.
46+
std::optional<AccOverridePoint> accDef;
47+
// If the accumulator is used in future iterations of the loop, this is the
48+
// iter arg number.
49+
std::optional<int> yieldArgNo;
50+
// Whether the accumulator needs to be multibuffered.
51+
bool accIsMultiBuffered;
52+
53+
Value phase = nullptr;
54+
Value barrierIdx = nullptr;
55+
Value accInsertIdx = nullptr;
56+
Value accExtractIdx = nullptr;
57+
Value barrierAlloc = nullptr;
58+
};
59+
60+
//===----------------------------------------------------------------------===//
61+
// MMA Pipeline Analysis
62+
//===----------------------------------------------------------------------===//
1963

2064
// Returns the TMEMAllocOp and TMEMLoadOp that are used to allocate and load the
2165
// accumulator for the given MMA operation. The TMEMAllocOp and TMEMLoadOp must
2266
// be in the same region as the MMA operation.
2367
std::optional<std::pair<TMEMAllocOp, TMEMLoadOp>>
2468
getTMemAllocAndLoad(MMAv5OpInterface mmaOp);
69+
// Get immediate users of the accumulator within the current loop iteration.
70+
SmallVector<Operation *> getDirectAccUses(TMEMLoadOp accDef);
71+
// Analyze an MMA op inside a loop to determine information about how it can be
72+
// pipelined. Returns `std::nullopt` if it cannot be pipelined.
73+
std::optional<MMAInfo> getMMAInfo(scf::ForOp forOp, MMAv5OpInterface mmaOp,
74+
DominanceInfo &domInfo);
75+
76+
//===----------------------------------------------------------------------===//
77+
// MMA Pipeline Rewriters
78+
//===----------------------------------------------------------------------===//
79+
2580
// Create a new TMEMAllocOp to use for the pipelined MMA operation. It is
2681
// optionally multi-buffered based on the number of stages.
2782
TMEMAllocOp createTMemAlloc(OpBuilder &builder, TMEMAllocOp oldTMemAllocOp,
2883
bool multiBufferred, int numStages);
2984

85+
// Create a store op of the initial value of the accumulator into the
86+
// potentially multi-buffered accumulator.
87+
void createInitStore(OpBuilder &builder, TMEMAllocOp allocOp, Value initVal,
88+
bool multiBufferred);
89+
3090
// Return true if operands of the MMA operation are/are going to be pipelined
3191
// and multibuffered, enabling the MMA operation to be pipelined.
3292
bool mmaHasPipelineableOperands(
3393
MMAv5OpInterface mma, scf::ForOp forOp,
3494
std::function<bool(Operation *)> isLoadPipelineable);
3595

36-
// Return true if the loop has a read-modify-write access to the accumulator.
96+
// Return true if the accumulator of an mma in subsequent iterations is either
97+
// independent from the previous iteration (overwritten) or completely reused,
98+
// without read-modify-write.
99+
// Otherwise, we can not pipeline the MMA, as we need to insert a wait after the
100+
// mma to read back the accumulator for RMW.
37101
bool hasAccReadModifyWrite(MMAv5OpInterface mma, scf::ForOp forOp);
102+
38103
} // namespace triton::nvidia_gpu
39104
} // namespace mlir
40105

include/triton/Dialect/TritonGPU/Transforms/Partition.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ class WarpSchedule {
4747
void insert(Operation *op) { ops.push_back(op); }
4848

4949
private:
50+
void setIndex(int idx) { this->idx = idx; }
51+
friend class WarpSchedule;
52+
5053
// The partition number.
5154
int idx;
5255
// The stage of the partition.
@@ -57,6 +60,8 @@ class WarpSchedule {
5760

5861
// Create a new partition with a stage.
5962
Partition *addPartition(unsigned stage);
63+
// Give each partition a new index and order. The indices must be unique.
64+
void reorderPartitions(ArrayRef<unsigned> order);
6065

6166
// Get the partition the op belongs to.
6267
Partition *getPartition(Operation *op);

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ namespace triton {
1414
static const char *kNumStagesAttrName = "tt.num_stages";
1515
static const char *kDisallowAccMultiBufferAttrName =
1616
"tt.disallow_acc_multi_buffer";
17+
static const char *kWarpSpecializeAttrName = "tt.warp_specialize";
1718
static const char *kLoopStageAttrName = "loop.stage";
1819
static const char *kLoopClusterAttrName = "loop.cluster";
1920
static const char *kScheduledMaxStageAttrName = "tt.scheduled_max_stage";
@@ -38,17 +39,6 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
3839
// `tt.disallow_acc_multi_buffer` set to true.
3940
bool getDisallowAccMultiBuffer(scf::ForOp forOp);
4041

41-
/// Visit the operands of `op` and the operands of any nested ops defined
42-
/// outside of `op`.
43-
void visitNestedOperands(Operation *op,
44-
function_ref<void(OpOperand &)> visitor);
45-
/// Visit the operands of `op` and the operands of any nested ops defined
46-
/// outside of `op`.
47-
void visitNestedOperands(Operation *op, function_ref<void(Value)> visitor);
48-
/// Get the operands of `op` and the operands of any nested ops defined outside
49-
/// of `op`.
50-
SetVector<Value> getNestedOperands(Operation *op);
51-
5242
// Return the definition of the given value. If the value is a loop-carried
5343
// dependency, return the definition and the distance to it.
5444
std::pair<OpResult, int64_t> getDefinitionAndDistance(scf::ForOp forOp,
@@ -90,10 +80,6 @@ gpu::SharedEncodingTrait getSharedEncoding(RankedTensorType ty);
9080
// Get a shared encoding for a tensor based on its uses.
9181
gpu::SharedEncodingTrait getSharedEncoding(Operation *loadOp);
9282

93-
// Erase the given loop carried values from the loop, where `loop` is replaced
94-
// with a new loop.
95-
void eraseLoopCarriedValues(scf::ForOp &loop, llvm::BitVector indices);
96-
9783
// Get the number of stages to pipeline the loop with, if it is explicitly
9884
// specified.
9985
int getNumStagesOrDefault(scf::ForOp forOp, int defaultNumStages);

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <numeric>
1111

1212
namespace mlir {
13+
class DominanceInfo;
1314

1415
namespace triton {
1516
class ModuleAxisInfoAnalysis;
@@ -135,6 +136,8 @@ scf::ForOp replaceForOpWithNewSignature(
135136
SmallVectorImpl<std::tuple<Value, Value>> &replacements);
136137
scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter, scf::ForOp loop,
137138
ValueRange newIterOperands);
139+
Block::BlockArgListType addIterArgsToLoop(OpBuilder &rewriter, scf::ForOp &loop,
140+
ValueRange newIterOperands);
138141

139142
// Replace WhileOp with a new WhileOp with extra operands. The YieldOp is not
140143
// updated and needs to be updated separately for the loop to be correct.
@@ -213,6 +216,27 @@ triton::gpu::LocalAllocOp findShmemAlloc(Value operand);
213216
SmallVector<Operation *>
214217
getMMAsWithMultiBufferredOperands(scf::ForOp forOp,
215218
SmallVector<Operation *> &mmaOps);
219+
220+
// Given a list of ops, find the naerest common dominator of all ops or return
221+
// null if one could not be found. The ops are allowed to be in different
222+
// regions. The result op is not necessarily one of the ops in the list.
223+
Operation *findNearestCommonDominator(ArrayRef<Operation *> ops,
224+
DominanceInfo &domInfo);
225+
226+
/// Visit the operands of `op` and the operands of any nested ops defined
227+
/// outside of `op`.
228+
void visitNestedOperands(Operation *op,
229+
function_ref<void(OpOperand &)> visitor);
230+
/// Visit the operands of `op` and the operands of any nested ops defined
231+
/// outside of `op`.
232+
void visitNestedOperands(Operation *op, function_ref<void(Value)> visitor);
233+
/// Get the operands of `op` and the operands of any nested ops defined outside
234+
/// of `op`.
235+
SetVector<Value> getNestedOperands(Operation *op);
236+
237+
// Erase the given loop carried values from the loop, where `loop` is replaced
238+
// with a new loop.
239+
void eraseLoopCarriedValues(scf::ForOp &loop, llvm::BitVector indices);
216240
} // namespace mlir
217241

218242
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,7 @@ LogicalResult WarpYieldOp::verify() {
850850
static size_t getSharedMemorySize(Type type) {
851851
if (isa<IntegerType, FloatType>(type))
852852
return llvm::divideCeil(type.getIntOrFloatBitWidth(), 8);
853-
if (isa<PointerType>(type))
853+
if (isa<PointerType, TensorDescType>(type))
854854
return 8;
855855
if (auto desc = dyn_cast<MemDescType>(type)) {
856856
if (!isa<SharedMemorySpaceAttr>(desc.getMemorySpace()))

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ add_triton_library(TritonGPUTransforms
1212
OptimizeThreadLocality.cpp
1313
Pipeliner/AssignLatencies.cpp
1414
Pipeliner/LowerLoops.cpp
15+
Pipeliner/MMAv5PipelineUtility.cpp
1516
Pipeliner/ScheduleLoops.cpp
1617
Pipeliner/WGMMAPipeline.cpp
1718
Pipeliner/PipelineExpander.cpp

lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,13 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) {
900900
epilogueIf.erase();
901901
}
902902

903+
// Propagate warp specialization flags.
904+
if (outer->hasAttr(kWarpSpecializeAttrName) ||
905+
llvm::any_of(innerLoops, [](scf::ForOp loop) {
906+
return loop->hasAttr(kWarpSpecializeAttrName);
907+
}))
908+
fused->setAttr(kWarpSpecializeAttrName, b.getUnitAttr());
909+
903910
// Propagate the `tt.disallow_acc_multi_buffer` attribute to the parent loop.
904911
bool disallowAccMultiBuffer = getDisallowAccMultiBuffer(outer);
905912
for (scf::ForOp loop : innerLoops) {

0 commit comments

Comments
 (0)