|
8 | 8 | from triton.experimental import gluon |
9 | 9 | from triton.experimental.gluon import language as ttgl |
10 | 10 | from triton.experimental.gluon.language.nvidia import blackwell |
11 | | -from triton.experimental.gluon.language.nvidia.blackwell import mbarrier, tma, TensorMemoryLayout |
| 11 | +from triton.experimental.gluon.language.nvidia.blackwell import mbarrier, tma, TensorMemoryLayout, async_copy |
12 | 12 | from triton.experimental.gluon.nvidia.hopper import TensorDescriptor |
13 | 13 | from triton._filecheck import filecheck_test, run_parser |
14 | 14 | import triton.language as tl |
15 | | -from triton._internal_testing import is_cuda |
| 15 | +from triton._internal_testing import is_ampere_or_newer, is_blackwell, is_hopper |
16 | 16 | from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure |
17 | 17 |
|
18 | 18 | TARGET_PAT = re.compile('ttg.target = "[^"]*"') |
@@ -117,8 +117,7 @@ def tensor_memory_kernel(layout: ttgl.constexpr, tmem_layout: ttgl.constexpr): |
117 | 117 | buffers.index(i).load(layout) |
118 | 118 |
|
119 | 119 |
|
120 | | -@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] != 10, |
121 | | - reason="Requires blackwell tensor cores") |
| 120 | +@pytest.mark.skipif(not is_blackwell(), reason="Requires blackwell tensor cores") |
122 | 121 | def test_tensor_memory(fresh_knobs): |
123 | 122 | knobs.compilation.disable_line_info = True |
124 | 123 |
|
@@ -373,13 +372,13 @@ def mbarrier_kernel(): |
373 | 372 | bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) |
374 | 373 | mbarrier.init(bar, count=1) |
375 | 374 | mbarrier.expect(bar, 4) |
376 | | - mbarrier.arrive(bar, 1) |
| 375 | + mbarrier.arrive(bar, count=1) |
377 | 376 | phase = 0 |
378 | 377 | mbarrier.wait(bar, phase, deps=[bar]) |
379 | 378 | mbarrier.invalidate(bar) |
380 | 379 |
|
381 | 380 |
|
382 | | -@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer") |
| 381 | +@pytest.mark.skipif(not is_hopper(), reason="Requires hopper or newer") |
383 | 382 | def test_mbarrier(fresh_knobs): |
384 | 383 | knobs.compilation.disable_line_info = True |
385 | 384 |
|
@@ -415,8 +414,7 @@ def tcgen05_mma_kernel(nvmma_layout: ttgl.constexpr, acc_layout: ttgl.constexpr) |
415 | 414 | blackwell.tcgen05_mma(a, b, acc) |
416 | 415 |
|
417 | 416 |
|
418 | | -@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] != 10, |
419 | | - reason="Requires blackwell tensor core") |
| 417 | +@pytest.mark.skipif(not is_blackwell(), reason="Requires blackwell tensor core") |
420 | 418 | def test_tcgen05_mma(fresh_knobs): |
421 | 419 | knobs.compilation.disable_line_info = True |
422 | 420 |
|
@@ -460,7 +458,7 @@ def async_tma_kernel(input_desc, XBLOCK: ttgl.constexpr): |
460 | 458 | tma.store_wait(0) |
461 | 459 |
|
462 | 460 |
|
463 | | -@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="TMA requires at least Hopper") |
| 461 | +@pytest.mark.skipif(not is_hopper(), reason="TMA requires at least Hopper") |
464 | 462 | def test_async_tma(fresh_knobs): |
465 | 463 | knobs.compilation.disable_line_info = True |
466 | 464 |
|
@@ -519,7 +517,7 @@ def async_tma_blackwell_kernel(input_desc, XBLOCK: ttgl.constexpr): |
519 | 517 | tma.store_wait(0) |
520 | 518 |
|
521 | 519 |
|
522 | | -@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] != 10, reason="Requires Blackwell") |
| 520 | +@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell") |
523 | 521 | def test_async_tma_blackwell(fresh_knobs): |
524 | 522 | knobs.compilation.disable_line_info = True |
525 | 523 |
|
@@ -955,3 +953,53 @@ def test_inline_asm_elementwise(): |
955 | 953 | x = ttgl.arange(0, 16, layout) |
956 | 954 | # CHECK: elementwise_inline_asm {{.*}} : tensor<16xi32, [[BLOCKED:#.*]]> -> tensor<16xi32, [[BLOCKED]]> |
957 | 955 | ttgl.inline_asm_elementwise("mov $0, $0;", "=r,r", [x], dtype=x.dtype, is_pure=True, pack=1) |
| 956 | + |
| 957 | + |
| 958 | +@gluon.jit |
| 959 | +def async_copy_kernel(inp, xnumel, XBLOCK: ttgl.constexpr): |
| 960 | + smem = ttgl.allocate_shared_memory(inp.dtype.element_ty, [XBLOCK], ttgl.SwizzledSharedLayout(1, 1, 1, order=[0])) |
| 961 | + block_layout: ttgl.constexpr = ttgl.BlockedLayout([2], [32], [4], [0]) |
| 962 | + xindex = ttgl.arange(0, XBLOCK, block_layout) |
| 963 | + mask = tl.max_constancy(xindex < xnumel, 2) |
| 964 | + |
| 965 | + async_copy.async_copy_global_to_shared(smem, inp + xindex, mask) |
| 966 | + async_copy.async_copy_global_to_shared(smem, inp + xindex, mask, cache_modifier=".ca", eviction_policy="evict_last", |
| 967 | + volatile=True) |
| 968 | + |
| 969 | + mbar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) |
| 970 | + async_copy.mbarrier_arrive(mbar) |
| 971 | + async_copy.mbarrier_arrive(mbar, increment_count=False) |
| 972 | + async_copy.commit_group() |
| 973 | + async_copy.wait_group(0) |
| 974 | + |
| 975 | + |
| 976 | +@pytest.mark.skipif(not is_ampere_or_newer(), reason="Requires ampere") |
| 977 | +def test_async_copy(fresh_knobs): |
| 978 | + knobs.compilation.disable_line_info = True |
| 979 | + |
| 980 | + h = async_copy_kernel.warmup(MockTensor(ttgl.float16), xnumel=100, XBLOCK=128, sanitize_overflow=False, grid=(1, )) |
| 981 | + expecttest.assert_expected_inline( |
| 982 | + anonymize_ir(h.asm["ttgir"]), """\ |
| 983 | +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> |
| 984 | +#loc = loc(unknown) |
| 985 | +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> |
| 986 | +#smem = #ttg.shared_memory |
| 987 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} { |
| 988 | + tt.func public @async_copy_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32} loc(unknown), %arg1: i32 loc(unknown)) attributes {noinline = false} { |
| 989 | + %0 = ttg.local_alloc : () -> !ttg.memdesc<128xf16, #shared, #smem, mutable> loc(#loc) |
| 990 | + %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked> loc(#loc) |
| 991 | + %2 = tt.splat %arg1 : i32 -> tensor<128xi32, #blocked> loc(#loc) |
| 992 | + %3 = arith.cmpi slt, %1, %2 {tt.constancy = dense<2> : tensor<1xi32>} : tensor<128xi32, #blocked> loc(#loc) |
| 993 | + %4 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x!tt.ptr<f16>, #blocked> loc(#loc) |
| 994 | + %5 = tt.addptr %4, %1 : tensor<128x!tt.ptr<f16>, #blocked>, tensor<128xi32, #blocked> loc(#loc) |
| 995 | + %6 = ttg.async_copy_global_to_local %5, %0 mask %3 : tensor<128x!tt.ptr<f16>, #blocked> -> <128xf16, #shared, #smem, mutable> loc(#loc) |
| 996 | + %7 = ttg.async_copy_global_to_local %5, %0 mask %3 cacheModifier = ca evictionPolicy = evict_last {isVolatile = true} : tensor<128x!tt.ptr<f16>, #blocked> -> <128xf16, #shared, #smem, mutable> loc(#loc) |
| 997 | + %8 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared, #smem, mutable> loc(#loc) |
| 998 | + ttng.async_copy_mbarrier_arrive %8 : !ttg.memdesc<1xi64, #shared, #smem, mutable> loc(#loc) |
| 999 | + ttng.async_copy_mbarrier_arrive %8 {noIncrement} : !ttg.memdesc<1xi64, #shared, #smem, mutable> loc(#loc) |
| 1000 | + %9 = ttg.async_commit_group loc(#loc) |
| 1001 | + %10 = ttg.async_wait {num = 0 : i32} loc(#loc) |
| 1002 | + tt.return loc(#loc) |
| 1003 | + } loc(#loc) |
| 1004 | +} loc(#loc) |
| 1005 | +""") |
0 commit comments