Skip to content

Commit a9cd5c7

Browse files
Merge OpenAI Triton commit 9f21c06 (#5469)
This PR changes the Triton base from 40dd0c4 to 9f21c06 (Oct 27). Pass rate: 94.91%
2 parents 8d220b9 + b308b6f commit a9cd5c7

File tree

28 files changed

+243
-149
lines changed

28 files changed

+243
-149
lines changed

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,11 @@ LogicalResult MemDescIndexOp::verify() {
809809
return emitError("src and dst must have the same type of encoding");
810810
}
811811

812+
if (dstTy.getAllocShape() != dstTy.getShape() ||
813+
srcTy.getAllocShape() != srcTy.getShape()) {
814+
return emitError("alloc shape must match shape for both result and src");
815+
}
816+
812817
if (isa<triton::nvidia_gpu::TensorMemoryEncodingAttr>(srcEnc)) {
813818
// We support only 3D -> 2D subviews with only first offset being non-zero.
814819
if (srcTy.getRank() != 3 || dstTy.getRank() != 2) {

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -687,8 +687,7 @@ triton::createSingleBufferView(OpBuilder &builder, Value alloc, Value idx) {
687687
allocDescType.getShape().end());
688688
auto viewDescType = ttg::MemDescType::get(
689689
shape, allocDescType.getElementType(), allocDescType.getEncoding(),
690-
allocDescType.getMemorySpace(), allocDescType.getMutableMemory(),
691-
/*allocShape=*/allocDescType.getAllocShape());
690+
allocDescType.getMemorySpace(), allocDescType.getMutableMemory());
692691
return builder.create<ttg::MemDescIndexOp>(alloc.getLoc(), viewDescType,
693692
alloc, idx);
694693
}

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1566,7 +1566,7 @@ void replaceUsesAndPropagateType(
15661566
bool isMutable = cast<ttg::MemDescType>(val.getType()).getMutableMemory();
15671567
Type newDstType = ttg::MemDescType::get(
15681568
oldType.getShape(), oldType.getElementType(), oldType.getEncoding(),
1569-
oldType.getMemorySpace(), isMutable, oldType.getAllocShape());
1569+
oldType.getMemorySpace(), isMutable);
15701570
newVal = builder.create<ttg::MemDescIndexOp>(subview.getLoc(), newDstType,
15711571
val, subview.getIndex());
15721572
} else if (auto subslice = dyn_cast<ttg::MemDescSubsliceOp>(user)) {

python/test/gluon/test_frontend.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,8 @@ def test_tensor_memory():
258258
%4 = arith.bitcast %c1_i32 : i32 to i32
259259
%5 = ub.poison : i32
260260
scf.for %arg0 = %2 to %3 step %4 : i32 {
261-
%6 = ttg.memdesc_index %result_2[%arg0] : !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128>
262-
%result_4 = ttng.tmem_load %6 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> -> tensor<128x128xf32, #blocked>
261+
%6 = ttg.memdesc_index %result_2[%arg0] : !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
262+
%result_4 = ttng.tmem_load %6 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
263263
}
264264
tt.return
265265
}
@@ -338,8 +338,8 @@ def test_shared_memory_index(target):
338338
%3 = arith.bitcast %c1_i32 : i32 to i32
339339
%4 = ub.poison : i32
340340
scf.for %arg0 = %1 to %2 step %3 : i32 {
341-
%5 = ttg.memdesc_index %0[%arg0] : !ttg.memdesc<4x256xi32, #shared, #smem, mutable> -> !ttg.memdesc<256xi32, #shared, #smem, mutable, 4x256>
342-
%6 = ttg.local_load %5 : !ttg.memdesc<256xi32, #shared, #smem, mutable, 4x256> -> tensor<256xi32, #blocked>
341+
%5 = ttg.memdesc_index %0[%arg0] : !ttg.memdesc<4x256xi32, #shared, #smem, mutable> -> !ttg.memdesc<256xi32, #shared, #smem, mutable>
342+
%6 = ttg.local_load %5 : !ttg.memdesc<256xi32, #shared, #smem, mutable> -> tensor<256xi32, #blocked>
343343
}
344344
tt.return
345345
}
@@ -408,15 +408,15 @@ def test_shared_memory_cast(target):
408408
tt.func public @shared_memory_cast_kernel() attributes {noinline = false} {
409409
%0 = ttg.local_alloc : () -> !ttg.memdesc<2x256x128xi8, #shared, #smem, mutable>
410410
%c0_i32 = arith.constant 0 : i32
411-
%1 = ttg.memdesc_index %0[%c0_i32] : !ttg.memdesc<2x256x128xi8, #shared, #smem, mutable> -> !ttg.memdesc<256x128xi8, #shared, #smem, mutable, 2x256x128>
412-
%2 = ttg.memdesc_trans %1 {order = array<i32: 1, 0>} : !ttg.memdesc<256x128xi8, #shared, #smem, mutable, 2x256x128> -> !ttg.memdesc<128x256xi8, #shared1, #smem, mutable, 2x128x256>
413-
tt.call @"test_frontend.anchor_noinline__MDi8S128_256SLNVMMA_64_8_True_False_NVMMALAS[2, 128, 256]ASMD__"(%2) : (!ttg.memdesc<128x256xi8, #shared1, #smem, mutable, 2x128x256>) -> ()
411+
%1 = ttg.memdesc_index %0[%c0_i32] : !ttg.memdesc<2x256x128xi8, #shared, #smem, mutable> -> !ttg.memdesc<256x128xi8, #shared, #smem, mutable>
412+
%2 = ttg.memdesc_trans %1 {order = array<i32: 1, 0>} : !ttg.memdesc<256x128xi8, #shared, #smem, mutable> -> !ttg.memdesc<128x256xi8, #shared1, #smem, mutable>
413+
tt.call @"test_frontend.anchor_noinline__MDi8S128_256SLNVMMA_64_8_True_False_NVMMALAS[128, 256]ASMD__"(%2) : (!ttg.memdesc<128x256xi8, #shared1, #smem, mutable>) -> ()
414414
%3 = ttg.local_alloc : () -> !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable>
415415
%4 = ttg.memdesc_reshape %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared3, #smem, mutable>
416416
%5 = ttg.memdesc_reinterpret %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<1024xi8, #shared4, #smem, mutable>
417417
tt.return
418418
}
419-
tt.func private @"test_frontend.anchor_noinline__MDi8S128_256SLNVMMA_64_8_True_False_NVMMALAS[2, 128, 256]ASMD__"(%arg0: !ttg.memdesc<128x256xi8, #shared1, #smem, mutable, 2x128x256>) attributes {noinline = true} {
419+
tt.func private @"test_frontend.anchor_noinline__MDi8S128_256SLNVMMA_64_8_True_False_NVMMALAS[128, 256]ASMD__"(%arg0: !ttg.memdesc<128x256xi8, #shared1, #smem, mutable>) attributes {noinline = true} {
420420
tt.return
421421
}
422422
}
@@ -930,7 +930,7 @@ def test_tmem_index_constexpr():
930930
tt.func public @tmem_index_kernel() attributes {noinline = false} {
931931
%result = ttng.tmem_alloc : () -> !ttg.memdesc<2x256x256xi32, #tmem, #ttng.tensor_memory, mutable>
932932
%c0_i32 = arith.constant 0 : i32
933-
%0 = ttg.memdesc_index %result[%c0_i32] : !ttg.memdesc<2x256x256xi32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<256x256xi32, #tmem, #ttng.tensor_memory, mutable, 2x256x256>
933+
%0 = ttg.memdesc_index %result[%c0_i32] : !ttg.memdesc<2x256x256xi32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<256x256xi32, #tmem, #ttng.tensor_memory, mutable>
934934
tt.return
935935
}
936936
}

python/test/unit/language/test_core.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6703,3 +6703,40 @@ def kernel():
67036703
tl.device_assert(tl.sum(x) == x.sum())
67046704

67056705
kernel[(1, )]()
6706+
6707+
6708+
@pytest.mark.interpreter
6709+
@pytest.mark.parametrize("rank", [2, 3, 4, 5, 6])
6710+
@pytest.mark.parametrize("trans_a", [False, True])
6711+
@pytest.mark.parametrize("trans_b", [False, True])
6712+
def test_dot_multidim(rank, trans_a, trans_b, device):
6713+
6714+
if is_interpreter():
6715+
pytest.xfail("bfloat16 is not supported in the interpreter")
6716+
6717+
@triton.jit
6718+
def kernel(X, Y, Z, RANK: tl.constexpr, TRANS_A: tl.constexpr, TRANS_B: tl.constexpr):
6719+
x = tl.load(X + tl.arange(0, 256 << RANK)).reshape([2] * (RANK - 2) + [32, 32])
6720+
y = tl.load(Y + tl.arange(0, 256 << RANK)).reshape([2] * (RANK - 2) + [32, 32])
6721+
if TRANS_A:
6722+
x = tl.trans(x)
6723+
if TRANS_B:
6724+
y = tl.trans(y)
6725+
z = tl.dot(x, y)
6726+
tl.store(Z + tl.arange(0, 256 << RANK), z.reshape([256 << RANK]))
6727+
6728+
shape = (2, ) * (rank - 2) + (32, 32)
6729+
6730+
a = torch.randint(-4, 5, shape, dtype=torch.bfloat16, device=device)
6731+
b = torch.randint(-4, 5, shape, dtype=torch.bfloat16, device=device)
6732+
c = torch.empty(shape, dtype=torch.float32, device=device)
6733+
kernel[(1, )](a, b, c, rank, trans_a, trans_b)
6734+
6735+
if trans_a:
6736+
a = torch.transpose(a, -1, -2)
6737+
if trans_b:
6738+
b = torch.transpose(b, -1, -2)
6739+
6740+
d = a.to(torch.float32) @ b.to(torch.float32)
6741+
6742+
assert torch.equal(c, d)

python/triton/experimental/gluon/language/_core.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -497,13 +497,20 @@ def warp_specialize(functions_and_args, worker_num_warps, worker_num_regs, _sema
497497
"""
498498
Create a warp-specialized execution region, partitioning work across warps.
499499
500+
This forks the current execution into a "default partition" and an arbitrary number of
501+
"worker partitons". The default partition is executed in the same :code:`num_warps` warps as
502+
the parent region, and may accept tensor arguments and return tensors. Worker partitions are
503+
executed in additional warps, which sit idle while executing the parent region.
504+
505+
Note that calling warp_specialize recursively is not supported.
506+
500507
Args:
501-
functions_and_args (List[Tuple[Callable, Any]]): List of functions and arguments for each partition.
502-
worker_num_warps (List[int]): Number of warps per partition.
503-
worker_num_regs (List[int]): Number of registers per partition.
508+
functions_and_args (List[Tuple[Callable, Any]]): List of functions and arguments for each partition. The first of which is the default partition.
509+
worker_num_warps (List[int]): Number of warps used for each worker partition.
510+
worker_num_regs (List[int]): Number of registers for each worker partition.
504511
505512
Returns:
506-
Tuple[Any, ...]: Results from the default region.
513+
Tuple[Any, ...]: Results from the default partition.
507514
"""
508515
worker_num_warps = [_unwrap_if_constexpr(w) for w in worker_num_warps]
509516
worker_num_regs = [_unwrap_if_constexpr(r) for r in worker_num_regs]

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def memdesc_index(self, mem_desc, index):
276276
shape = mem_desc.shape[1:]
277277
index = self.to_tensor(index).handle
278278
layout = mem_desc.layout
279-
ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape)
279+
ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, shape)
280280
builder = self.builder
281281
handle = builder.create_memdesc_index(ty.to_ir(builder), mem_desc.handle, index)
282282
return ttgl.shared_memory_descriptor(handle, **ty.__dict__)

