Skip to content

Commit 5942691

Browse files
Merge commit '3523ab4b4a8116922f400ef93e2d71514236eac2'
2 parents 3796f3f + 3523ab4 commit 5942691

File tree

78 files changed

+4089
-2794
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

78 files changed

+4089
-2794
lines changed

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@ dev-install: dev-install-requires dev-install-triton
8989

9090
.PHONY: golden-samples
9191
golden-samples: triton-opt
92-
$(TRITON_OPT) test/TritonGPU/samples/simulated-grouped-gemm.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \
92+
$(TRITON_OPT) test/TritonGPU/samples/simulated-grouped-gemm.mlir.in -tritongpu-pipeline -canonicalize | \
9393
$(PYTHON) utils/generate-test-checks.py --source test/TritonGPU/samples/simulated-grouped-gemm.mlir.in --source_delim_regex="\bmodule" \
9494
-o test/TritonGPU/samples/simulated-grouped-gemm.mlir
95-
$(TRITON_OPT) test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \
95+
$(TRITON_OPT) test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in -tritongpu-pipeline -canonicalize | \
9696
$(PYTHON) utils/generate-test-checks.py --source test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in --source_delim_regex="\bmodule" \
9797
-o test/TritonGPU/samples/descriptor-matmul-pipeline.mlir

bin/RegisterTritonDialects.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,9 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
6868
mlir::triton::intel::registerConvertTritonToTritonGPUWarpPass();
6969
mlir::triton::intel::registerTritonIntelRemoveMasks();
7070
mlir::triton::intel::registerTritonRaiseBlockPointer();
71-
mlir::triton::registerAllocateSharedMemoryPass();
72-
mlir::triton::registerTritonGPUGlobalScratchAllocationPass();
71+
mlir::triton::gpu::registerAllocateSharedMemoryPass();
72+
mlir::triton::gpu::registerTritonGPUAllocateWarpGroups();
73+
mlir::triton::gpu::registerTritonGPUGlobalScratchAllocationPass();
7374
mlir::triton::registerConvertTritonGPUToLLVMPass();
7475
mlir::triton::registerConvertNVGPUToLLVMPass();
7576
mlir::registerLLVMDIScope();

