@@ -41,25 +41,25 @@ def decode_add_rms( # pylint: disable=too-many-locals
4141 annotations = {"pragma_auto_unroll_max_step" : 256 , "pragma_unroll_explicit" : 1 },
4242 ):
4343 for i in range (add_local_size ):
44- with T .block ("T_add" ):
44+ with T .sblock ("T_add" ):
4545 bx = T .axis .spatial (batch_size , v_bx )
4646 h = T .axis .spatial (hidden_size , i * TX + v_tx )
4747 add_local [h // TX ] = A [bx , 0 , h ] + B [bx , 0 , h ]
48- with T .block ("T_write_back" ):
48+ with T .sblock ("T_write_back" ):
4949 bx = T .axis .spatial (batch_size , v_bx )
5050 v_ax1 = T .axis .spatial (1 , 0 )
5151 h = T .axis .spatial (hidden_size , i * TX + v_tx )
5252 add [bx , v_ax1 , h ] = add_local [h // TX ]
53- with T .block ("T_multiply_red_rf_init" ):
53+ with T .sblock ("T_multiply_red_rf_init" ):
5454 tx , bx = T .axis .remap ("SS" , [v_tx , v_bx ])
5555 sum_local [tx , bx , 0 ] = T .float32 (0 )
5656 for v_i , _j in T .grid (add_local_size , 1 ):
57- with T .block ("T_multiply_red_rf_update" ):
57+ with T .sblock ("T_multiply_red_rf_update" ):
5858 tx , bx , i = T .axis .remap ("SSR" , [v_tx , v_bx , v_i ])
5959 sum_local [tx , bx , 0 ] += T .float32 (add_local [i ]) * T .float32 (add_local [i ])
6060 for _j in range (1 ):
6161 for v_tx_2 in T .thread_binding (TX , thread = "threadIdx.x" ):
62- with T .block ("T_multiply_red" ):
62+ with T .sblock ("T_multiply_red" ):
6363 tx , bx = T .axis .remap ("RS" , [v_tx_2 , v_bx ])
6464 T .reads (sum_local [tx , bx , 0 ])
6565 T .writes (sum_shared [bx , 0 ])
@@ -68,7 +68,7 @@ def decode_add_rms( # pylint: disable=too-many-locals
6868 sum_shared [bx , 0 ] += sum_local [tx , bx , 0 ]
6969 for i in range (add_local_size ):
7070 for v_tx_2 in T .thread_binding (TX , thread = "threadIdx.x" ):
71- with T .block ("T_cast_2" ):
71+ with T .sblock ("T_cast_2" ):
7272 bx = T .axis .spatial (batch_size , v_bx )
7373 h = T .axis .spatial (hidden_size , i * TX + v_tx_2 )
7474 O [bx , 0 , h ] = T .cast (
@@ -109,31 +109,31 @@ def prefill_add_rms( # pylint: disable=too-many-locals
109109 annotations = {"pragma_auto_unroll_max_step" : 256 , "pragma_unroll_explicit" : 1 },
110110 ):
111111 for v_i in range (add_local_size ):
112- with T .block ("T_add" ):
112+ with T .sblock ("T_add" ):
113113 bx = T .axis .spatial (seq_len , v_bx )
114114 h = T .axis .spatial (hidden_size , v_i * TX + v_tx )
115115 add_local [h // TX ] = A [0 , bx , h ] + B [0 , bx , h ]
116- with T .block ("T_write_back" ):
116+ with T .sblock ("T_write_back" ):
117117 bx = T .axis .spatial (seq_len , v_bx )
118118 h = T .axis .spatial (hidden_size , v_i * TX + v_tx )
119119 add [0 , bx , h ] = add_local [h // TX ]
120- with T .block ("T_multiply_red_rf_init" ):
120+ with T .sblock ("T_multiply_red_rf_init" ):
121121 tx , bx = T .axis .remap ("SS" , [v_tx , v_bx ])
122122 sum_local [tx , 0 , bx ] = T .float32 (0 )
123123 for v_i , _j in T .grid (add_local_size , 1 ):
124- with T .block ("T_multiply_red_rf_update" ):
124+ with T .sblock ("T_multiply_red_rf_update" ):
125125 tx , bx , i = T .axis .remap ("SSR" , [v_tx , v_bx , v_i ])
126126 sum_local [tx , 0 , bx ] += T .float32 (add_local [i ]) * T .float32 (add_local [i ])
127127 for _j in range (1 ):
128128 for v_tx_2 in T .thread_binding (TX , thread = "threadIdx.x" ):
129- with T .block ("T_multiply_red" ):
129+ with T .sblock ("T_multiply_red" ):
130130 tx , bx = T .axis .remap ("RS" , [v_tx_2 , v_bx ])
131131 with T .init ():
132132 sum_shared [0 , bx ] = T .float32 (0 )
133133 sum_shared [0 , bx ] = sum_shared [0 , bx ] + sum_local [tx , 0 , bx ]
134134 for v_i in range (add_local_size ):
135135 for v_tx_2 in T .thread_binding (TX , thread = "threadIdx.x" ):
136- with T .block ("T_cast_2" ):
136+ with T .sblock ("T_cast_2" ):
137137 bx = T .axis .spatial (seq_len , v_bx )
138138 v1 = T .axis .spatial (hidden_size , v_i * TX + v_tx_2 )
139139 O [0 , bx , v1 ] = T .cast (
0 commit comments