python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def index(self, index, _semantic: GluonSemantic = None) -> tensor_memory_descrip
334334
builder = _semantic.builder
335335
shape = self.shape[1:]
336336
layout = self.layout
337-
ret = tensor_memory_descriptor(None, self.dtype, shape, layout, self.type.alloc_shape)
337+
ret = tensor_memory_descriptor(None, self.dtype, shape, layout, shape)
338338
ret.handle = builder.create_memdesc_index(ret.type.to_ir(builder), self.handle, index.handle)
339339
return ret
340340

python/triton/language/core.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1731,8 +1731,9 @@ def trans(input: tensor, *dims, _semantic=None):
17311731
"""
17321732
Permutes the dimensions of a tensor.
17331733
1734-
If the parameter :code:`dims` is not specified, the function defaults to a (1,0) permutation,
1735-
effectively transposing a 2D tensor.
1734+
If the parameter :code:`dims` is not specified, the function defaults to
1735+
swapping the last two axes, thereby performing an (optionally batched)
1736+
2D transpose.
17361737
17371738
:param input: The input tensor.
17381739
:param dims: The desired ordering of dimensions. For example,
@@ -1749,7 +1750,10 @@ def trans(input: tensor, *dims, _semantic=None):
17491750
"""
17501751
dims = _unwrap_iterable(dims)
17511752
if not dims:
1752-
dims = (1, 0)
1753+
n = len(input.shape)
1754+
if n < 2:
1755+
raise ValueError("tl.trans invoked with a 0- or 1-dimensional tensor")
1756+
dims = list(builtins.range(n - 2)) + [n - 1, n - 2]
17531757
return _semantic.permute(input, dims)
17541758

