Skip to content

Commit d67959e

Browse files
committed
Merge commit 'd183197524bb6abec1e22657e96b65ae039363c8'
2 parents 052efe4 + d183197 commit d67959e

File tree

20 files changed

+557
-270
lines changed

20 files changed

+557
-270
lines changed

.github/workflows/integration-tests-nvidia.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ jobs:
1313
integration-tests-nvidia:
1414
runs-on: ${{ matrix.runner }}
1515
timeout-minutes: 60
16+
# Let A100 and H100 continue even if GB200 fails, as it's a bit flaky
17+
continue-on-error: ${{ matrix.runner[0] == 'nvidia-gb200'}}
1618
strategy:
1719
matrix:
1820
runner: ${{ fromJson(inputs.matrix) }}

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -260,11 +260,12 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
260260

261261
/// Replace all uses of `old` with a local load from `alloc` unless the use is a
262262
/// `ttg.local_alloc` with a matching shared encoding, in which case the shared
263-
/// memory is forwarded directly into the use.
264-
void replaceUsesWithLocalLoad(
265-
OpBuilder &builder, OpResult old,
266-
TypedValue<triton::gpu::MemDescType> alloc,
267-
TypedValue<triton::gpu::AsyncTokenType> token = {});
263+
/// memory is forwarded directly into the use. Returns the `ttg.local_load` if
264+
/// it created one.
265+
triton::gpu::LocalLoadOp
266+
replaceUsesWithLocalLoad(OpBuilder &builder, OpResult old,
267+
TypedValue<triton::gpu::MemDescType> alloc,
268+
TypedValue<triton::gpu::AsyncTokenType> token = {});
268269

269270
// Return true if the value comes from a load or a block argument.
270271
// This will skip convert layouts and memdesc views.

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1532,9 +1532,10 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
15321532
op->erase();
15331533
}
15341534

