@@ -116,7 +116,7 @@ def kernel(s_ptr, out_ptr):
116
116
layout : ttgl .constexpr = get_tmem_32x32b_reg_layout (BLOCK_M , BLOCK_N , (BLOCK_M , N ), num_warps = 4 )
117
117
118
118
offsets = ttgl .arange (0 , BLOCK_M )[:, None ] * N + ttgl .arange (0 , N )[None , :]
119
- offsets = ttgl .convert_layout (offsets , layout )
119
+ offsets = ttgl .set_auto_layout (offsets , layout )
120
120
s = ttgl .load (s_ptr + offsets )
121
121
122
122
s_tmem .store (s )
@@ -194,8 +194,8 @@ def kernel(a_ptr, b_ptr, c_ptr, d_ptr):
194
194
195
195
a_layout : ttgl .constexpr = get_tmem_32x32b_reg_layout (BLOCK_M , BLOCK_N , (BLOCK_M , N ), num_warps = 4 )
196
196
b_layout : ttgl .constexpr = ttgl .BlockedLayout ([1 , 1 ], [1 , 32 ], [4 , 1 ], [1 , 0 ])
197
- a_offsets = ttgl .convert_layout (a_offsets , a_layout )
198
- b_offsets = ttgl .convert_layout (b_offsets , b_layout )
197
+ a_offsets = ttgl .set_auto_layout (a_offsets , a_layout )
198
+ b_offsets = ttgl .set_auto_layout (b_offsets , b_layout )
199
199
200
200
a = ttgl .load (a_ptr + a_offsets )
201
201
b = ttgl .load (b_ptr + b_offsets )
0 commit comments