17551759

@@ -1771,7 +1775,7 @@ def permute(input, *dims, _semantic=None):
17711775
permute(x, 2, 1, 0)
17721776
17731777
:py:func:`trans` is equivalent to this function, except when
1774-
:code:`dims` is empty, it tries to do a (1,0) permutation.
1778+
:code:`dims` is empty, it tries to swap the last two axes.
17751779
"""
17761780
dims = _unwrap_iterable(dims)
17771781
return _semantic.permute(input, dims)
@@ -2024,7 +2028,36 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i
20242028
out_dtype = _unwrap_if_constexpr(out_dtype)
20252029
max_num_imprecise_acc = _unwrap_if_constexpr(max_num_imprecise_acc)
20262030
acc = _unwrap_if_constexpr(acc)
2027-
return _semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype)
2031+
2032+
# check shapes make sense:
2033+
a_shape = list(input.shape)
2034+
b_shape = list(other.shape)
2035+
assert len(a_shape) == len(b_shape) >= 2, "input and other must have equal ranks >= 2"
2036+
assert a_shape[:-2] == b_shape[:-2], "input and other must have equal batch shapes"
2037+
assert a_shape[-1] == b_shape[-2], "input and other must have equal reduction dimensions"
2038+
2039+
# compute shape of accumulator:
2040+
c_shape = a_shape[:-1] + [b_shape[-1]]
2041+
if acc is not None:
2042+
assert list(acc.shape) == c_shape, "accumulator shape is incompatible"
2043+
rank = len(c_shape)
2044+
2045+
if rank >= 4:
2046+
batch_size = 1
2047+
for i in builtins.range(rank - 2):
2048+
batch_size *= c_shape[i]
2049+
input = _semantic.reshape(input, [batch_size] + a_shape[-2:], can_reorder=False)
2050+
other = _semantic.reshape(other, [batch_size] + b_shape[-2:], can_reorder=False)
2051+
if acc is not None:
2052+
acc = _semantic.reshape(acc, [batch_size] + c_shape[-2:], can_reorder=False)
2053+
2054+
res = _semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype)
2055+
2056+
if rank >= 4:
2057+
res = _semantic.reshape(res, c_shape, can_reorder=False)
2058+
2059+
assert list(res.shape) == c_shape, "output shape is unexpected"
2060+
return res
20282061

20292062

20302063
@builtin

python/triton/runtime/autotuner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pr
3636
self.keys = key
3737
self.cache: Dict[Tuple, Config] = {}
3838
self.arg_names = arg_names
39-
self.cache_results = cache_results or (knobs.autotuning.cache and not knobs.runtime.interpret)
39+
self.cache_results = (cache_results or knobs.autotuning.cache) and not knobs.runtime.interpret
4040

4141
# Reset to zero or restore values
4242
self.reset_to_zero = []

0 commit comments

Comments
 (0)