Skip to content

Commit 7355662

Browse files
Merge OpenAI Triton commit 3523ab4 (#3569)
This PR change the Triton base from 0a8e3cc to 3523ab4 (Feb 25). Pass rate: 89.74%
2 parents cd4f49b + 602f559 commit 7355662

File tree

87 files changed

+4167
-2797
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

87 files changed

+4167
-2797
lines changed

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@ dev-install: dev-install-requires dev-install-triton
8989

9090
.PHONY: golden-samples
9191
golden-samples: triton-opt
92-
$(TRITON_OPT) test/TritonGPU/samples/simulated-grouped-gemm.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \
92+
$(TRITON_OPT) test/TritonGPU/samples/simulated-grouped-gemm.mlir.in -tritongpu-pipeline -canonicalize | \
9393
$(PYTHON) utils/generate-test-checks.py --source test/TritonGPU/samples/simulated-grouped-gemm.mlir.in --source_delim_regex="\bmodule" \
9494
-o test/TritonGPU/samples/simulated-grouped-gemm.mlir
95-
$(TRITON_OPT) test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \
95+
$(TRITON_OPT) test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in -tritongpu-pipeline -canonicalize | \
9696
$(PYTHON) utils/generate-test-checks.py --source test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in --source_delim_regex="\bmodule" \
9797
-o test/TritonGPU/samples/descriptor-matmul-pipeline.mlir

bin/RegisterTritonDialects.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,9 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
6868
mlir::triton::intel::registerConvertTritonToTritonGPUWarpPass();
6969
mlir::triton::intel::registerTritonIntelRemoveMasks();
7070
mlir::triton::intel::registerTritonRaiseBlockPointer();
71-
mlir::triton::registerAllocateSharedMemoryPass();
72-
mlir::triton::registerTritonGPUGlobalScratchAllocationPass();
71+
mlir::triton::gpu::registerAllocateSharedMemoryPass();
72+
mlir::triton::gpu::registerTritonGPUAllocateWarpGroups();
73+
mlir::triton::gpu::registerTritonGPUGlobalScratchAllocationPass();
7374
mlir::triton::registerConvertTritonGPUToLLVMPass();
7475
mlir::triton::registerConvertNVGPUToLLVMPass();
7576
mlir::registerLLVMDIScope();

include/triton/Analysis/Allocation.h

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -191,22 +191,19 @@ class Allocation {
191191
/// Virtual: triton.call
192192
enum class BufferKind { Explicit, Scratch, Virtual };
193193

194-
/// MT: thread-safe
195-
inline static std::atomic<BufferId> nextId = 0;
196-
197194
BufferKind kind;
198195
BufferId id;
196+
Operation *owner;
199197
size_t size;
200198
size_t alignment;
201199
size_t offset;
202200

203201
bool operator==(const BufferT &other) const { return id == other.id; }
204202
bool operator<(const BufferT &other) const { return id < other.id; }
205203

206-
BufferT() : BufferT(BufferKind::Explicit, 0) {}
207-
BufferT(BufferKind kind, size_t size, size_t alignment = 4,
208-
size_t offset = 0)
209-
: kind(kind), id(nextId++), size(size), alignment(alignment),
204+
BufferT(BufferKind kind, BufferId id, Operation *owner, size_t size,
205+
size_t alignment = 4, size_t offset = 0)
206+
: kind(kind), id(id), owner(owner), size(size), alignment(alignment),
210207
offset(offset) {}
211208

212209
size_t setOffsetAligned(size_t newOffset) {
@@ -226,14 +223,16 @@ class Allocation {
226223
private:
227224
template <BufferT::BufferKind Kind, typename KeyType, typename... Args>
228225
void addBuffer(KeyType &key, Args &&...args) {
229-
auto buffer = BufferT(Kind, std::forward<Args>(args)...);
230-
bufferSet[buffer.id] = std::move(buffer);
226+
BufferId nextId = bufferIdCounter++;
227+
auto [it, inserted] = bufferSet.insert_or_assign(
228+
nextId, BufferT(Kind, nextId, key, std::forward<Args>(args)...));
229+
BufferT *buffer = &it->second;
231230
if constexpr (Kind == BufferT::BufferKind::Explicit) {
232-
valueBuffer[key] = &bufferSet[buffer.id];
231+
valueBuffer[key] = buffer;
233232
} else if constexpr (Kind == BufferT::BufferKind::Virtual) {
234-
opVirtual[key] = &bufferSet[buffer.id];
233+
opVirtual[key] = buffer;
235234
} else {
236-
opScratch[key] = &bufferSet[buffer.id];
235+
opScratch[key] = buffer;
237236
}
238237
}
239238

@@ -250,6 +249,8 @@ class Allocation {
250249
BufferSetT bufferSet;
251250
size_t sharedMemorySize = 0;
252251

252+
size_t bufferIdCounter = 0;
253+
253254
friend class triton::AllocationAnalysis;
254255
};
255256

include/triton/Analysis/Membar.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ struct BlockInfo {
9797
// Shared Memory Barrier Analysis
9898
//===----------------------------------------------------------------------===//
9999
class MembarAnalysis {
100+
using VirtualBlock = std::pair<Block *, Block::iterator>;
101+
100102
public:
101103
using FuncBlockInfoMapT = CallGraph<BlockInfo>::FuncDataMapT;
102104
/// Creates a new Membar analysis that generates the shared memory barrier
@@ -143,7 +145,8 @@ class MembarAnalysis {
143145
FuncBlockInfoMapT *funcBlockInfoMap, OpBuilder *builder);
144146

145147
/// Collects the successors of the terminator
146-
void visitTerminator(Operation *operation, SmallVector<Block *> &successors);
148+
void visitTerminator(Operation *operation,
149+
SmallVector<VirtualBlock> &successors);
147150

148151
void insertBarrier(Operation *operation, OpBuilder *builder);
149152

include/triton/Conversion/TritonGPUToLLVM/Passes.h

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
#ifndef TRITONGPU_CONVERSION_TRITONGPUTOLLVM_PASSES_H
22
#define TRITONGPU_CONVERSION_TRITONGPUTOLLVM_PASSES_H
33

4-
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
54
#include "mlir/Pass/Pass.h"
6-
#include "mlir/Transforms/DialectConversion.h"
75

86
#include <memory>
97

@@ -12,22 +10,15 @@ namespace mlir {
1210
class ModuleOp;
1311
template <typename T> class OperationPass;
1412

15-
namespace triton {
13+
namespace triton::gpu {
1614

1715
#define GEN_PASS_DECL
1816
#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc"
1917

20-
namespace gpu {
21-
std::unique_ptr<OperationPass<ModuleOp>> createAllocateSharedMemoryPass();
22-
23-
std::unique_ptr<Pass> createTritonGPUGlobalScratchAllocationPass();
24-
25-
} // namespace gpu
26-
2718
#define GEN_PASS_REGISTRATION
2819
#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc"
2920

30-
} // namespace triton
21+
} // namespace triton::gpu
3122

3223
} // namespace mlir
3324

include/triton/Conversion/TritonGPUToLLVM/Passes.td

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
include "mlir/Pass/PassBase.td"
55

66
def AllocateSharedMemory : Pass<"allocate-shared-memory", "mlir::ModuleOp"> {
7-
let summary = "Add metadata for shared memory allocation";
8-
let description = [{
9-
This pass uses the `ModuleAllocation` analysis to:
10-
- Annotate modules with an attribute with the amount of shared/local
11-
memory used.
12-
- Annotate operations with an offset into the total shared/local memory.
13-
}];
14-
15-
let constructor = "mlir::triton::gpu::createAllocateSharedMemoryPass()";
7+
let summary = "Add metadata for shared memory allocation";
8+
9+
let description = [{
10+
This pass uses the `ModuleAllocation` analysis to:
11+
- Annotate modules with an attribute with the amount of shared/local
12+
memory used.
13+
- Annotate operations with an offset into the total shared/local memory.
14+
}];
1615
}
1716

1817
def TritonGPUGlobalScratchAllocationPass : Pass<"tritongpu-global-scratch-memory-allocation", "mlir::ModuleOp"> {
@@ -22,11 +21,25 @@ def TritonGPUGlobalScratchAllocationPass : Pass<"tritongpu-global-scratch-memory
2221
Decide on global scratch space memory allocation and assign attributes to each allocation.
2322
}];
2423

25-
let constructor = "mlir::triton::gpu::createTritonGPUGlobalScratchAllocationPass()";
26-
2724
let dependentDialects = [
2825
"mlir::triton::gpu::TritonGPUDialect"
2926
];
3027
}
3128

29+
def TritonGPUAllocateWarpGroups : Pass<"tritongpu-allocate-warp-groups", "mlir::ModuleOp"> {
30+
let summary = "Allocate warp groups";
31+
32+
let description = [{
33+
The `tritongpu-allocate-warp-groups` pass performs warpgroup allocation for
34+
a GPU program. When a GPU program contains warp specialization, additional
35+
warps are launched in addition to the "default" warp group. The "default"
36+
warpgroup executes top-level code in a `tt.func` and its size is specified
37+
by the user via the `num_warps` argument.
38+
39+
This pass analyzes `ttg.warp_specialize` ops in the program and determines
40+
the total number of needed warps, then attaches the range of warp IDs to
41+
each warpgroup function.
42+
}];
43+
}
44+
3245
#endif

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ class TargetInfoBase {
8989

9090
virtual int getSharedAddressSpace() const = 0;
9191

92+
virtual int getAddressSpace(Attribute addressSpace) const = 0;
93+
9294
virtual bool supportVectorizedAtomics() const = 0;
9395

9496
// Helper used by targets to annotate store operations during lowering to

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -575,15 +575,18 @@ Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale,
575575
// Hardware Indices
576576
// -----------------------------------------------------------------------
577577

578+
// If an operation is contained within a warp specialize region, this returns
579+
// the thread ID offset of that warpgroup.
580+
std::optional<int> getWarpGroupStartThreadId(Block *block);
581+
578582
// Returns CTA level thread ID.
579583
Value getThreadId(OpBuilder &rewriter, Location loc);
580584

581585
// Get the lane ID, which is index of the thread within its warp.
582-
Value getLaneId(OpBuilder &rewriter, Location loc, unsigned threadsPerWarp);
586+
Value getLaneId(OpBuilder &rewriter, Location loc);
583587

584588
// Get the lane ID and warp ID.
585-
std::pair<Value, Value> getLaneAndWarpId(OpBuilder &rewriter, Location loc,
586-
unsigned threadsPerWarp);
589+
std::pair<Value, Value> getLaneAndWarpId(OpBuilder &rewriter, Location loc);
587590

588591
// -----------------------------------------------------------------------
589592
// Shared memory utilities

include/triton/Dialect/Triton/IR/Traits.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ class SameLoadStoreOperandsAndResultEncoding
114114
}
115115
};
116116

117+
// This trait indicates that regions in the op may execute concurrently with
118+
// each other.
119+
template <typename ConcreteType>
120+
struct AsyncRegions : public TraitBase<ConcreteType, AsyncRegions> {};
121+
117122
} // namespace OpTrait
118123
} // namespace mlir
119124

include/triton/Dialect/Triton/IR/TritonInterfaces.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def SameLoadStoreOperandsShape : NativeOpTrait<"SameLoadStoreOperandsShape">;
1212
def SameLoadStoreOperandsAndResultShape : NativeOpTrait<"SameLoadStoreOperandsAndResultShape">;
1313
def SameLoadStoreOperandsEncoding : NativeOpTrait<"SameLoadStoreOperandsEncoding">;
1414
def SameLoadStoreOperandsAndResultEncoding : NativeOpTrait<"SameLoadStoreOperandsAndResultEncoding">;
15+
def AsyncRegions : NativeOpTrait<"AsyncRegions">;
1516

1617
// A trait equivalent to InferTypeOpAdaptor, but that checks for structural
1718
// equivalence of the layouts of the result rather than just layout equality.

0 commit comments

Comments
 (0)