@@ -383,7 +383,9 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
383383
384384
385385@pytest .mark .interpreter
386- def test_tensor_descriptor_padding ():
386+ def test_tensor_descriptor_padding (device ):
387+ if not is_cuda ():
388+ pytest .xfail ("padding is unsupported" )
387389
388390 @triton .jit
389391 def device_tma_load (in_ptr , out_ptr , IM , IN , YM , YN , M_BLOCK : tl .constexpr , N_BLOCK : tl .constexpr ,
@@ -414,7 +416,7 @@ def host_tma_load(in_desc, out_ptr, YM, YN, M_BLOCK: tl.constexpr, N_BLOCK: tl.c
414416
415417 # TMA descriptors require a global memory allocation
416418 def alloc_fn (size : int , alignment : float , stream : float ):
417- return torch .ones (size , device = "cuda" , dtype = torch .float32 )
419+ return torch .ones (size , device = device , dtype = torch .float32 )
418420
419421 triton .set_allocator (alloc_fn )
420422
@@ -423,16 +425,16 @@ def alloc_fn(size: int, alignment: float, stream: float):
423425 M_BLOCK = 32
424426 N_BLOCK = 32
425427 padding = "nan"
426- input = torch .arange (IM * IN , device = "cuda" , dtype = torch .float32 )
428+ input = torch .arange (IM * IN , device = device , dtype = torch .float32 )
427429 input = input .reshape (IM , IN )
428- out_device_tma = torch .zeros ((OM , ON ), device = "cuda" , dtype = torch .float32 )
429- out_host_tma = torch .zeros ((OM , ON ), device = "cuda" , dtype = torch .float32 )
430+ out_device_tma = torch .zeros ((OM , ON ), device = device , dtype = torch .float32 )
431+ out_host_tma = torch .zeros ((OM , ON ), device = device , dtype = torch .float32 )
430432 dummy_block = [M_BLOCK , N_BLOCK ]
431433 in_desc = TensorDescriptor (input , input .shape , input .stride (), dummy_block , padding = padding )
432434 grid = (triton .cdiv (OM , M_BLOCK ), triton .cdiv (ON , N_BLOCK ))
433435 device_tma_load [grid ](input , out_device_tma , IM , IN , OM , ON , M_BLOCK , N_BLOCK , padding )
434436 host_tma_load [grid ](in_desc , out_host_tma , OM , ON , M_BLOCK , N_BLOCK )
435- expected = torch .zeros ((OM , ON ), device = "cuda" , dtype = torch .float32 )
437+ expected = torch .zeros ((OM , ON ), device = device , dtype = torch .float32 )
436438 expected [0 :IN , 0 :IM ] = input
437439 expected [:, IN :ON ] = float ('nan' )
438440 expected [IM :OM , :] = float ('nan' )
0 commit comments