Skip to content

Commit c6c7ac0

Browse files
pchen7e2meta-codesync[bot]
authored andcommitted
[2/N][TLX-2cta] Expose cluster_cta_rank (#638)
Summary: This will expose the capability for executing CTA to know whether it's leader CTA in the pair or not. It will be necessary if we want non leader CTA to arrive a barrier for leader CTA and synchronize the two before issuing MMA. ``` % make test-lit ninja -C /data/users/pchen7e4/triton/build/cmake.linux-x86_64-cpython-3.11 check-triton-lit-tests ninja: Entering directory `/data/users/pchen7e4/triton/build/cmake.linux-x86_64-cpython-3.11' [0/1] Running the triton regression tests Testing Time: 7.81s Total Discovered Tests: 208 Passed : 207 (99.52%) Expectedly Failed: 1 (0.48%) % third_party/tlx/run_all.sh Need to build triton in this script? {y|n}n Run all LITs? {y|n}n Run core Triton python unit tests? {y|n}n Run all TLX unit tests? {y|n}y Running TLX Unit Tests ... ====================================================================================== 31 passed, 76 skipped in 19.55s ====================================================================================== Run TLX tutorial kernels (correctness|performance|no)? {c|p|n} c Verifying correctness of TLX tutorial kernels (all passing) ``` Pull Request resolved: #638 Reviewed By: htyu Differential Revision: D86249537 Pulled By: pchen7e2 fbshipit-source-id: 1becf189deab327d33ea64d32963723668bae257
1 parent ca34989 commit c6c7ac0

File tree

5 files changed

+62
-0
lines changed

5 files changed

+62
-0
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ While this approach places more responsibility on the user, it reduces the compi
5050

5151
Slice a `M x N` tensor at a `m x n` offset.
5252

53+
### Remote buffer operations
54+
55+
- `buffer = tlx.remote_view(buffer, remote_cta_rank)`
56+
57+
Return a remote view of the `buffer` living in another CTA in the same cluster with ID `remote_cta_rank`. NOTE: for
58+
now we only support barrier as `buffer`, not general SMEM.
59+
5360
### Async memory access
5461

5562

@@ -171,6 +178,11 @@ Examples: how mbarriers are communicated in warp specialization
171178

172179
`tlx.async_task(num_warps=4)` defines a warp-specialized asynchronous task that explicitly reserves 4 warps in addition to those used by the trunk task..
173180

181+
### Other operations
182+
183+
- `tlx.cluster_cta_rank()`
184+
185+
Returns the rank (unique ID) of the current CTA within the cluster.
174186

175187
- `tlx.thread_id(axis)`
176188

python/test/unit/language/test_tlx.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,37 @@ def store_from_thread_0_kernel(
552552
torch.testing.assert_close(output, expected_output)
553553

554554

555+
@pytest.mark.skipif(not is_hopper_or_newer(), reason="Need Hopper or newer")
556+
def test_custer_cta_rank(device):
557+
558+
@triton.jit
559+
def test_cta_0_kernel(
560+
output_ptr,
561+
n_elements,
562+
BLOCK_SIZE: tl.constexpr,
563+
):
564+
pid = tl.program_id(axis=0)
565+
block_start = pid * BLOCK_SIZE
566+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
567+
mask = offsets < n_elements
568+
# without multi-cta cluster launch, this test does not validate much except
569+
# the fact that the IR lowering flow works
570+
cta_id = tlx.cluster_cta_rank()
571+
tl.store(output_ptr + offsets, cta_id, mask=mask)
572+
573+
tensor_size = 32
574+
# init with 1, expected to be filled with 0
575+
output = torch.ones(tensor_size, dtype=torch.int32, device=device)
576+
kernel = test_cta_0_kernel[(1, )](output, tensor_size, tensor_size, num_warps=1)
577+
578+
ttgir = kernel.asm["ttgir"]
579+
assert ttgir.count("nvgpu.cluster_id") == 1
580+
581+
torch.cuda.synchronize()
582+
expected_output = torch.zeros(tensor_size, dtype=torch.int32, device=device)
583+
torch.testing.assert_close(output, expected_output)
584+
585+
555586
def test_clock64(device):
556587

557588
@triton.jit

third_party/tlx/dialect/triton_tlx.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "Transforms/Passes.h"
33
#include "ir.h" // TritonOpBuilder
44
#include "mlir/Pass/PassManager.h"
5+
#include "nvidia/include/Dialect/NVGPU/IR/Dialect.h"
56
#include "passes.h"
67
#include "tlx/dialect/include/Transforms/Passes.h"
78
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -555,6 +556,14 @@ void init_triton_tlx_ir(py::module &&m) {
555556
self.getBuilder().getI32Type(), threadId);
556557
return threadId;
557558
})
559+
.def("create_cluster_cta_rank",
560+
[](TritonOpBuilder &self) -> Value {
561+
// The naming of ClusterCTAIdOp is bad. It actually returns the
562+
// cluster CTA rank (1D) instead of cluster CTA ID (3D)
563+
Value rank = self.create<triton::nvgpu::ClusterCTAIdOp>(
564+
self.getBuilder().getI32Type());
565+
return rank;
566+
})
558567
.def("create_map_to_remote_buffer",
559568
[](TritonOpBuilder &self, Value &src,
560569
Value &clusterCTARank) -> Value {

third_party/tlx/language/tlx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
tcgen05_commit,
3333
)
3434
from .utility import (
35+
cluster_cta_rank,
3536
thread_id,
3637
async_task_replica_id,
3738
dtype_of,
@@ -96,6 +97,7 @@
9697
"async_dot_wait",
9798
"tcgen05_commit",
9899
# utility
100+
"cluster_cta_rank",
99101
"thread_id",
100102
"async_task_replica_id",
101103
"dtype_of",

third_party/tlx/language/tlx/utility.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@ def cuda_parse_arch(arch):
1717
return int(match.group(1))
1818

1919

20+
@tl.builtin
21+
def cluster_cta_rank(_semantic=None):
22+
"""
23+
:return the unique CTA ID within a cluster across all dims
24+
"""
25+
return tl.tensor(_semantic.builder.create_cluster_cta_rank(), tl.int32)
26+
27+
2028
@tl.builtin
2129
def thread_id(axis, _semantic=None):
2230
"""

0 commit comments

Comments
 (0)