Skip to content

Commit bed2f70

Browse files
authored
Merge branch 'main' into feature/benchmarks-subset
2 parents d10efe0 + 98dca47 commit bed2f70

File tree

211 files changed

+10689
-3789
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

211 files changed

+10689
-3789
lines changed

.pre-commit-config.yaml

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
default_stages: [pre-commit, pre-push, manual]
12
repos:
23
- repo: https://github.com/pre-commit/pre-commit-hooks
3-
rev: v4.4.0
4+
rev: v5.0.0
45
hooks:
56
- id: check-symlinks
67
- id: destroyed-symlinks
@@ -17,12 +18,11 @@ repos:
1718
- id: debug-statements
1819

1920
- repo: https://github.com/astral-sh/ruff-pre-commit
20-
rev: v0.1.3
21+
rev: v0.7.1
2122
hooks:
2223
- id: ruff
2324
files: '^python/.*'
24-
args: ["--fix", "--line-length", "120"]
25-
stages: [pre-commit, pre-push, manual]
25+
args: ["--fix", "--exit-non-zero-on-fix"]
2626
exclude: |
2727
(?x)(
2828
^python/triton/runtime/.*|
@@ -31,18 +31,16 @@ repos:
3131
)
3232
3333
- repo: https://github.com/google/yapf
34-
rev: be72557
34+
rev: "7e21823"
3535
hooks:
3636
- id: yapf
3737
args: ["-p", "-i"]
38-
stages: [pre-commit, pre-push, manual]
3938
exclude: "python/test/unit/language/test_line_info.py"
4039

4140
- repo: https://github.com/pre-commit/mirrors-clang-format
42-
rev: v16.0.6
41+
rev: v19.1.2
4342
hooks:
4443
- id: clang-format
45-
stages: [pre-commit, pre-push, manual]
4644

4745
# Expand YAML anchors in files used by github workflows, because github can't
4846
# do this itself. This lets us use anchors, which avoids code duplication.

CMakeLists.txt

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ set(CMAKE_CXX_STANDARD 17)
1212

1313
set(CMAKE_INCLUDE_CURRENT_DIR ON)
1414

15-
project(triton)
15+
project(triton CXX)
1616
include(CTest)
1717

1818
if(NOT WIN32)
@@ -26,8 +26,25 @@ option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
2626
option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
2727
option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ON)
2828
option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" ON)
29+
option(TRITON_BUILD_WITH_CCACHE "Build with ccache (if available)" ON)
2930
set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends")
3031

32+
if(TRITON_BUILD_WITH_CCACHE)
33+
find_program(CCACHE_PROGRAM ccache)
34+
if(CCACHE_PROGRAM)
35+
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}"
36+
CACHE STRING "C compiler launcher")
37+
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}"
38+
CACHE STRING "CXX compiler launcher")
39+
else()
40+
message(
41+
STATUS
42+
"Could not find ccache. Consider installing ccache to speed up compilation."
43+
)
44+
endif()
45+
endif()
46+
47+
3148
# Ensure Python3 vars are set correctly
3249
# used conditionally in this file and by lit tests
3350

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,11 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
7878
start_m = tl.program_id(2)
7979
off_z = tl.program_id(0)
8080
off_h = tl.program_id(1)
81+
qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
8182
if N_CTX <= 512:
8283
start_m = tl.program_id(0)
8384
off_z = tl.program_id(2)
84-
qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
85+
qvk_offset = off_z.to(tl.int64) * stride_qh
8586

