@@ -514,13 +514,12 @@ def fast_dividef_kernel(x_ptr, y_ptr, z_ptr, warp_size: ttgl.constexpr, num_warp
514514 torch .testing .assert_close (z , torch .div (x , y ), atol = 1e-5 , rtol = 1e-4 )
515515
516516
517- @pytest .mark .xfail (reason = "copy to tmem with scale layout is currently broken in Gluon." )
518517@pytest .mark .skipif (not is_blackwell (), reason = "Requires Blackwell" )
519518def test_tmem_copy_2d ():
520519 device = "cuda"
521520
522- smem_h = 256
523- smem_w = 4
521+ smem_h = 64
522+ smem_w = 16
524523 num_rows = 128
525524 num_cols = smem_h * smem_w // 32
526525
@@ -530,13 +529,14 @@ def kernel(in_ptr, out_ptr, smem_h: ttgl.constexpr, smem_w: ttgl.constexpr, num_
530529 in_ptrs = in_ptr + ttgl .arange (0 , smem_h )[:, None ] * smem_w + ttgl .arange (0 , smem_w )[None , :]
531530 out_ptrs = out_ptr + ttgl .arange (0 , num_rows )[:, None ] * num_cols + ttgl .arange (0 , num_cols )[None , :]
532531
533- blocked : ttgl .constexpr = ttgl .BlockedLayout ([1 , 4 ], [32 , 1 ], [4 , 1 ], [0 , 1 ])
532+ blocked : ttgl .constexpr = ttgl .BlockedLayout ([1 , 4 ], [32 , 1 ], [4 , 1 ], [1 , 0 ])
534533 value = ttgl .load (ttgl .set_auto_layout (in_ptrs , blocked ))
535534
536- smem_layout : ttgl .constexpr = ttgl .NVMMASharedLayout (swizzle_byte_width = 0 , element_bitwidth = 8 , rank = 2 )
535+ smem_layout : ttgl .constexpr = ttgl .SharedLinearLayout (
536+ offset_bases = [[0 , 1 ], [0 , 2 ], [32 , 0 ], [0 , 4 ], [1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ], [16 , 0 ], [0 , 8 ]])
537537 tmem_layout : ttgl .constexpr = TensorMemoryScalesLayout ()
538538 smem = ttgl .allocate_shared_memory (ttgl .int8 , (smem_h , smem_w ), layout = smem_layout )
539- tmem = allocate_tensor_memory (ttgl .int8 , (num_rows , num_cols ), layout = tmem_layout )
539+ tmem = allocate_tensor_memory (ttgl .int8 , (smem_h , smem_w ), layout = tmem_layout )
540540
541541 barrier = ttgl .allocate_shared_memory (ttgl .int64 , [1 ], ttgl .constexpr (mbarrier .MBarrierLayout ()))
542542 mbarrier .init (barrier , count = 1 )
@@ -546,22 +546,30 @@ def kernel(in_ptr, out_ptr, smem_h: ttgl.constexpr, smem_w: ttgl.constexpr, num_
546546 tcgen05_copy (smem , tmem )
547547 tcgen05_commit (barrier )
548548 mbarrier .wait (barrier , phase = 0 )
549- tmem_alias : ttgl .constexpr = TensorMemoryLayout ((128 , 32 ), col_stride = 1 )
549+ tmem_alias : ttgl .constexpr = TensorMemoryLayout ((num_rows , num_cols ), col_stride = 1 )
550550 tmem = tmem ._reinterpret (ttgl .int8 , (num_rows , num_cols ), tmem_alias )
551551 value = tmem .load (blocked )
552+ ttgl .static_print (ttgl .to_linear_layout (blocked , (smem_h , smem_w )))
553+ ttgl .static_print (ttgl .to_linear_layout (blocked , (num_rows , num_cols )))
552554 ttgl .store (ttgl .set_auto_layout (out_ptrs , blocked ), value )
553555
556+ torch .manual_seed (0 )
554557 x = torch .randint (size = (smem_h , smem_w ), low = - 100 , high = 100 , dtype = torch .int8 ).to (device )
558+ #x = torch.arange(smem_h * smem_w, dtype=torch.int8, device=device).reshape(smem_h, smem_w)
555559 z_tri = torch .zeros (size = (num_rows , num_cols ), dtype = torch .int8 ).to (device )
556560 kernel [(1 , )](x , z_tri , smem_h , smem_w , num_rows , num_cols )
557561
558- num_rep_m = smem_h // 32
559-
560- for m in range (num_rep_m ):
561- col_offset = m * 4
562- for i in range (4 ):
563- # Copied values are duplicated across warps
564- assert torch .equal (x [m * 32 :(m + 1 ) * 32 ], z_tri [32 * i :32 * (i + 1 ), col_offset :(col_offset + 4 )])
562+ # offset_bases=[[0, 1], [0, 2], [32, 0], [0, 4], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 8]],
563+ # Split into contiguous shmem chunks
564+ x_res = x .reshape (2 , 32 , 2 , 2 , 4 )
565+ # Put tmem cols first then rows
566+ x_res = x_res .permute (1 , 2 , 3 , 0 , 4 )
567+ # Reshape as 32xnum_cols
568+ x_res = x_res .reshape (num_rows // 4 , num_cols )
569+
570+ warps = torch .chunk (z_tri , chunks = 4 , dim = 0 )
571+ for warp in warps :
572+ torch .testing .assert_close (x_res , warp )
565573
566574
567575@pytest .mark .skipif (not is_blackwell (), reason = "Requires Blackwell" )
0 commit comments