include/triton/Analysis/Allocation.h

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -191,22 +191,19 @@ class Allocation {
191191
/// Virtual: triton.call
192192
enum class BufferKind { Explicit, Scratch, Virtual };
193193

194-
/// MT: thread-safe
195-
inline static std::atomic<BufferId> nextId = 0;
196-
197194
BufferKind kind;
198195
BufferId id;
196+
Operation *owner;
199197
size_t size;
200198
size_t alignment;
201199
size_t offset;
202200

203201
bool operator==(const BufferT &other) const { return id == other.id; }
204202
bool operator<(const BufferT &other) const { return id < other.id; }
205203

206-
BufferT() : BufferT(BufferKind::Explicit, 0) {}
207-
BufferT(BufferKind kind, size_t size, size_t alignment = 4,
208-
size_t offset = 0)
209-
: kind(kind), id(nextId++), size(size), alignment(alignment),
204+
BufferT(BufferKind kind, BufferId id, Operation *owner, size_t size,
205+
size_t alignment = 4, size_t offset = 0)
206+
: kind(kind), id(id), owner(owner), size(size), alignment(alignment),
210207
offset(offset) {}
211208

212209
size_t setOffsetAligned(size_t newOffset) {
@@ -226,14 +223,16 @@ class Allocation {
226223
private:
227224
template <BufferT::BufferKind Kind, typename KeyType, typename... Args>
228225
void addBuffer(KeyType &key, Args &&...args) {
229-
auto buffer = BufferT(Kind, std::forward<Args>(args)...);
230-
bufferSet[buffer.id] = std::move(buffer);
226+
BufferId nextId = bufferIdCounter++;
227+
auto [it, inserted] = bufferSet.insert_or_assign(
228+
nextId, BufferT(Kind, nextId, key, std::forward<Args>(args)...));
229+
BufferT *buffer = &it->second;
231230
if constexpr (Kind == BufferT::BufferKind::Explicit) {
232-
valueBuffer[key] = &bufferSet[buffer.id];
231+
valueBuffer[key] = buffer;
233232
} else if constexpr (Kind == BufferT::BufferKind::Virtual) {
234-
opVirtual[key] = &bufferSet[buffer.id];
233+
opVirtual[key] = buffer;
235234
} else {
236-
opScratch[key] = &bufferSet[buffer.id];
235+
opScratch[key] = buffer;
237236
}
238237
}
239238

@@ -250,6 +249,8 @@ class Allocation {
250249
BufferSetT bufferSet;
251250
size_t sharedMemorySize = 0;
252251

252+
size_t bufferIdCounter = 0;
253+
253254
friend class triton::AllocationAnalysis;
254255
};
255256

include/triton/Analysis/Membar.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ struct BlockInfo {
9797
// Shared Memory Barrier Analysis
9898
//===----------------------------------------------------------------------===//
9999
class MembarAnalysis {
100+
using VirtualBlock = std::pair<Block *, Block::iterator>;
101+
100102
public:
101103
using FuncBlockInfoMapT = CallGraph<BlockInfo>::FuncDataMapT;
102104
/// Creates a new Membar analysis that generates the shared memory barrier
@@ -143,7 +145,8 @@ class MembarAnalysis {
143145
FuncBlockInfoMapT *funcBlockInfoMap, OpBuilder *builder);
144146

145147
/// Collects the successors of the terminator
146-
void visitTerminator(Operation *operation, SmallVector<Block *> &successors);
148+
void visitTerminator(Operation *operation,
149+
SmallVector<VirtualBlock> &successors);
147150

148151
void insertBarrier(Operation *operation, OpBuilder *builder);
149152

include/triton/Conversion/TritonGPUToLLVM/Passes.h

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
#ifndef TRITONGPU_CONVERSION_TRITONGPUTOLLVM_PASSES_H
22
#define TRITONGPU_CONVERSION_TRITONGPUTOLLVM_PASSES_H
33

4-
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
54
#include "mlir/Pass/Pass.h"
6-
#include "mlir/Transforms/DialectConversion.h"
75

86
#include <memory>
97

@@ -12,22 +10,15 @@ namespace mlir {
1210
class ModuleOp;
1311
template <typename T> class OperationPass;
1412

15-
namespace triton {
13+
namespace triton::gpu {
1614

1715
#define GEN_PASS_DECL
1816
#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc"
1917

20-
namespace gpu {
21-
std::unique_ptr<OperationPass<ModuleOp>> createAllocateSharedMemoryPass();
22-
23-
std::unique_ptr<Pass> createTritonGPUGlobalScratchAllocationPass();
24-
25-
} // namespace gpu
26-
2718
#define GEN_PASS_REGISTRATION
2819
#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc"
2920

30-
} // namespace triton
21+
} // namespace triton::gpu
3122

3223
} // namespace mlir
3324

include/triton/Conversion/TritonGPUToLLVM/Passes.td

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
include "mlir/Pass/PassBase.td"
55

66
def AllocateSharedMemory : Pass<"allocate-shared-memory", "mlir::ModuleOp"> {
7-
let summary = "Add metadata for shared memory allocation";
8-
let description = [{
9-
This pass uses the `ModuleAllocation` analysis to:
10-
- Annotate modules with an attribute with the amount of shared/local
11-
memory used.
12-
- Annotate operations with an offset into the total shared/local memory.
13-
}];
14-
15-
let constructor = "mlir::triton::gpu::createAllocateSharedMemoryPass()";
7+
let summary = "Add metadata for shared memory allocation";
8+
9+
let description = [{
10+
This pass uses the `ModuleAllocation` analysis to:
11+
- Annotate modules with an attribute with the amount of shared/local
12+
memory used.
13+
- Annotate operations with an offset into the total shared/local memory.
14+
}];
1615
}
1716

1817
def TritonGPUGlobalScratchAllocationPass : Pass<"tritongpu-global-scratch-memory-allocation", "mlir::ModuleOp"> {
@@ -22,11 +21,25 @@ def TritonGPUGlobalScratchAllocationPass : Pass<"tritongpu-global-scratch-memory
2221
Decide on global scratch space memory allocation and assign attributes to each allocation.
2322
}];
2423

25-
let constructor = "mlir::triton::gpu::createTritonGPUGlobalScratchAllocationPass()";
26-
2724
let dependentDialects = [
2825
"mlir::triton::gpu::TritonGPUDialect"
2926
];
3027
}
3128

29+
def TritonGPUAllocateWarpGroups : Pass<"tritongpu-allocate-warp-groups", "mlir::ModuleOp"> {
30+
let summary = "Allocate warp groups";
31+
32+
let description = [{
33+
The `tritongpu-allocate-warp-groups` pass performs warpgroup allocation for
34+
a GPU program. When a GPU program contains warp specialization, additional
35+
warps are launched in addition to the "default" warp group. The "default"
36+
warpgroup executes top-level code in a `tt.func` and its size is specified
37+
by the user via the `num_warps` argument.
38+
39+
This pass analyzes `ttg.warp_specialize` ops in the program and determines
40+
the total number of needed warps, then attaches the range of warp IDs to
41+
each warpgroup function.
42+
}];
43+
}
44+
3245
#endif

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -575,15 +575,18 @@ Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale,
575575
// Hardware Indices
576576
// -----------------------------------------------------------------------
577577

578+
// If an operation is contained within a warp specialize region, this returns
579+
// the thread ID offset of that warpgroup.
580+
std::optional<int> getWarpGroupStartThreadId(Block *block);
581+
578582
// Returns CTA level thread ID.
579583
Value getThreadId(OpBuilder &rewriter, Location loc);
580584

581585
// Get the lane ID, which is index of the thread within its warp.
582-
Value getLaneId(OpBuilder &rewriter, Location loc, unsigned threadsPerWarp);
586+
Value getLaneId(OpBuilder &rewriter, Location loc);
583587

584588
// Get the lane ID and warp ID.
585-
std::pair<Value, Value> getLaneAndWarpId(OpBuilder &rewriter, Location loc,
586-
unsigned threadsPerWarp);
589+
std::pair<Value, Value> getLaneAndWarpId(OpBuilder &rewriter, Location loc);
587590

588591
// -----------------------------------------------------------------------
589592
// Shared memory utilities

include/triton/Dialect/Triton/IR/Traits.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ class SameLoadStoreOperandsAndResultEncoding
114114
}
115115
};
116116

117+
// This trait indicates that regions in the op may execute concurrently with
118+
// each other.
119+
template <typename ConcreteType>
120+
struct AsyncRegions : public TraitBase<ConcreteType, AsyncRegions> {};
121+
117122
} // namespace OpTrait
118123
} // namespace mlir
119124

