@@ -365,9 +365,9 @@ def torch_gemm_mxfp(a, b, a_scale, b_scale, scale_block, M, N, K):
365365
366366@gluon .jit
367367def tensor_copy_kernel (a_ptr , b_ptr , M , N , #
368- BLOCK_M : ttgl .constexpr , BLOCK_N : ttgl .constexpr , NUM_BUFFERS : ttgl .constexpr ):
368+ BLOCK_M : ttgl .constexpr , BLOCK_N : ttgl .constexpr , NUM_BUFFERS : ttgl .constexpr ,
369+ BLOCKED_LAYOUT : ttgl .constexpr ):
369370 SHARED_LAYOUT : ttgl .constexpr = ttgl .PaddedSharedLayout .with_identity_for ([[32 , 4 ]], [BLOCK_M , BLOCK_N ], [1 , 0 ])
370- BLOCKED_LAYOUT : ttgl .constexpr = ttgl .BlockedLayout ([1 , 8 ], [4 , 8 ], [4 , 1 ], [1 , 0 ])
371371
372372 pid = ttgl .program_id (axis = 0 )
373373 num_pid_m = ttgl .cdiv (M , BLOCK_M )
@@ -400,31 +400,38 @@ def tensor_copy_kernel(a_ptr, b_ptr, M, N, #
400400@pytest .mark .parametrize ("BLOCK_M,BLOCK_N" , [(32 , 32 ), (32 , 64 ), (64 , 64 )])
401401@pytest .mark .parametrize ("NUM_BUFFERS" , [1 , 2 ])
402402def test_compile_tensor_copy (BLOCK_M , BLOCK_N , NUM_BUFFERS ):
403+ BLOCKED_LAYOUT = ttgl .BlockedLayout ([1 , 8 ], [4 , 8 ], [4 , 1 ], [1 , 0 ])
403404 k = triton .compile (
404405 gluon ._runtime .GluonASTSource (
405406 fn = tensor_copy_kernel , signature = {
406407 "a_ptr" : "*fp16" , "b_ptr" : "*fp16" , "M" : "i32" , "N" : "i32" , #
407- "BLOCK_M" : "constexpr" , "BLOCK_N" : "constexpr" , "NUM_BUFFERS" : "constexpr"
408- }, constexprs = {"BLOCK_M" : BLOCK_M , "BLOCK_N" : BLOCK_N , "NUM_BUFFERS" : NUM_BUFFERS }),
409- target = GPUTarget ("hip" , 'gfx1250' , 32 ))
408+ "BLOCK_M" : "constexpr" , "BLOCK_N" : "constexpr" , "NUM_BUFFERS" : "constexpr" , #
409+ "BLOCKED_LAYOUT" : "constexpr"
410+ }, constexprs = {
411+ "BLOCK_M" : BLOCK_M , "BLOCK_N" : BLOCK_N , "NUM_BUFFERS" : NUM_BUFFERS , "BLOCKED_LAYOUT" : BLOCKED_LAYOUT
412+ }), target = GPUTarget ("hip" , 'gfx1250' , 32 ))
410413
411414 amdgcn = k .asm ["amdgcn" ]
412415 for pattern in ("tensor_load_to_lds" , "s_wait_tensorcnt 0x0" ):
413416 assert re .search (pattern , amdgcn )
414417
415418
416- @pytest .mark .parametrize ("BLOCK_M,BLOCK_N" , [(32 , 32 ), (32 , 64 ), (64 , 64 )])
419+ @pytest .mark .parametrize ("BLOCK_M,BLOCK_N" , [(32 , 32 ), (32 , 64 ), (64 , 64 ), ( 1 , 512 ), ( 256 , 2 ) ])
417420@pytest .mark .parametrize ("NUM_BUFFERS" , [1 , 2 ])
421+ @pytest .mark .parametrize ("NUM_WARPS" , [4 , 8 ])
418422@pytest .mark .parametrize ("M,N" , [(1024 , 1024 ), (1000 , 1000 )])
419- def test_runtime_tensor_copy (M , N , BLOCK_M , BLOCK_N , NUM_BUFFERS ):
423+ def test_runtime_tensor_copy (M , N , BLOCK_M , BLOCK_N , NUM_BUFFERS , NUM_WARPS ):
424+ blocked_layout = ttgl .BlockedLayout ([1 , 8 ], [4 , 8 ], [NUM_WARPS , 1 ], [1 , 0 ])
425+
420426 torch .manual_seed (42 )
421427 a = torch .randint (0x0 , 0xFFFF , (M , N ), dtype = torch .uint16 )
422428 b = torch .zeros_like (a )
423429
424430 a_device = a .cuda ()
425431 b_device = b .cuda ()
426432 grid = (triton .cdiv (M , BLOCK_M ) * triton .cdiv (N , BLOCK_N * NUM_BUFFERS ), 1 )
427- tensor_copy_kernel [grid ](a_device , b_device , M , N , BLOCK_M = BLOCK_M , BLOCK_N = BLOCK_N , NUM_BUFFERS = NUM_BUFFERS )
433+ tensor_copy_kernel [grid ](a_device , b_device , M , N , BLOCK_M = BLOCK_M , BLOCK_N = BLOCK_N , NUM_BUFFERS = NUM_BUFFERS ,
434+ BLOCKED_LAYOUT = blocked_layout , num_warps = NUM_WARPS )
428435
429436 b_triton = b_device .cpu ()
430437 assert torch .equal (b_triton , a )
0 commit comments