9
9
from triton .experimental .gluon import language as ttgl
10
10
from triton .experimental .gluon .language .nvidia import blackwell
11
11
from triton .experimental .gluon .language .nvidia .blackwell import mbarrier , tma , TensorMemoryLayout
12
+ from triton .experimental .gluon .nvidia .hopper import TensorDescriptor
12
13
from triton ._filecheck import filecheck_test , run_parser
13
14
import triton .language as tl
14
15
from triton ._internal_testing import is_cuda
15
- from triton .tools .tensor_descriptor import TensorDescriptor
16
16
from triton .compiler .errors import CompilationError
17
17
18
18
TARGET_PAT = re .compile ('ttg.target = "[^"]*"' )
@@ -434,8 +434,8 @@ def test_tcgen05_mma(fresh_knobs):
434
434
435
435
436
436
@gluon .jit
437
- def async_tma_kernel (input_desc , XBLOCK : ttgl .constexpr , smem_layout : ttgl . constexpr ):
438
- smem = ttgl .allocate_shared_memory (ttgl .float16 , [XBLOCK , XBLOCK ], smem_layout )
437
+ def async_tma_kernel (input_desc , XBLOCK : ttgl .constexpr ):
438
+ smem = ttgl .allocate_shared_memory (ttgl .float16 , [XBLOCK , XBLOCK ], input_desc . layout )
439
439
bar = ttgl .allocate_shared_memory (ttgl .int64 , [1 ], mbarrier .MBarrierLayout ())
440
440
mbarrier .init (bar , count = 1 )
441
441
@@ -455,25 +455,25 @@ def test_async_tma(fresh_knobs):
455
455
456
456
input = torch .randn ((1024 , 1024 ), device = "cuda" , dtype = torch .float16 )
457
457
XBLOCK = 128
458
- input_desc = TensorDescriptor .from_tensor (input , [XBLOCK , XBLOCK ])
459
458
shared_layout = ttgl .NVMMASharedLayout (swizzle_byte_width = 128 , element_bitwidth = 16 , rank = 2 )
459
+ input_desc = TensorDescriptor .from_tensor (input , [XBLOCK , XBLOCK ], shared_layout )
460
460
461
- h = async_tma_kernel .warmup (input_desc , XBLOCK , shared_layout , grid = (1 , ), num_warps = 4 )
461
+ h = async_tma_kernel .warmup (input_desc , XBLOCK , grid = (1 , ), num_warps = 4 )
462
462
expecttest .assert_expected_inline (
463
463
anonymize_ir (h .asm ["source" ]), """\
464
464
#loc = loc(unknown)
465
465
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
466
466
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
467
467
#smem = #ttg.shared_memory
468
468
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
469
- tt.func public @async_tma_kernel(%arg0: !tt.tensordesc<tensor<128x128xf16>> loc(unknown), %arg1: i32 loc(unknown), %arg2: i32 loc(unknown), %arg3: i64 loc(unknown), %arg4: i64 loc(unknown)) attributes {noinline = false} {
469
+ tt.func public @async_tma_kernel(%arg0: !tt.tensordesc<tensor<128x128xf16, #shared >> loc(unknown), %arg1: i32 loc(unknown), %arg2: i32 loc(unknown), %arg3: i64 loc(unknown), %arg4: i64 loc(unknown)) attributes {noinline = false} {
470
470
%0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
471
471
%1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
472
472
ttng.init_barrier %1, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
473
473
%c0_i32 = arith.constant 0 : i32 loc(#loc)
474
474
%c0_i32_0 = arith.constant 0 : i32 loc(#loc)
475
475
%true = arith.constant true loc(#loc)
476
- ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32_0] %0, %1, %true : !tt.tensordesc<tensor<128x128xf16>>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
476
+ ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32_0] %0, %1, %true : !tt.tensordesc<tensor<128x128xf16, #shared >>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
477
477
%true_1 = arith.constant true loc(#loc)
478
478
ttng.barrier_expect %1, 32768, %true_1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
479
479
%c0_i32_2 = arith.constant 0 : i32 loc(#loc)
@@ -482,7 +482,7 @@ def test_async_tma(fresh_knobs):
482
482
ttng.inval_barrier %1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
483
483
%c0_i32_4 = arith.constant 0 : i32 loc(#loc)
484
484
%c0_i32_5 = arith.constant 0 : i32 loc(#loc)
485
- ttng.async_tma_copy_local_to_global %arg0[%c0_i32_4, %c0_i32_5] %0 : !tt.tensordesc<tensor<128x128xf16>>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
485
+ ttng.async_tma_copy_local_to_global %arg0[%c0_i32_4, %c0_i32_5] %0 : !tt.tensordesc<tensor<128x128xf16, #shared >>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
486
486
ttng.async_tma_store_wait {pendings = 0 : i32} loc(#loc)
487
487
tt.return loc(#loc)
488
488
} loc(#loc)
@@ -491,8 +491,8 @@ def test_async_tma(fresh_knobs):
491
491
492
492
493
493
@gluon .jit
494
- def async_tma_blackwell_kernel (input_desc , XBLOCK : ttgl .constexpr , smem_layout : ttgl . constexpr ):
495
- smem = ttgl .allocate_shared_memory (ttgl .float16 , [XBLOCK , XBLOCK ], smem_layout )
494
+ def async_tma_blackwell_kernel (input_desc , XBLOCK : ttgl .constexpr ):
495
+ smem = ttgl .allocate_shared_memory (ttgl .float16 , [XBLOCK , XBLOCK ], input_desc . layout )
496
496
bar = ttgl .allocate_shared_memory (ttgl .int64 , [1 ], mbarrier .MBarrierLayout ())
497
497
mbarrier .init (bar , count = 1 )
498
498
@@ -514,10 +514,10 @@ def test_async_tma_blackwell(fresh_knobs):
514
514
515
515
input = torch .randn ((1024 , 1024 ), device = "cuda" , dtype = torch .float16 )
516
516
XBLOCK = 128
517
- input_desc = TensorDescriptor .from_tensor (input , [1 , XBLOCK ])
518
517
shared_layout = ttgl .NVMMASharedLayout (swizzle_byte_width = 128 , element_bitwidth = 16 , rank = 2 )
518
+ input_desc = TensorDescriptor .from_tensor (input , [1 , XBLOCK ], shared_layout )
519
519
520
- h = async_tma_blackwell_kernel .warmup (input_desc , XBLOCK , shared_layout , grid = (1 , ), num_warps = 4 )
520
+ h = async_tma_blackwell_kernel .warmup (input_desc , XBLOCK , grid = (1 , ), num_warps = 4 )
521
521
expecttest .assert_expected_inline (
522
522
anonymize_ir (h .asm ["source" ]), """\
523
523
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
@@ -526,22 +526,22 @@ def test_async_tma_blackwell(fresh_knobs):
526
526
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
527
527
#smem = #ttg.shared_memory
528
528
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
529
- tt.func public @async_tma_blackwell_kernel(%arg0: !tt.tensordesc<tensor<1x128xf16>> loc(unknown), %arg1: i32 loc(unknown), %arg2: i32 loc(unknown), %arg3: i64 loc(unknown), %arg4: i64 loc(unknown)) attributes {noinline = false} {
529
+ tt.func public @async_tma_blackwell_kernel(%arg0: !tt.tensordesc<tensor<1x128xf16, #shared >> loc(unknown), %arg1: i32 loc(unknown), %arg2: i32 loc(unknown), %arg3: i64 loc(unknown), %arg4: i64 loc(unknown)) attributes {noinline = false} {
530
530
%0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
531
531
%1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
532
532
ttng.init_barrier %1, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
533
533
%2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
534
534
%true = arith.constant true loc(#loc)
535
535
%c0_i32 = arith.constant 0 : i32 loc(#loc)
536
- ttng.async_tma_gather %arg0[%2, %c0_i32] %0, %1, %true : !tt.tensordesc<tensor<1x128xf16>>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<1xi64, #shared1, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, i1 loc(#loc)
536
+ ttng.async_tma_gather %arg0[%2, %c0_i32] %0, %1, %true : !tt.tensordesc<tensor<1x128xf16, #shared >>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<1xi64, #shared1, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, i1 loc(#loc)
537
537
%true_0 = arith.constant true loc(#loc)
538
538
ttng.barrier_expect %1, 32768, %true_0 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
539
539
%c0_i32_1 = arith.constant 0 : i32 loc(#loc)
540
540
%true_2 = arith.constant true loc(#loc)
541
541
ttng.wait_barrier %1, %c0_i32_1, %true_2 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
542
542
ttng.inval_barrier %1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
543
543
%c0_i32_3 = arith.constant 0 : i32 loc(#loc)
544
- ttng.async_tma_scatter %arg0[%2, %c0_i32_3] %0 : !tt.tensordesc<tensor<1x128xf16>>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
544
+ ttng.async_tma_scatter %arg0[%2, %c0_i32_3] %0 : !tt.tensordesc<tensor<1x128xf16, #shared >>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
545
545
ttng.async_tma_store_wait {pendings = 0 : i32} loc(#loc)
546
546
tt.return loc(#loc)
547
547
} loc(#loc)
0 commit comments