|
6 | 6 | from triton.experimental import gluon
|
7 | 7 | from triton.experimental.gluon import language as ttgl
|
8 | 8 | from triton.experimental.gluon.language.nvidia import blackwell
|
9 |
| -from triton.experimental.gluon.language.nvidia.blackwell import mbarrier |
| 9 | +from triton.experimental.gluon.language.nvidia.blackwell import mbarrier, tma |
10 | 10 | from triton._filecheck import filecheck_test, run_parser
|
11 | 11 | import triton.language as tl
|
12 | 12 | from triton._internal_testing import is_cuda
|
| 13 | +from triton.tools.tensor_descriptor import TensorDescriptor |
13 | 14 | from triton.compiler.errors import CompilationError
|
14 | 15 |
|
15 | 16 |
|
@@ -408,6 +409,126 @@ def test_tcgen05_mma(fresh_knobs):
|
408 | 409 | """)
|
409 | 410 |
|
410 | 411 |
|
| 412 | +@gluon.jit |
| 413 | +def async_tma_kernel(input_desc, XBLOCK: ttgl.constexpr, smem_layout: ttgl.constexpr): |
| 414 | + smem = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], smem_layout) |
| 415 | + bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) |
| 416 | + mbarrier.init(bar, count=1) |
| 417 | + |
| 418 | + tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem) |
| 419 | + mbarrier.expect(bar, XBLOCK * XBLOCK * ttgl.float16.primitive_bitwidth // 8) |
| 420 | + mbarrier.wait(bar, 0) |
| 421 | + |
| 422 | + mbarrier.invalidate(bar) |
| 423 | + |
| 424 | + tma.async_copy_shared_to_global(input_desc, [0, 0], smem) |
| 425 | + tma.store_wait(0) |
| 426 | + |
| 427 | + |
| 428 | +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="TMA requires at least Hopper") |
| 429 | +def test_async_tma(fresh_knobs): |
| 430 | + knobs.compilation.disable_line_info = True |
| 431 | + |
| 432 | + input = torch.randn((1024, 1024), device="cuda", dtype=torch.float16) |
| 433 | + XBLOCK = 128 |
| 434 | + input_desc = TensorDescriptor.from_tensor(input, [XBLOCK, XBLOCK]) |
| 435 | + shared_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2) |
| 436 | + |
| 437 | + h = async_tma_kernel.warmup(input_desc, XBLOCK, shared_layout, grid=(1, ), num_warps=4) |
| 438 | + expecttest.assert_expected_inline( |
| 439 | + h.asm["source"], """\ |
| 440 | +#loc = loc(unknown) |
| 441 | +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> |
| 442 | +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> |
| 443 | +#smem = #ttg.shared_memory |
| 444 | +module attributes {"ttg.num-warps" = 4 : i32} { |
| 445 | + tt.func public @async_tma_kernel(%arg0: !tt.tensordesc<tensor<128x128xf16>> loc(unknown), %arg1: i32 loc(unknown), %arg2: i32 loc(unknown), %arg3: i64 loc(unknown), %arg4: i64 loc(unknown)) attributes {noinline = false} { |
| 446 | + %0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc) |
| 447 | + %1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc) |
| 448 | + ttng.init_barrier %1, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc) |
| 449 | + %c0_i32 = arith.constant 0 : i32 loc(#loc) |
| 450 | + %c0_i32_0 = arith.constant 0 : i32 loc(#loc) |
| 451 | + %true = arith.constant true loc(#loc) |
| 452 | + %2 = ttng.tensor_desc_to_tma_ptr %arg0 : !tt.tensordesc<tensor<128x128xf16>> to !tt.ptr<i8> loc(#loc) |
| 453 | + ttng.async_tma_copy_global_to_local %2[%c0_i32, %c0_i32_0] %0, %1, %true : !tt.ptr<i8>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc) |
| 454 | + %true_1 = arith.constant true loc(#loc) |
| 455 | + ttng.barrier_expect %1, 32768, %true_1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc) |
| 456 | + %c0_i32_2 = arith.constant 0 : i32 loc(#loc) |
| 457 | + %true_3 = arith.constant true loc(#loc) |
| 458 | + ttng.wait_barrier %1, %c0_i32_2, %true_3 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc) |
| 459 | + ttng.inval_barrier %1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc) |
| 460 | + %c0_i32_4 = arith.constant 0 : i32 loc(#loc) |
| 461 | + %c0_i32_5 = arith.constant 0 : i32 loc(#loc) |
| 462 | + %3 = ttng.tensor_desc_to_tma_ptr %arg0 : !tt.tensordesc<tensor<128x128xf16>> to !tt.ptr<i8> loc(#loc) |
| 463 | + ttng.async_tma_copy_local_to_global %3[%c0_i32_4, %c0_i32_5] %0 : !tt.ptr<i8>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc) |
| 464 | + ttng.async_tma_store_wait {pendings = 0 : i32} loc(#loc) |
| 465 | + tt.return loc(#loc) |
| 466 | + } loc(#loc) |
| 467 | +} loc(#loc) |
| 468 | +""") |
| 469 | + |
| 470 | + |
| 471 | +@gluon.jit |
| 472 | +def async_tma_blackwell_kernel(input_desc, XBLOCK: ttgl.constexpr, smem_layout: ttgl.constexpr): |
| 473 | + smem = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], smem_layout) |
| 474 | + bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) |
| 475 | + mbarrier.init(bar, count=1) |
| 476 | + |
| 477 | + offset_layout: tl.constexpr = ttgl.BlockedLayout([1, 4], [32, 1], [1, 4], [1, 0]) |
| 478 | + x_offsets = ttgl.arange(0, XBLOCK, layout=ttgl.SliceLayout(0, offset_layout)) |
| 479 | + tma.async_gather(input_desc, x_offsets, 0, bar, smem) |
| 480 | + mbarrier.expect(bar, XBLOCK * XBLOCK * ttgl.float16.primitive_bitwidth // 8) |
| 481 | + mbarrier.wait(bar, 0) |
| 482 | + |
| 483 | + mbarrier.invalidate(bar) |
| 484 | + |
| 485 | + tma.async_scatter(input_desc, x_offsets, 0, smem) |
| 486 | + tma.store_wait(0) |
| 487 | + |
| 488 | + |
| 489 | +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] != 10, reason="Requires Blackwell") |
| 490 | +def test_async_tma_blackwell(fresh_knobs): |
| 491 | + knobs.compilation.disable_line_info = True |
| 492 | + |
| 493 | + input = torch.randn((1024, 1024), device="cuda", dtype=torch.float16) |
| 494 | + XBLOCK = 128 |
| 495 | + input_desc = TensorDescriptor.from_tensor(input, [1, XBLOCK]) |
| 496 | + shared_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2) |
| 497 | + |
| 498 | + h = async_tma_blackwell_kernel.warmup(input_desc, XBLOCK, shared_layout, grid=(1, ), num_warps=4) |
| 499 | + expecttest.assert_expected_inline( |
| 500 | + h.asm["source"], """\ |
| 501 | +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> |
| 502 | +#loc = loc(unknown) |
| 503 | +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> |
| 504 | +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> |
| 505 | +#smem = #ttg.shared_memory |
| 506 | +module attributes {"ttg.num-warps" = 4 : i32} { |
| 507 | + tt.func public @async_tma_blackwell_kernel(%arg0: !tt.tensordesc<tensor<1x128xf16>> loc(unknown), %arg1: i32 loc(unknown), %arg2: i32 loc(unknown), %arg3: i64 loc(unknown), %arg4: i64 loc(unknown)) attributes {noinline = false} { |
| 508 | + %0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc) |
| 509 | + %1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc) |
| 510 | + ttng.init_barrier %1, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc) |
| 511 | + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc) |
| 512 | + %true = arith.constant true loc(#loc) |
| 513 | + %c0_i32 = arith.constant 0 : i32 loc(#loc) |
| 514 | + %3 = ttng.tensor_desc_to_tma_ptr %arg0 : !tt.tensordesc<tensor<1x128xf16>> to !tt.ptr<i8> loc(#loc) |
| 515 | + ttng.async_tma_gather %3[%2, %c0_i32] %0, %1, %true : !tt.ptr<i8>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<1xi64, #shared1, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, i1 loc(#loc) |
| 516 | + %true_0 = arith.constant true loc(#loc) |
| 517 | + ttng.barrier_expect %1, 32768, %true_0 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc) |
| 518 | + %c0_i32_1 = arith.constant 0 : i32 loc(#loc) |
| 519 | + %true_2 = arith.constant true loc(#loc) |
| 520 | + ttng.wait_barrier %1, %c0_i32_1, %true_2 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc) |
| 521 | + ttng.inval_barrier %1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc) |
| 522 | + %4 = ttng.tensor_desc_to_tma_ptr %arg0 : !tt.tensordesc<tensor<1x128xf16>> to !tt.ptr<i8> loc(#loc) |
| 523 | + %c0_i32_3 = arith.constant 0 : i32 loc(#loc) |
| 524 | + ttng.async_tma_scatter %4[%2, %c0_i32_3] %0 : !tt.ptr<i8>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc) |
| 525 | + ttng.async_tma_store_wait {pendings = 0 : i32} loc(#loc) |
| 526 | + tt.return loc(#loc) |
| 527 | + } loc(#loc) |
| 528 | +} loc(#loc) |
| 529 | +""") |
| 530 | + |
| 531 | + |
411 | 532 | def test_mlir_attr_error():
|
412 | 533 |
|
413 | 534 | @gluon.jit
|
|
0 commit comments