include/triton/Dialect/Triton/IR/TritonInterfaces.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def SameLoadStoreOperandsShape : NativeOpTrait<"SameLoadStoreOperandsShape">;
1212
def SameLoadStoreOperandsAndResultShape : NativeOpTrait<"SameLoadStoreOperandsAndResultShape">;
1313
def SameLoadStoreOperandsEncoding : NativeOpTrait<"SameLoadStoreOperandsEncoding">;
1414
def SameLoadStoreOperandsAndResultEncoding : NativeOpTrait<"SameLoadStoreOperandsAndResultEncoding">;
15+
def AsyncRegions : NativeOpTrait<"AsyncRegions">;
1516

1617
// A trait equivalent to InferTypeOpAdaptor, but that checks for structural
1718
// equivalence of the layouts of the result rather than just layout equality.

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def TTG_GlobalScratchAllocOp : TTG_Op<"global_scratch_alloc", [MemoryEffects<[Me
333333
}
334334

335335
def TTG_WarpSpecializeOp : TTG_Op<"warp_specialize", [
336-
RecursiveMemoryEffects, RecursivelySpeculatable,
336+
RecursiveMemoryEffects, RecursivelySpeculatable, AsyncRegions,
337337
DeclareOpInterfaceMethods<RegionBranchOpInterface>
338338
]> {
339339
let summary = "asynchronously execute code on multiple warpgroups";
@@ -362,21 +362,24 @@ def TTG_WarpSpecializeOp : TTG_Op<"warp_specialize", [
362362
}
363363
partition0(%arg0: i32, %arg1: i32) num_warps(8) {
364364
some_async_dispatch(%arg0, %arg1)
365+
ttg.warp_return
365366
}
366367
partition1(%arg0: i32, %arg1: i32) num_warps(1) {
367368
some_async_dispatch(%arg0, %arg1)
369+
ttg.warp_return
368370
} : (i32, i32) -> i32
369371
```
370372
}];
371373

372374
let arguments = (ins
373375
Variadic<AnyType>:$explicitCaptures,
374-
DenseI32ArrayAttr:$partitionNumWarps
376+
DenseI32ArrayAttr:$partitionNumWarps,
377+
OptionalAttr<DenseI32ArrayAttr>:$warpGroupStartIds
375378
);
376379
let results = (outs Variadic<AnyType>:$defaultPassthrough);
377380

378381
let regions = (region
379-
SizedRegion<1>:$defaultRegion,
382+
MinSizedRegion<1>:$defaultRegion,
380383
SizedRegion<1>:$partitionOpHolder
381384
);
382385

@@ -390,20 +393,19 @@ def TTG_WarpSpecializeOp : TTG_Op<"warp_specialize", [
390393

391394
def TTG_WarpSpecializePartitionsOp : TTG_Op<"warp_specialize.partitions", [
392395
IsolatedFromAbove, RecursiveMemoryEffects, RecursivelySpeculatable,
393-
Terminator, HasParent<"WarpSpecializeOp">,
394-
SingleBlockImplicitTerminator<"WarpReturnOp">
396+
Terminator, HasParent<"WarpSpecializeOp">
395397
]> {
396398
let summary = "container op for `ttg.warp_specialize`";
397399
let description = [{
398400
Because MLIR requires entire operations be isolated from above, this op
399401
contains the actual isolated from above regions of `ttg.warp_specialize`.
400402
}];
401403

402-
let regions = (region VariadicRegion<SizedRegion<1>>:$partitionRegions);
404+
let regions = (region VariadicRegion<MinSizedRegion<1>>:$partitionRegions);
403405
}
404406

405407
def TTG_WarpYieldOp : TTG_Op<"warp_yield", [
406-
Pure, Terminator, HasParent<"WarpSpecializeOp">,
408+
Pure, Terminator, ReturnLike, HasParent<"WarpSpecializeOp">,
407409
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>
408410
]> {
409411
let summary = "yield from the default region of `ttg.warp_specialize`";
@@ -422,6 +424,7 @@ def TTG_WarpYieldOp : TTG_Op<"warp_yield", [
422424
let arguments = (ins Variadic<AnyType>:$values);
423425

424426
let assemblyFormat = "($values^)? attr-dict (`:` type($values)^)?";
427+
let hasVerifier = 1;
425428
}
426429

427430
def TTG_WarpReturnOp : TTG_Op<"warp_return", [

0 commit comments

Comments
 (0)