Skip to content

Commit 6baca09

Browse files
dshi7meta-codesync[bot]
authored andcommitted
Fix CLC unit test (#620)
Summary: pytest python/test/unit/language/test_tlx.py::test_cluster_launch_control Pull Request resolved: #620 Reviewed By: agron911 Differential Revision: D85926674 Pulled By: dshi7 fbshipit-source-id: 771c1e2d19819be0d3e986e2c2f9c2f7444ba684
1 parent c752031 commit 6baca09

File tree

1 file changed

+24
-29
lines changed

1 file changed

+24
-29
lines changed

python/test/unit/language/test_tlx.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1682,37 +1682,36 @@ def mul2_clc(
16821682
n_elements,
16831683
BLOCK_SIZE: tl.constexpr,
16841684
):
1685-
pid = tl.program_id(axis=0)
1686-
block_start = pid * BLOCK_SIZE
1687-
1688-
tid = tlx.thread_id(axis=0)
1689-
1690-
offsets = block_start + tl.arange(0, BLOCK_SIZE)
1691-
mask = offsets < n_elements
1685+
tile_id = tl.program_id(axis=0)
16921686

1693-
x = tl.load(x_ptr + offsets, mask=mask)
1694-
y = tl.load(y_ptr + offsets, mask=mask)
1695-
output = x * y
1696-
tl.store(z_ptr + offsets, output, mask=mask)
1687+
# CLC Init
1688+
clc_phase_producer = 1
1689+
clc_phase_consumer = 0
1690+
# NUM_CLC_STAGES=1
1691+
# NUM_CONSUMERS=1
1692+
clc_context = tlx.clc_create_context(1, 1)
16971693

1698-
bars = tlx.alloc_barriers(num_barriers=1)
1699-
clc_mbar = bars[0]
1694+
while tile_id != -1:
1695+
# CLC producer
1696+
tlx.clc_producer(clc_context, 0, clc_phase_producer)
1697+
clc_phase_producer ^= 1
17001698

1701-
responses = tlx.alloc_clc_responses(num_responses=1)
1702-
clc_response = tlx.local_view(responses, 0)
1703-
tlx.barrier_expect_bytes(clc_mbar, 16) # CLC response is 16-byte
1699+
block_start = tile_id * BLOCK_SIZE
17041700

1705-
# Issue async clc.try_cancel for the next available CTA
1706-
tlx.clc_issue(clc_response, clc_mbar)
1701+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
1702+
mask = offsets < n_elements
17071703

1708-
# Wait for clc.try_cancel finishes
1709-
tlx.barrier_wait(clc_mbar, 0)
1704+
x = tl.load(x_ptr + offsets, mask=mask)
1705+
y = tl.load(y_ptr + offsets, mask=mask)
1706+
output = x * y
1707+
tl.store(z_ptr + offsets, output, mask=mask)
17101708

1711-
# Extract CTA ID from CLC response
1712-
res = tlx.clc_query(clc_response)
1709+
# CLC consumer
1710+
tile_id = tlx.clc_consumer(clc_context, 0, clc_phase_consumer)
1711+
clc_phase_consumer ^= 1
17131712

1714-
if tid == 0:
1715-
tl.device_print("Extracted CtaID", res)
1713+
if tlx.thread_id(axis=0) == 0:
1714+
tl.device_print("Extracted CtaID", tile_id)
17161715

17171716
torch.manual_seed(0)
17181717
# number of kernels to launch in a non-persistent mode
@@ -1731,11 +1730,7 @@ def mul2_clc(
17311730
assert re.search((r'clusterlaunchcontrol.query_cancel.is_canceled.pred.b128'), ptx, flags=re.DOTALL)
17321731
assert re.search((r'clusterlaunchcontrol.query_cancel.get_first_ctaid.v4.b32.b128'), ptx, flags=re.DOTALL)
17331732

1734-
# Each worker uses the {blockIdx.x, blockIdx.y, blockIdx.z} coordinate as the first output tile to process
1735-
# and uses the CLC query for subsequent processing of output tiles.
1736-
# However in our test those CTAs left from the first round won't execute.
1737-
# Its nonzero count MUST be different from original size.
1738-
assert (torch.count_nonzero(output) != size)
1733+
assert (torch.count_nonzero(output) == size)
17391734

17401735

17411736
@pytest.mark.skipif(not is_hopper_or_newer(), reason="Need Hopper or newer")

0 commit comments

Comments
 (0)