1535-
void replaceUsesWithLocalLoad(OpBuilder &builder, OpResult old,
1536-
TypedValue<ttg::MemDescType> alloc,
1537-
TypedValue<ttg::AsyncTokenType> token) {
1535+
ttg::LocalLoadOp
1536+
replaceUsesWithLocalLoad(OpBuilder &builder, OpResult old,
1537+
TypedValue<ttg::MemDescType> alloc,
1538+
TypedValue<ttg::AsyncTokenType> token) {
15381539
// Remove redundant local_load -> local_alloc
15391540
auto allocTy = alloc.getType();
15401541
SmallVector<ttg::LocalAllocOp> allocsToErase;
@@ -1549,16 +1550,18 @@ void replaceUsesWithLocalLoad(OpBuilder &builder, OpResult old,
15491550

15501551
// If there are some uses that were not local_allocs, we need to create a
15511552
// local_load for them.
1553+
ttg::LocalLoadOp maybeLocalLoad;
15521554
if (std::distance(old.getUsers().begin(), old.getUsers().end()) >
15531555
allocsToErase.size()) {
15541556
auto loc = old.getOwner()->getLoc();
1555-
auto sharedLoad = builder.template create<ttg::LocalLoadOp>(
1557+
maybeLocalLoad = builder.template create<ttg::LocalLoadOp>(
15561558
loc, old.getType(), alloc, token);
1557-
old.replaceAllUsesWith(sharedLoad.getResult());
1559+
old.replaceAllUsesWith(maybeLocalLoad);
15581560
}
15591561
for (auto alloc : allocsToErase) {
15601562
alloc.erase();
15611563
}
1564+
return maybeLocalLoad;
15621565
}
15631566

15641567
bool comesFromLoadOrBlockArg(Value v) {

lib/Tools/GenericSwizzling.cpp

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,16 @@ SmallVector<int32_t> flatten(const LinearLayout &ll, StringAttr dim) {
4747
return vec;
4848
};
4949

50+
SmallVector<int32_t> removeZeros(ArrayRef<int32_t> vec) {
51+
SmallVector<int32_t> result;
52+
for (int32_t r : vec) {
53+
if (r != 0) {
54+
result.push_back(r);
55+
}
56+
}
57+
return result;
58+
}
59+
5060
// [1, 2, 4, 8] -> [[1], [2], [4], [8]]
5161
std::vector<std::vector<int32_t>> unflatten(ArrayRef<int32_t> basis) {
5262
std::vector<std::vector<int32_t>> unflattened;
@@ -279,6 +289,7 @@ LinearLayout optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
279289
auto *ctx = src.getInDimNames().begin()->getContext();
280290
auto kReg = StringAttr::get(ctx, "register");
281291
auto kLane = StringAttr::get(ctx, "lane");
292+
auto kWarp = StringAttr::get(ctx, "warp");
282293

283294
// We work on the flattened tensors as the tensor dimensions are not relevant
284295
const LinearLayout srcFlat = src.flattenOuts();
@@ -307,6 +318,65 @@ LinearLayout optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
307318
if (vbasis.size() > maxVecBases) {
308319
vbasis.resize(maxVecBases);
309320
}
321+
// We fill-up vbasis until it has 32 bits as best we can
322+
auto vecFillsBank = (1 << vbasis.size()) * bitwidth >= 32;
323+
if (!vecFillsBank) {
324+
auto warpSrc = removeZeros(flatten(srcFlat, kWarp));
325+
auto warpDst = removeZeros(flatten(dstFlat, kWarp));
326+
auto removeVec = [&vbasis](ArrayRef<int32_t> vec) {
327+
SmallVector<int32_t> result;
328+
for (int32_t r : vec) {
329+
if (!llvm::is_contained(vbasis, r)) {
330+
result.push_back(r);
331+
}
332+
}
333+
return result;
334+
};
335+
auto regSrcWarp = intersectionBasis(removeVec(regSrc), warpDst, dim);
336+
auto regDstWarp = intersectionBasis(removeVec(regDst), warpSrc, dim);
337+
// Maximise vectorisation in the load or the store without creating
338+
// conflicts
339+
SmallVector<int32_t> largest;
340+
if (regSrcWarp.size() == regDstWarp.size() && regSrcWarp.size() > 0) {
341+
// We choose the one with the lowest basis in the hope that it will
342+
// avoid PRMTs. The comparison of the mins will be strict as the sets
343+
// removeVec(regSrc) and removeVec(regDst) don't intersect
344+
if (*llvm::min_element(regSrcWarp) < *llvm::min_element(regDstWarp)) {
345+
largest = regSrcWarp;
346+
} else {
347+
largest = regDstWarp;
348+
}
349+
} else {
350+
largest = regSrcWarp.size() > regDstWarp.size() ? regSrcWarp : regDstWarp;
351+
}
352+
vbasis.append(largest.begin(), largest.end());
353+
if (vbasis.size() > maxVecBases) {
354+
vbasis.resize(maxVecBases);
355+
}
356+
// We allow vbasis.size > Log2_32(32 / bitwidth) at this point, as it is in
357+
// general good, but one should note
358+
if (vbasis.size() < llvm::Log2_32(32 / bitwidth)) {
359+
// Pad the vectorisation to 32 bits with warp bases
360+
auto warpSrcWarp = intersectionBasis(warpSrc, warpDst, dim);
361+
vbasis.append(warpSrcWarp.begin(), warpSrcWarp.end());
362+
}
363+
364+
int i = 0;
365+
while (vbasis.size() < llvm::Log2_32(32 / bitwidth) &&
366+
(i < warpSrc.size() || i < warpDst.size())) {
367+
// If we have not filled up a whole bank, we add more warp bases
368+
// until we have 32 bits. They will at least avoid bank conflicts in one
369+
// direction
370+
if (i < warpSrc.size() && !llvm::is_contained(vbasis, warpSrc[i])) {
371+
vbasis.push_back(warpSrc[i]);
372+
}
373+
if (vbasis.size() < llvm::Log2_32(32 / bitwidth) && i < warpDst.size() &&
374+
!llvm::is_contained(vbasis, warpDst[i])) {
375+
vbasis.push_back(warpDst[i]);
376+
}
377+
++i;
378+
}
379+
}
310380

311381
// Bits in a bank segment: 32 banks x 32 bits
312382
constexpr int32_t bankBits = 32 * 32;
@@ -321,8 +391,11 @@ LinearLayout optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
321391
auto bankDst = llvm::to_vector(llvm::concat<int32_t>(vbasis, laneDst));
322392

323393
// Whether we'll use b32.v1 / b32.v2 / b32.v4
324-
auto b32Vec =
325-
llvm::Log2_32(std::max<int32_t>((1 << vbasis.size()) * bitwidth / 32, 1));
394+
// FIXME: With !vecFillsBank we may use b32.v2 or b32.v4 for the load or
395+
// store, but we pesimistically assume we don't.
396+
auto b32Vec = !vecFillsBank ? 0
397+
: llvm::Log2_32(std::max<int32_t>(
398+
(1 << vbasis.size()) * bitwidth / 32, 1));
326399
// Drop the last vec bases of the banks
327400
bankSrc.resize(bankSrc.size() - b32Vec);
328401
bankDst.resize(bankDst.size() - b32Vec);

python/triton/experimental/gluon/language/_core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import triton.language.core as tl_core
1313
from triton.language.core import (
1414
constexpr,
15+
constexpr_function,
1516
base_value,
1617
base_type,
1718
dtype,
@@ -38,6 +39,7 @@
3839
float64,
3940
_unwrap_if_constexpr,
4041
_unwrap_shape,
42+
static_range,
4143
tensor,
4244
tuple,
4345
tuple_type,
@@ -68,6 +70,7 @@
6870

6971
__all__ = [
7072
"constexpr",
73+
"constexpr_function",
7174
"base_value",
7275
"base_type",
7376
"dtype",
@@ -105,6 +108,7 @@
105108
"allocate_shared_memory",
106109
"set_auto_layout",
107110
"shared_memory_descriptor",
111+
"static_range",
108112
"warp_specialize",
109113
*_IMPORT_FROM_TRITON,
110114
]

python/triton/experimental/gluon/language/_layouts.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,15 @@ def type(self):
233233
return constexpr_type(self)
234234

235235

236+
def _get_shape_per_cta(shape, cta_split_num):
237+
shape_per_cta = shape
238+
if cta_split_num is not None:
239+
assert len(cta_split_num) == len(shape)
240+
for dim in range(len(shape_per_cta)):
241+
shape_per_cta[dim] /= cta_split_num[dim]
242+
return shape_per_cta
243+
244+
236245
@dataclass(frozen=True)
237246
class NVMMASharedLayout(SharedLayout):
238247
"""
@@ -286,6 +295,47 @@ def _to_ir(self, builder):
286295
self.cta_order,
287296
)
288297

298+
@staticmethod
299+
def get_default_for(block_shape, dtype, transposed=False, fp4_padded=False, ctas_per_cga=None, cta_split_num=None,
300+
cta_order=None):
301+
"""Returns an NVMMASharedLayout with default swizzling for a given shape.
302+
303+
This picks the largest swizzle pattern compatible with the shape, which
304+
allows emitting the fewest TMA or MMA messages.
305+
"""
306+
packing_factor = 2 if fp4_padded else 1
307+
shape_per_cta = _get_shape_per_cta(block_shape, cta_split_num)
308+
rank = len(block_shape)
309+
if transposed:
310+
shape_per_cta = shape_per_cta[1:] + shape_per_cta[:1]
311+
contig_dim_size = shape_per_cta[-1] * packing_factor
312+
contig_dim_bytes = contig_dim_size * dtype.primitive_bitwidth // 8
313+
if contig_dim_bytes >= 128 and contig_dim_bytes % 128 == 0:
314+
swizzle_byte_width = 128
315+
elif contig_dim_bytes >= 64 and contig_dim_bytes % 64 == 0:
316+
swizzle_byte_width = 64
317+
elif contig_dim_bytes >= 32 and contig_dim_bytes % 32 == 0:
318+
swizzle_byte_width = 32
319+
else:
320+
swizzle_byte_width = 0
321+
322+
flatten_outer_dim = 1
323+
for size in shape_per_cta[:-1]:
324+
flatten_outer_dim *= size
325+
if len(block_shape) < 2 or flatten_outer_dim < 8:
326+
swizzle_byte_width = 0
327+
328+
return NVMMASharedLayout(
329+
swizzle_byte_width=swizzle_byte_width,
330+
element_bitwidth=dtype.primitive_bitwidth,
331+
rank=rank,
332+
transposed=transposed,
333+
fp4_padded=fp4_padded,
334+
ctas_per_cga=ctas_per_cga,
335+
cta_split_num=cta_split_num,
336+
cta_order=cta_order,
337+
)
338+
289339
def mangle(self) -> str:
290340
return f"NVMMA_{self.swizzle_byte_width}_{self.element_bitwidth}_{self.transposed}_{self.fp4_padded}_NVMMA"
291341

python/triton/tools/tensor_descriptor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from dataclasses import dataclass
22
from typing import List, Any
33
from triton._utils import validate_block_shape
4-
from torch._subclasses.fake_tensor import FakeTensor
5-
from torch._subclasses.functional_tensor import FunctionalTensor
64

75

86
@dataclass
@@ -18,7 +16,9 @@ def __post_init__(self):
1816
assert len(self.block_shape) == rank, f"rank mismatch: {self}"
1917
assert rank > 0, "rank must not be zero"
2018
assert rank <= 5, "rank cannot be more than 5"
21-
if not isinstance(self.base, (FakeTensor, FunctionalTensor)):
19+
ty = type(self.base)
20+
type_name = f"{ty.__module__}.{ty.__name__}"
21+
if type_name not in ("torch.FakeTensor", "torch.FunctionalTensor"):
2222
assert self.base.data_ptr() % 16 == 0, "base must be 16-byte aligned"
2323
validate_block_shape(self.block_shape)
2424
elem_bytes = self.base.dtype.itemsize

python/tutorials/gluon/01-attention-forward.py

Lines changed: 10 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import torch
22
import triton
3-
import triton.language as tl
43
import pytest
54
import itertools
65

@@ -25,7 +24,7 @@
2524
# ===-----------------------------------------------------------------------===#
2625

2726

28-
@tl.constexpr_function
27+
@gl.constexpr_function
2928
def get_tmem_32x32b_reg_layout(instr_shape, shape, num_warps):
3029
assert len(shape) == 2, "expected a 2D tensor"
3130
assert num_warps in [4, 8], "expected 4 or 8 warps"
@@ -61,45 +60,15 @@ def get_tmem_32x32b_reg_layout(instr_shape, shape, num_warps):
6160
)
6261

6362

64-
@tl.constexpr_function
63+
@gl.constexpr_function
6564
def get_mma_instr_shape(shape, element_ty):
6665
m = 128 if shape[0] >= 128 else 64
6766
n = 256 if shape[1] >= 256 else shape[1]
6867
k = 256 // element_ty.primitive_bitwidth
6968
return (m, n, k)
7069

7170

72-
@tl.constexpr_function
73-
def get_nvmma_layout(shape, element_ty, order=[1, 0], fp4_padded=False):
74-
packing_factor = 2 if fp4_padded else 1
75-
76-
contig_dim_size = shape[order[0]] * packing_factor * element_ty.primitive_bitwidth // 8
77-
if contig_dim_size >= 128 and contig_dim_size % 128 == 0:
78-
swizzle_byte_width = 128
79-
elif contig_dim_size >= 64 and contig_dim_size % 64 == 0:
80-
swizzle_byte_width = 64
81-
elif contig_dim_size >= 32 and contig_dim_size % 32 == 0:
82-
swizzle_byte_width = 32
83-
else:
84-
swizzle_byte_width = 0
85-
86-
flatten_outer_dim = 1
87-
for i in range(1, len(shape)):
88-
flatten_outer_dim *= shape[order[i]]
89-
if len(shape) < 2 or flatten_outer_dim < 8:
90-
swizzle_byte_width = 0
91-
transposed = order[0] == 0
92-
93-
return gl.NVMMASharedLayout(
94-
swizzle_byte_width=swizzle_byte_width,
95-
element_bitwidth=element_ty.primitive_bitwidth,
96-
rank=len(shape),
97-
transposed=transposed,
98-
fp4_padded=fp4_padded,
99-
)
100-
101-
102-
@tl.constexpr_function
71+
@gl.constexpr_function
10372
def get_mma_reg_layout(shape, num_warps, dtype=gl.float32):
10473
instr_shape = get_mma_instr_shape(shape, dtype)
10574
return get_tmem_32x32b_reg_layout(instr_shape, shape, num_warps)
@@ -133,7 +102,7 @@ def alloc(shape: gl.constexpr, dtype: gl.constexpr, layout: gl.constexpr, num_bu
133102
mem = alloc_fn(dtype, [num_buffers] + shape, layout)
134103
ready_bars = gl.allocate_shared_memory(gl.int64, [num_buffers, 1], mbarrier.MBarrierLayout())
135104
empty_bars = gl.allocate_shared_memory(gl.int64, [num_buffers, 1], mbarrier.MBarrierLayout())
136-
for i in tl.static_range(num_buffers):
105+
for i in gl.static_range(num_buffers):
137106
mbarrier.init(ready_bars.index(i), count=1)
138107
mbarrier.init(empty_bars.index(i), count=num_consumers)
139108
mbarrier.arrive(empty_bars.index(i), count=num_consumers)
@@ -179,7 +148,7 @@ def create_consumer(self):
179148
def release(self):
180149
if isinstance(self.mem, gl.shared_memory_descriptor):
181150
self.mem._keep_alive()
182-
for i in tl.static_range(self.num_buffers):
151+
for i in gl.static_range(self.num_buffers):
183152
mbarrier.invalidate(self.ready_bars.index(i))
184153
mbarrier.invalidate(self.empty_bars.index(i))
185154

@@ -847,7 +816,7 @@ def _attn_fwd_correction_rescale(config, s_tmem, corr_consumer, o_consumer):
847816
mbarrier.arrive(corr_bar, count=1)
848817
alpha = gl.convert_layout(alpha.reshape([config.SPLIT_M]), alpha_layout)
849818

850-
for i in tl.static_range(config.SPLIT_D_FACTOR):
819+
for i in gl.static_range(config.SPLIT_D_FACTOR):
851820
o_ref = o_tmem.slice(i * config.SPLIT_D, config.SPLIT_D)
852821
o = o_ref.load(config.o_splitn_layout)
853822
o = _mul_f32x2(o, alpha[:, None])
@@ -882,7 +851,7 @@ def _attn_fwd_correction_epilogue(config, prog, s_tmem, M, corr_consumer, epi_pr
882851
SPLIT_N: gl.constexpr = o_smem.type.shape[1] // SPLIT_N_FACTOR
883852

884853
scale = 1 / l_i
885-
for i in tl.static_range(SPLIT_N_FACTOR):
854+
for i in gl.static_range(SPLIT_N_FACTOR):
886855
o_ref = o_tmem.slice(i * SPLIT_N, SPLIT_N)
887856
o = o_ref.load(config.o_splitn_layout)
888857
o = _mul_f32x2(o, scale[:, None])
@@ -992,12 +961,12 @@ def attention_kernel( #
992961
def torch_dtype_to_triton(dtype):
993962
if dtype == torch.float8_e5m2:
994963
return gl.float8e5
995-
return getattr(tl, str(dtype).split('.')[1])
964+
return getattr(gl, str(dtype).split('.')[1])
996965

997966

998967
def make_tensor_desc(x, shape, strides, block_shape):
999-
layout = get_nvmma_layout(block_shape, torch_dtype_to_triton(x.dtype))
1000-
return TensorDescriptor(x, shape=shape, strides=strides, block_shape=block_shape, layout=layout.value)
968+
layout = gl.NVMMASharedLayout.get_default_for(block_shape, torch_dtype_to_triton(x.dtype))
969+
return TensorDescriptor(x, shape=shape, strides=strides, block_shape=block_shape, layout=layout)
1001970

1002971

1003972
def attention_forward(q, k, v, causal, sm_scale):

0 commit comments

Comments
 (0)