@@ -203,7 +203,7 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
203203def test_tensor_descriptor_store3d (dtype_str , K_BLOCK ):
204204
205205 if dtype_str == 'bfloat16' :
206- return pytest .skip ("FIXME: bfloat16 test fails verification " )
206+ return pytest .skip ("FIXME: issue #4137 " )
207207
208208 @triton .jit
209209 def kernel (out_ptr , a_ptr , M , N , K , stride_m , stride_n , stride_k , M_BLOCK : tl .constexpr , N_BLOCK : tl .constexpr ,
@@ -248,6 +248,138 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
248248 torch .testing .assert_close (expect , actual )
249249
250250
251+ @pytest .mark .parametrize ("dtype_str" , tma_dtypes )
252+ @pytest .mark .parametrize ("num_ctas" , [1 ])
253+ @pytest .mark .parametrize ("ndim" , [1 , 2 , 3 , 4 , 5 ])
254+ @pytest .mark .parametrize ("INNER_BLOCK" , [16 , 32 , 64 , 128 ])
255+ def test_tensor_descriptor_load_nd (dtype_str , num_ctas , ndim , INNER_BLOCK ):
256+
257+ if ndim not in [1 ] or dtype_str not in ["uint16" , "uint32" ]:
258+ return pytest .skip ("FIXME: issue #4139" )
259+
260+ @triton .jit
261+ def kernel (out_ptr , a_ptr , shape , strides , BLOCK_SHAPE ):
262+ desc = tl .make_tensor_descriptor (
263+ a_ptr ,
264+ shape = shape ,
265+ strides = strides ,
266+ block_shape = BLOCK_SHAPE ,
267+ )
268+ ndim : tl .constexpr = len (BLOCK_SHAPE )
269+
270+ offs = (0 , ) * ndim
271+ block = desc .load (offs )
272+
273+ idx = tl .full (BLOCK_SHAPE , 0 , tl .int32 )
274+ stride = 1
275+ for k in tl .static_range (ndim - 1 , - 1 , - 1 ):
276+ arange = tl .arange (0 , BLOCK_SHAPE [k ])
277+ for _ in tl .static_range (k ):
278+ arange = tl .expand_dims (arange , 0 )
279+ for _ in tl .static_range (k + 1 , ndim ):
280+ arange = tl .expand_dims (arange , - 1 )
281+
282+ idx += arange * stride
283+ stride *= BLOCK_SHAPE [k ]
284+
285+ tl .store (out_ptr + idx , block )
286+
287+ def alloc_fn (size : int , align : int , stream : Optional [int ]):
288+ return torch .empty (size , dtype = torch .int8 , device = "xpu" )
289+
290+ triton .set_allocator (alloc_fn )
291+
292+ alloc_shape = (1 , 1 , 3 , 7 , INNER_BLOCK )[- ndim :]
293+ inp = to_triton (numpy_random (alloc_shape , dtype_str ), device = "xpu" , dst_type = dtype_str )
294+ inp .data = inp .data [..., :INNER_BLOCK - 3 ]
295+
296+ if INNER_BLOCK * inp .element_size () < 32 :
297+ return pytest .xfail ("Invalid last dim size" )
298+
299+ BLOCK_SHAPE = (2 , 2 , 4 , 8 , INNER_BLOCK )[- ndim :]
300+ out = inp .new_empty (BLOCK_SHAPE )
301+
302+ constexpr_block_shape = tuple (tl .constexpr (v ) for v in BLOCK_SHAPE )
303+ kernel [(1 , )](out , inp , inp .shape , inp .stride (), constexpr_block_shape , num_ctas = num_ctas )
304+
305+ # Check in-bounds
306+ actual = unwrap_tensor (out )
307+ expect = unwrap_tensor (inp )
308+ idx = [slice (None , s ) for s in inp .shape ]
309+ torch .testing .assert_close (expect , actual [idx ])
310+
311+ # Check out-of-bounds
312+ actual [idx ].zero_ ()
313+ expect = expect .new_zeros (BLOCK_SHAPE )
314+ torch .testing .assert_close (expect , actual )
315+
316+
317+ @pytest .mark .parametrize ("dtype_str" , tma_dtypes )
318+ @pytest .mark .parametrize ("num_ctas" , [1 ])
319+ @pytest .mark .parametrize ("ndim" , [1 , 2 , 3 , 4 , 5 ])
320+ @pytest .mark .parametrize ("INNER_BLOCK" , [16 , 32 , 64 , 128 ])
321+ def test_tensor_descriptor_store_nd (dtype_str , num_ctas , ndim , INNER_BLOCK ):
322+
323+ if ndim not in [1 ]:
324+ return pytest .skip ("FIXME: issue #4140" )
325+
326+ @triton .jit
327+ def kernel (out_ptr , a_ptr , shape , strides , BLOCK_SHAPE ):
328+ desc = tl .make_tensor_descriptor (
329+ out_ptr ,
330+ shape = shape ,
331+ strides = strides ,
332+ block_shape = BLOCK_SHAPE ,
333+ )
334+ ndim : tl .constexpr = len (BLOCK_SHAPE )
335+
336+ idx = tl .full (BLOCK_SHAPE , 0 , tl .int32 )
337+ stride = 1
338+ for k in tl .static_range (ndim - 1 , - 1 , - 1 ):
339+ arange = tl .arange (0 , BLOCK_SHAPE [k ])
340+ for _ in tl .static_range (k ):
341+ arange = tl .expand_dims (arange , 0 )
342+ for _ in tl .static_range (k + 1 , ndim ):
343+ arange = tl .expand_dims (arange , - 1 )
344+
345+ idx += arange * stride
346+ stride *= BLOCK_SHAPE [k ]
347+
348+ block = tl .load (a_ptr + idx )
349+
350+ offs = (0 , ) * ndim
351+ desc .store (offs , block )
352+
353+ def alloc_fn (size : int , align : int , stream : Optional [int ]):
354+ return torch .empty (size , dtype = torch .int8 , device = "xpu" )
355+
356+ triton .set_allocator (alloc_fn )
357+
358+ BLOCK_SHAPE = (2 , 2 , 4 , 8 , INNER_BLOCK )[- ndim :]
359+ inp = to_triton (numpy_random (BLOCK_SHAPE , dtype_str ), device = "xpu" , dst_type = dtype_str )
360+
361+ if INNER_BLOCK * inp .element_size () < 32 :
362+ return pytest .xfail ("Invalid last dim size" )
363+
364+ out = inp .new_empty (BLOCK_SHAPE )
365+ out .data .fill_ (- 1 )
366+
367+ desc_shape = (1 , 1 , 3 , 7 , INNER_BLOCK )[- ndim :]
368+ constexpr_block_shape = tuple (tl .constexpr (v ) for v in BLOCK_SHAPE )
369+ kernel [(1 , )](out , inp , desc_shape , out .stride (), constexpr_block_shape , num_ctas = num_ctas )
370+
371+ # Check in-bounds
372+ actual = unwrap_tensor (out )
373+ expect = unwrap_tensor (inp )
374+ idx = [slice (None , s ) for s in desc_shape ]
375+ torch .testing .assert_close (expect [idx ], actual [idx ])
376+
377+ # Check out-of-bounds
378+ actual [idx ].fill_ (- 1 )
379+ expect = expect .new_full (BLOCK_SHAPE , - 1 )
380+ torch .testing .assert_close (expect , actual )
381+
382+
251383@triton .jit (noinline = False )
252384def tensor_descriptor_in_function_helper (out_ptr , in_ptr , M , N , M_BLOCK : tl .constexpr , N_BLOCK : tl .constexpr ):
253385 in_desc = tl .make_tensor_descriptor (
@@ -465,6 +597,186 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
465597 torch .testing .assert_close (ref_out , C , rtol = 1e-3 , atol = 1e-3 )
466598
467599
600+ @triton .jit
601+ def kernel_make_tensor_descriptor_loop_carried (a_ptr , M , N , MBLOCK : tl .constexpr , NBLOCK : tl .constexpr ):
602+ # Test that descriptors work with
603+ pid = tl .program_id (0 )
604+ moffset = MBLOCK * pid
605+
606+ a_desc = tl .make_tensor_descriptor (
607+ a_ptr ,
608+ shape = [M , N ],
609+ strides = [N , 1 ],
610+ block_shape = [MBLOCK , NBLOCK ],
611+ )
612+
613+ for i in range (0 , N , NBLOCK ):
614+ assert isinstance (a_desc , tl .tensor_descriptor )
615+ if i % (3 * NBLOCK ) == 0 :
616+ a_desc = tl .make_tensor_descriptor (
617+ a_ptr ,
618+ shape = [M , N ],
619+ strides = [N , 1 ],
620+ block_shape = [MBLOCK , NBLOCK ],
621+ )
622+ assert isinstance (a_desc , tl .tensor_descriptor )
623+ assert isinstance (a_desc , tl .tensor_descriptor )
624+ a = a_desc .load ([moffset , i ])
625+ a_desc .store ([moffset , i ], a + 10 )
626+
627+ n = 0
628+ while n < N :
629+ assert isinstance (a_desc , tl .tensor_descriptor )
630+ if n % (3 * NBLOCK ) == 0 :
631+ assert isinstance (a_desc , tl .tensor_descriptor )
632+ a_desc = tl .make_tensor_descriptor (
633+ a_ptr ,
634+ shape = [M , N ],
635+ strides = [N , 1 ],
636+ block_shape = [MBLOCK , NBLOCK ],
637+ )
638+ assert isinstance (a_desc , tl .tensor_descriptor )
639+ a = a_desc .load ([moffset , n ])
640+ a_desc .store ([moffset , n ], a + 5 )
641+
642+ n += NBLOCK
643+
644+
645+ @pytest .mark .interpreter
646+ def test_make_tensor_descriptor_loop_carried ():
647+ return pytest .skip ("FIXME: issue #4132" )
648+
649+ device = "xpu"
650+ M , N = 64 , 512
651+ torch .manual_seed (42 )
652+ A = torch .randn ((M , N ), dtype = torch .float32 , device = device )
653+ MBLOCK , NBLOCK = 8 , 128
654+ grid = (triton .cdiv (M , MBLOCK ), )
655+
656+ def alloc_fn (size : int , align : int , stream : Optional [int ]):
657+ assert size == 128 * grid [0 ]
658+ assert align == 128
659+ assert stream == 0
660+ return torch .empty (size , dtype = torch .int8 , device = "xpu" )
661+
662+ triton .set_allocator (alloc_fn )
663+
664+ ref_out = A + 15
665+ kernel_make_tensor_descriptor_loop_carried [grid ](
666+ A ,
667+ M ,
668+ N ,
669+ MBLOCK ,
670+ NBLOCK ,
671+ )
672+ torch .testing .assert_close (ref_out , A )
673+
674+
675+ @triton .jit
676+ def batched_gemm_2d_tma_kernel (a_ptr , b_ptr , c_ptr , #
677+ B , M , N , K , #
678+ dtype : tl .constexpr , #
679+ BLOCK_M : tl .constexpr , BLOCK_N : tl .constexpr , BLOCK_K : tl .constexpr , #
680+ NUM_SMS : tl .constexpr ):
681+ start_pid = tl .program_id (axis = 0 )
682+ num_tiles_m = tl .cdiv (M , BLOCK_M )
683+ num_tiles_n = tl .cdiv (N , BLOCK_N )
684+ k_tiles = tl .cdiv (K , BLOCK_K )
685+ num_tiles_per_batch = num_tiles_m * num_tiles_n
686+ num_tiles = B * num_tiles_per_batch
687+
688+ tiles_per_SM = num_tiles // NUM_SMS
689+ if start_pid < num_tiles % NUM_SMS :
690+ tiles_per_SM += 1
691+
692+ tile_id = start_pid - NUM_SMS
693+ ki = - 1
694+
695+ tile_m = 0
696+ tile_n = 0
697+ tile_b = 0
698+
699+ offs_m = 0
700+ offs_n = 0
701+ offs_b = 0
702+
703+ a_desc = tl .make_tensor_descriptor (a_ptr + offs_b * (M * K ), [M , K ], [K , 1 ], [BLOCK_M , BLOCK_K ])
704+ b_desc = tl .make_tensor_descriptor (b_ptr + offs_b * (N * K ), [N , K ], [K , 1 ], [BLOCK_N , BLOCK_K ])
705+ c_desc = tl .make_tensor_descriptor (c_ptr + offs_b * (M * N ), [M , N ], [N , 1 ], [BLOCK_M , BLOCK_N ])
706+
707+ accumulator = tl .zeros ((BLOCK_M , BLOCK_N ), dtype = tl .float32 )
708+
709+ for _ in range (k_tiles * tiles_per_SM ):
710+ ki = tl .where (ki == k_tiles - 1 , 0 , ki + 1 )
711+ if ki == 0 :
712+ tile_id += NUM_SMS
713+ tile_b = tile_id // num_tiles_per_batch
714+ tile_m = (tile_id // num_tiles_n ) % num_tiles_m
715+ tile_n = tile_id % num_tiles_n
716+
717+ offs_b = tile_b
718+ offs_m = tile_m * BLOCK_M
719+ offs_n = tile_n * BLOCK_N
720+
721+ a_desc = tl .make_tensor_descriptor (a_ptr + offs_b * (M * K ), [M , K ], [K , 1 ], [BLOCK_M , BLOCK_K ])
722+ b_desc = tl .make_tensor_descriptor (b_ptr + offs_b * (N * K ), [N , K ], [K , 1 ], [BLOCK_N , BLOCK_K ])
723+ c_desc = tl .make_tensor_descriptor (c_ptr + offs_b * (M * N ), [M , N ], [N , 1 ], [BLOCK_M , BLOCK_N ])
724+
725+ offs_k = ki * BLOCK_K
726+
727+ a = a_desc .load ([offs_m , offs_k ])
728+ b = b_desc .load ([offs_n , offs_k ])
729+ accumulator = tl .dot (a , b .T , accumulator )
730+
731+ if ki == k_tiles - 1 :
732+ c = accumulator .to (dtype )
733+
734+ c_desc .store ([offs_m , offs_n ], c )
735+ accumulator = tl .zeros ((BLOCK_M , BLOCK_N ), dtype = tl .float32 )
736+
737+
738+ @pytest .mark .interpreter
739+ def test_tensor_descriptor_batched_gemm_2d_tma ():
740+ return pytest .skip ("FIXME: issue #4132" )
741+
742+ device = "xpu"
743+ BLOCK_M , BLOCK_N , BLOCK_K = 128 , 256 , 64
744+ if is_interpreter ():
745+ B , M , N , K = 2 , BLOCK_M , BLOCK_N , BLOCK_K
746+ else :
747+ B , M , N , K = 2 , 1024 , 1024 , 128
748+ NUM_SMS = 96
749+ num_stages = 3
750+
751+ grid = (min (NUM_SMS , B * triton .cdiv (M , BLOCK_M ) * triton .cdiv (N , BLOCK_N )), )
752+
753+ a = torch .randn ((B , M , K ), device = device , dtype = torch .float16 )
754+ b = torch .randn ((B , N , K ), device = device , dtype = torch .float16 )
755+ c = torch .empty ((B , M , N ), device = device , dtype = torch .float16 )
756+
757+ expect = torch .bmm (a , b .mT )
758+
759+ def alloc_fn (size : int , align : int , stream : Optional [int ]):
760+ # TODO: should only need num_stages * 3 descriptors per SM
761+ assert size == 128 * 3 * (num_stages + 1 ) * grid [0 ]
762+ assert align == 128
763+ assert stream == 0
764+ return torch .empty (size , dtype = torch .int8 , device = "xpu" )
765+
766+ triton .set_allocator (alloc_fn )
767+
768+ batched_gemm_2d_tma_kernel [grid ](
769+ a , b , c , #
770+ B , M , N , K , #
771+ tl .float16 , #
772+ BLOCK_M , BLOCK_N , BLOCK_K , #
773+ NUM_SMS , #
774+ num_stages = num_stages , num_warps = 8 )
775+ torch .xpu .synchronize ()
776+
777+ torch .testing .assert_close (c , expect , rtol = 1e-3 , atol = 1e-3 )
778+
779+
468780@triton .jit
469781def batched_gemm_3d_tma_kernel (a_ptr , b_ptr , c_ptr , #
470782 B , M , N , K , #
0 commit comments