Skip to content

Commit 565808e

Browse files
Merge OpenAI Triton commit 717997b (#4497)
This PR change the Triton base from 88a2851 to 717997b (Jun 11). Pass rate: 97.11%
2 parents c7b3773 + 7efdd03 commit 565808e

File tree

13 files changed

+120
-21
lines changed

13 files changed

+120
-21
lines changed

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@ SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
3535
// Return true if the Load uses block pointer.
3636
bool isLoadFromTensorPtr(triton::LoadOp op);
3737

38-
// Return an array of indices enumerating the elements of 'arr' in descending
39-
// order (so that result[i] is the index of the i-th largest element of 'arr')
40-
SmallVector<unsigned, 4> argSort(const SmallVector<int64_t> &arr);
38+
// Gets the order of a tensor from its contiguity. Places the dimensions with
39+
// the largest contiguity as the inner most dimension. If the contiguity is
40+
// all ones, returns the order {dim - 1, dim - 2, ..., 0}
41+
SmallVector<unsigned, 4>
42+
getOrderFromContiguity(const SmallVector<int64_t> &contiguity);
4143

4244
// Return the operand used to access the memory in the operation
4345
Value getMemAccessPtr(Operation *op);

lib/Dialect/TritonGPU/Transforms/Coalesce.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase<CoalescePass> {
3838
});
3939

4040
auto contiguity = axisInfoAnalysis.getAxisInfo(ptr)->getContiguity();
41-
SmallVector<unsigned> order = argSort(contiguity);
41+
SmallVector<unsigned> order = getOrderFromContiguity(contiguity);
4242
LDBG("order=[" << triton::join(order, ", ") << "]");
4343

