@@ -460,3 +460,77 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
460460 assert "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned" in kernel .asm ["ptx" ]
461461 if BLOCK_M >= 64 and BLOCK_N >= 64 :
462462 assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel .asm ["ptx" ]
463+
464+
465+ @triton .jit
466+ def kernel_make_tensor_desciptor_loop_carried (a_ptr , M , N , MBLOCK : tl .constexpr , NBLOCK : tl .constexpr ):
467+ # Test that descriptors work with
468+ pid = tl .program_id (0 )
469+ moffset = MBLOCK * pid
470+
471+ a_desc = tl ._experimental_make_tensor_descriptor (
472+ a_ptr ,
473+ shape = [M , N ],
474+ strides = [N , 1 ],
475+ block_shape = [MBLOCK , NBLOCK ],
476+ )
477+
478+ for i in range (0 , N , NBLOCK ):
479+ assert isinstance (a_desc , tl ._experimental_tensor_descriptor )
480+ if i % (3 * NBLOCK ) == 0 :
481+ a_desc = tl ._experimental_make_tensor_descriptor (
482+ a_ptr ,
483+ shape = [M , N ],
484+ strides = [N , 1 ],
485+ block_shape = [MBLOCK , NBLOCK ],
486+ )
487+ assert isinstance (a_desc , tl ._experimental_tensor_descriptor )
488+ assert isinstance (a_desc , tl ._experimental_tensor_descriptor )
489+ a = a_desc .load ([moffset , i ])
490+ a_desc .store ([moffset , i ], a + 10 )
491+
492+ n = 0
493+ while n < N :
494+ assert isinstance (a_desc , tl ._experimental_tensor_descriptor )
495+ if n % (3 * NBLOCK ) == 0 :
496+ assert isinstance (a_desc , tl ._experimental_tensor_descriptor )
497+ a_desc = tl ._experimental_make_tensor_descriptor (
498+ a_ptr ,
499+ shape = [M , N ],
500+ strides = [N , 1 ],
501+ block_shape = [MBLOCK , NBLOCK ],
502+ )
503+ assert isinstance (a_desc , tl ._experimental_tensor_descriptor )
504+ a = a_desc .load ([moffset , n ])
505+ a_desc .store ([moffset , n ], a + 5 )
506+
507+ n += NBLOCK
508+
509+
510+ @requires_tma
511+ def test_experimental_make_tensor_descriptor_loop_carried ():
512+ device = "cuda"
513+ M , N = 8192 , 8192
514+ torch .manual_seed (42 )
515+ A = torch .randn ((M , N ), dtype = torch .float32 , device = device )
516+ MBLOCK , NBLOCK = 8 , 128
517+ grid = (triton .cdiv (M , MBLOCK ), )
518+
519+ def alloc_fn (size : int , align : int , stream : Optional [int ]):
520+ assert size == 128 * grid [0 ]
521+ assert align == 128
522+ assert stream == 0
523+ return torch .empty (size , dtype = torch .int8 , device = "cuda" )
524+
525+ triton .set_allocator (alloc_fn )
526+
527+ ref_out = A + 15
528+ kernel = kernel_make_tensor_desciptor_loop_carried [grid ](
529+ A ,
530+ M ,
531+ N ,
532+ MBLOCK ,
533+ NBLOCK ,
534+ )
535+ torch .testing .assert_close (ref_out , A )
536+ assert "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned" in kernel .asm ["ptx" ]
0 commit comments