8687
# block pointers
8788
Q_block_ptr = tl.make_block_ptr(
@@ -181,7 +182,7 @@ def forward(q, k, v, causal, sm_scale):
181182
grid = lambda args: (q.shape[0], q.shape[1], triton.cdiv(q.shape[2], args['BLOCK_M']))
182183
n_ctx = q.shape[2]
183184
if n_ctx <= 512:
184-
grid = lambda args: (triton.cdiv(q.shape[2], args['BLOCK_M']), q.shape[1], q.shape[0])
185+
grid = lambda args: (triton.cdiv(q.shape[2], args['BLOCK_M']), 1, q.shape[0] * q.shape[1])
185186
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
186187

187188
if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0':

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ def matmul_kernel_with_block_pointers_batched(
129129
stride_cz: tl.constexpr, stride_cm: tl.constexpr, stride_cn: tl.constexpr,
130130
# Meta-parameters
131131
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
132-
bid = tl.program_id(axis=0)
133-
pid = tl.program_id(axis=1)
132+
bid = tl.program_id(axis=1)
133+
pid = tl.program_id(axis=0)
134134
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
135135
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
136136
num_pid_in_group = GROUP_SIZE_M * num_pid_n
@@ -186,8 +186,8 @@ def matmul(a, b, c, transpose_a=False, transpose_b=False):
186186
B = a.shape[0]
187187
# 1D launch kernel where each block gets its own program.
188188
grid = lambda META: (
189-
B,
190189
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
190+
B,
191191
)
192192
matmul_kernel_with_block_pointers_batched[grid](
193193
a, b, c, #

bin/RegisterTritonDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
8888
mlir::registerTritonAMDGPUStreamPipelineV2();
8989
mlir::registerTritonAMDGPUCanonicalizePointers();
9090
mlir::registerTritonAMDGPUConvertToBufferOps();
91+
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();
92+
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();
9193

9294
// TODO: register Triton & TritonGPU passes
9395
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,

cmake/llvm-hash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
b5cc222d7429fe6f18c787f633d5262fac2e676f
1+
fa57c7a6a5f594a9e3ae2dbe3542cf89a20cdd73

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def documenter(app, obj, parent):
145145
autosummary_generate = True
146146

147147
# versioning config
148-
smv_tag_whitelist = r'^(v3.1.0)$'
148+
smv_tag_whitelist = r'^(v3.2.0)$'
149149
smv_branch_whitelist = r'^main$'
150150
smv_remote_whitelist = None
151151
smv_released_pattern = r'^tags/.*$'

docs/python-api/triton-semantics.rst

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@ The algorithm is as follows:
1414

1515
2. **Width** If both tensors are of dtypes of the same kind, and one of them is of a higher width, the other one is promoted to this dtype: ``(float32, float16) -> float32``
1616

17-
3. **Supremum** If both tensors are of the same width and signedness but different dtypes, they are both promoted to the next larger dtype. ``(float16, bfloat16) -> float32``
18-
19-
3.1 If both tensors are of different ``fp8`` dtypes, they are both cast to ``float16``.
17+
3. **Prefer float16** If both tensors are of the same width and signedness but different dtypes (``float16`` and ``bfloat16`` or different ``fp8`` types), they are both promoted to ``float16``. ``(float16, bfloat16) -> float16``
2018

2119
4. **Prefer unsigned** Otherwise (same width, different signedness), they are promoted to the unsigned dtype: ``(int32, uint32) -> uint32``
2220

docs/python-api/triton.language.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ Linear Algebra Ops
5959
:nosignatures:
6060

6161
dot
62+
dot_scaled
6263

6364

6465
Memory/Pointer Ops

include/triton/Analysis/Allocation.h

Lines changed: 59 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,45 @@ class Allocation {
9393
using BufferIdSetT = DenseSet<BufferId>;
9494
using FuncAllocMapT = CallGraph<Allocation>::FuncDataMapT;
9595

96+
/// A class that represents a shared memory buffer
97+
struct BufferT {
98+
/// Explicit: triton_gpu.local_alloc
99+
/// Scratch: triton_gpu.convert_layout
100+
/// Virtual: triton.call
101+
enum class BufferKind { Explicit, Scratch, Virtual };
102+
103+
/// MT: thread-safe
104+
inline static std::atomic<BufferId> nextId = 0;
105+
106+
BufferKind kind;
107+
BufferId id;
108+
size_t size;
109+
size_t alignment;
110+
size_t offset;
111+
112+
bool operator==(const BufferT &other) const { return id == other.id; }
113+
bool operator<(const BufferT &other) const { return id < other.id; }
114+
115+
BufferT() : BufferT(BufferKind::Explicit, 0) {}
116+
BufferT(BufferKind kind, size_t size, size_t alignment = 4,
117+
size_t offset = 0)
118+
: kind(kind), id(nextId++), size(size), alignment(alignment),
119+
offset(offset) {}
120+
121+
size_t setOffsetAligned(size_t newOffset) {
122+
return offset = llvm::alignTo(newOffset, alignment);
123+
}
124+
};
125+
126+
/// Op -> Scratch Buffer
127+
using OpScratchMapT = DenseMap<Operation *, BufferT *>;
128+
/// Value -> Explicit Buffer
129+
using ValueBufferMapT = llvm::MapVector<Value, BufferT *>;
130+
/// Value -> Alias Buffer
131+
using AliasBufferMapT = llvm::MapVector<Value, llvm::SetVector<BufferT *>>;
132+
/// BufferId -> Buffer
133+
using BufferSetT = std::map<BufferId, BufferT>;
134+
96135
static constexpr BufferId InvalidBufferId =
97136
std::numeric_limits<BufferId>::max();
98137

@@ -102,11 +141,17 @@ class Allocation {
102141
explicit Allocation(Operation *operation) : operation(operation) {}
103142

104143
/// Runs allocation analysis on the given top-level operation.
105-
void run(FuncAllocMapT &funcAllocMap);
144+
template <typename AllocationAnalysis> void run(FuncAllocMapT &funcAllocMap);
106145

107146
/// Returns the operation this analysis was constructed from.
108147
Operation *getOperation() const { return operation; }
109148

149+
const OpScratchMapT &getOpScratch() const { return opScratch; }
150+
const OpScratchMapT &getOpVirtual() const { return opVirtual; }
151+
const ValueBufferMapT &getValueBuffer() const { return valueBuffer; }
152+
const AliasBufferMapT &getAliasBuffer() const { return aliasBuffer; }
153+
void setSharedMemorySize(size_t size) { sharedMemorySize = size; }
154+
110155
/// Returns the offset of the given buffer in the shared memory.
111156
size_t getOffset(BufferId bufferId) const {
112157
return bufferSet.at(bufferId).offset;
@@ -170,47 +215,6 @@ class Allocation {
170215
/// Returns mapping from operation to list of live LDS buffers
171216
std::map<Operation *, SmallVector<BufferId>> getLiveBuffers();
172217

173-
private:
174-
/// A class that represents a shared memory buffer
175-
struct BufferT {
176-
/// Explicit: triton_gpu.local_alloc
177-
/// Scratch: triton_gpu.convert_layout
178-
/// Virtual: triton.call
179-
enum class BufferKind { Explicit, Scratch, Virtual };
180-
181-
/// MT: thread-safe
182-
inline static std::atomic<BufferId> nextId = 0;
183-
184-
BufferKind kind;
185-
BufferId id;
186-
size_t size;
187-
size_t alignment;
188-
size_t offset;
189-
190-
bool operator==(const BufferT &other) const { return id == other.id; }
191-
bool operator<(const BufferT &other) const { return id < other.id; }
192-
193-
BufferT() : BufferT(BufferKind::Explicit, 0) {}
194-
BufferT(BufferKind kind, size_t size, size_t alignment = 4,
195-
size_t offset = 0)
196-
: kind(kind), id(nextId++), size(size), alignment(alignment),
197-
offset(offset) {}
198-
199-
size_t setOffsetAligned(size_t newOffset) {
200-
return offset = llvm::alignTo(newOffset, alignment);
201-
}
202-
};
203-
204-
/// Op -> Scratch Buffer
205-
using OpScratchMapT = DenseMap<Operation *, BufferT *>;
206-
/// Value -> Explicit Buffer
207-
using ValueBufferMapT = llvm::MapVector<Value, BufferT *>;
208-
/// Value -> Alias Buffer
209-
using AliasBufferMapT = llvm::MapVector<Value, llvm::SetVector<BufferT *>>;
210-
/// BufferId -> Buffer
211-
using BufferSetT = std::map<BufferId, BufferT>;
212-
213-
private:
214218
template <BufferT::BufferKind Kind, typename KeyType, typename... Args>
215219
void addBuffer(KeyType &key, Args &&...args) {
216220
auto buffer = BufferT(Kind, std::forward<Args>(args)...);
@@ -236,10 +240,11 @@ class Allocation {
236240
AliasBufferMapT aliasBuffer;
237241
BufferSetT bufferSet;
238242
size_t sharedMemorySize = 0;
239-
240-
friend class triton::AllocationAnalysis;
241243
};
242244

245+
template <>
246+
void Allocation::run<triton::AllocationAnalysis>(FuncAllocMapT &funcAllocMap);
247+
243248
/// Static analysis that computes the allocation of shared memory buffers
244249
/// of the entire call graph.
245250
/// The allocation is performed in a post-order walk of the call graph.
@@ -250,17 +255,19 @@ class ModuleAllocation : public CallGraph<Allocation> {
250255
public:
251256
using FuncOffsetMapT = DenseMap<FunctionOpInterface, Value>;
252257

253-
explicit ModuleAllocation(ModuleOp moduleOp)
254-
: CallGraph<Allocation>(moduleOp) {
255-
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
258+
template <typename AllocationAnalysis = triton::AllocationAnalysis>
259+
static ModuleAllocation get(ModuleOp moduleOp) {
260+
ModuleAllocation res(moduleOp);
261+
res.walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
256262
// Pre-order edge walk callback
257263
[](CallOpInterface callOp, FunctionOpInterface funcOp) {},
258264
// Post-order node walk callback
259265
[&](FunctionOpInterface funcOp) {
260-
auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp);
266+
auto [iter, inserted] = res.funcMap.try_emplace(funcOp, funcOp);
261267
if (inserted)
262-
iter->second.run(funcMap);
268+
iter->second.template run<AllocationAnalysis>(res.funcMap);
263269
});
270+
return res;
264271
}
265272

266273
size_t getSharedMemorySize() {
@@ -285,6 +292,9 @@ class ModuleAllocation : public CallGraph<Allocation> {
285292
}
286293

287294
private:
295+
explicit ModuleAllocation(ModuleOp moduleOp)
296+
: CallGraph<Allocation>(moduleOp) {}
297+
288298
FuncOffsetMapT sharedMemoryValue;
289299
};
290300

0 commit comments

Comments
 (0)