4444
auto matchesShape = [&refTensorType](const Value &val) {
@@ -55,8 +55,8 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase<CoalescePass> {
5555
Value val = getMemAccessPtr(use);
5656
if (!val || !matchesShape(val) || memAccessesSameOrder.contains(use))
5757
continue;
58-
auto currOrder =
59-
argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity());
58+
auto currOrder = getOrderFromContiguity(
59+
axisInfoAnalysis.getAxisInfo(val)->getContiguity());
6060
if (order == currOrder) {
6161
LDBG("multi-root-slice: insert to memAccessesSameOrder " << *use);
6262
memAccessesSameOrder.insert(use);

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,11 @@ bool isLoadFromTensorPtr(triton::LoadOp op) {
9292
return mlir::triton::isTensorPointerType(op.getPtr().getType());
9393
}
9494

95-
SmallVector<unsigned, 4> argSort(const SmallVector<int64_t> &arr) {
95+
SmallVector<unsigned, 4>
96+
getOrderFromContiguity(const SmallVector<int64_t> &arr) {
9697
SmallVector<unsigned, 4> ret(arr.size());
9798
std::iota(ret.begin(), ret.end(), 0);
99+
std::reverse(ret.begin(), ret.end());
98100
std::stable_sort(ret.begin(), ret.end(),
99101
[&](unsigned x, unsigned y) { return arr[x] > arr[y]; });
100102
return ret;

python/test/gluon/test_frontend.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -881,3 +881,47 @@ def test_tensor_reshape():
881881
expect_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1, 2], [2, 4, 4], [4, 1, 1], [2, 1, 0], [1, 1, 1], [1, 1, 1],
882882
[2, 1, 0])
883883
ttgl.static_assert(v.type.layout == expect_layout)
884+
885+
886+
@filecheck_test
887+
@gluon.jit
888+
def test_zeros():
889+
# CHECK: [[BLOCKED:#.*]] = #ttg.blocked<{sizePerThread = [2]
890+
# CHECK: [[BLOCKED2D:#.*]] = #ttg.blocked<{sizePerThread = [1, 2]
891+
layout: ttgl.constexpr = ttgl.BlockedLayout([2], [32], [4], [0])
892+
layout_2d: ttgl.constexpr = ttgl.BlockedLayout([1, 2], [4, 8], [4, 1], [1, 0])
893+
894+
# CHECK: arith.constant dense<0.000000e+00> : tensor<32xf32, [[BLOCKED]]>
895+
a = ttgl.zeros([32], ttgl.float32, layout)
896+
897+
# CHECK: arith.constant dense<7.000000e+00> : tensor<32xf32, [[BLOCKED]]>
898+
ttgl.full_like(a, 7)
899+
900+
# CHECK: arith.constant dense<0.000000e+00> : tensor<32xf32, [[BLOCKED]]>
901+
ttgl.zeros_like(a)
902+
903+
# CHECK: arith.constant dense<0.000000e+00> : tensor<64xf32, [[BLOCKED]]>
904+
ttgl.zeros_like(a, shape=[64])
905+
906+
# CHECK: arith.constant dense<0> : tensor<16x16xi8, [[BLOCKED2D]]>
907+
ttgl.zeros_like(a, shape=[16, 16], dtype=ttgl.int8, layout=layout_2d)
908+
909+
# CHECK: arith.constant dense<7> : tensor<8x8xi16, [[BLOCKED2D]]>
910+
ttgl.full_like(a, 7, shape=[8, 8], dtype=ttgl.int16, layout=layout_2d)
911+
912+
913+
@filecheck_test
914+
@gluon.jit
915+
def test_barrier():
916+
# CHECK: gpu.barrier
917+
ttgl.thread_barrier()
918+
919+
920+
@filecheck_test
921+
@gluon.jit
922+
def test_fence_async_shared():
923+
# CHECK: ttng.fence_async_shared {bCluster = false}
924+
blackwell.fence_async_shared()
925+
926+
# CHECK-NEXT: ttng.fence_async_shared {bCluster = true}
927+
blackwell.fence_async_shared(cluster=True)

python/triton/experimental/gluon/language/_core.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
"tensor",
9292
"tuple",
9393
"tuple_type",
94+
"thread_barrier",
9495
"arange",
9596
"full",
9697
"convert_layout",
@@ -313,3 +314,8 @@ def warp_specialize(args, default_partition, worker_partitions, worker_num_warps
313314
worker_num_regs = [_unwrap_if_constexpr(r) for r in worker_num_regs]
314315
return _semantic.warp_specialize(args, default_partition, worker_partitions, worker_num_warps, #
315316
worker_num_regs, _generator)
317+
318+
319+
@builtin
320+
def thread_barrier(_semantic=None):
321+
return _semantic.debug_barrier()

python/triton/experimental/gluon/language/_standard.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import triton.language.standard as tl_standard
44
from .._runtime import jit
55
from triton import knobs
6+
from . import _core as ttgl
67

78
_IMPORT_FROM_TRITON = [
89
"sum",
@@ -12,10 +13,35 @@
1213
"xor_sum",
1314
]
1415

15-
__all__ = _IMPORT_FROM_TRITON
16+
__all__ = [
17+
"full_like",
18+
"zeros",
19+
"zeros_like",
20+
*_IMPORT_FROM_TRITON,
21+
]
1622

1723
for name in _IMPORT_FROM_TRITON:
1824
# Convert JITFunction -> GluonJitFunction
1925
fn = getattr(tl_standard, name)
2026
assert knobs.runtime.interpret or isinstance(fn, triton.runtime.JITFunction)
2127
globals()[name] = jit(fn.fn)
28+
29+
30+
@jit
31+
def zeros(shape, dtype, layout):
32+
return ttgl.full(shape, 0, dtype, layout)
33+
34+
35+
@jit
36+
def full_like(input, value, shape=None, dtype=None, layout=None):
37+
return ttgl.full(
38+
input.shape if shape is None else shape,
39+
value,
40+
input.dtype if dtype is None else dtype,
41+
input.type.layout if layout is None else layout,
42+
)
43+
44+
45+
@jit
46+
def zeros_like(input, shape=None, dtype=None, layout=None):
47+
return full_like(input, 0, shape=shape, dtype=dtype, layout=layout)

python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,19 @@
66
from triton.experimental.gluon.language._core import builtin, base_type, base_value, _unwrap_if_constexpr
77

88
from . import tma
9-
from ..hopper import mbarrier
9+
from ..hopper import mbarrier, fence_async_shared
1010

1111
if TYPE_CHECKING:
1212
from triton._C.libtriton.gluon_ir import GluonOpBuilder
1313
from triton._C.libtriton import gluon_ir as ir
1414
from ..._semantic import GluonSemantic
1515

1616
__all__ = [
17-
"TensorMemoryLayout",
18-
"tensor_memory_descriptor",
1917
"allocate_tensor_memory",
18+
"fence_async_shared",
2019
"mbarrier",
20+
"tensor_memory_descriptor",
21+
"TensorMemoryLayout",
2122
"tma",
2223
]
2324

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
11
from . import mbarrier
22
from . import tma
3+
from ... import _core
34

4-
__all__ = ["mbarrier", "tma"]
5+
__all__ = ["fence_async_shared", "mbarrier", "tma"]
6+
7+
8+
@_core.builtin
9+
def fence_async_shared(cluster=False, _semantic=None):
10+
cluster = _core._unwrap_if_constexpr(cluster)
11+
_semantic.builder.create_fence_async_shared(cluster)

python/triton_kernels/triton_kernels/matmul_ogs_details/_finalize_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def _finalize_matmul(
291291
if src_idx != -1:
292292
As = A + src_idx.to(tl.int64) * stride_a_m + offs_n
293293
for ki in tl.static_range(K):
294-
acc += tl.load(As, mask=n_mask, other=0.0)
294+
acc += tl.load(As, mask=(src_idxs != -1)[:, None] & n_mask[None, :], other=0.0)
295295
As += stride_a_k
296296
else:
297297
As = A + src_idxs.to(tl.int64)[:, None] * stride_a_m + offs_n[None, :]

python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def _compute_writeback_idx(
387387
is_src_active = (src_idxs != -1).to(tl.int32)
388388
num_src_active = tl.sum(is_src_active, axis=1)
389389

390-
need_finalize_scatter = mask_m & (num_src_active > 1)
390+
need_finalize_scatter = mask_m & (num_src_active != 1)
391391
finalize_scatter_count = tl.sum(need_finalize_scatter.to(tl.int32))
392392
if finalize_scatter_count == 0:
393393
return

0 commit comments

Comments
 (0)