@@ -50,6 +50,7 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride
5050 col_offsets = tl .arange (0 , BLOCK_SIZE )
5151 tl .assume (input_row_stride >= 0 )
5252 tl .assume (output_row_stride >= 0 )
53+ tl .assume (row_start >= 0 )
5354
5455 if USE_BLOCKED :
5556
@@ -61,68 +62,63 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride
6162 # Accumulate sum of squares
6263 sum_squares = tl .zeros ([1 ], dtype = tl .float32 )
6364 n_cols_blks = tl .cdiv (n_cols , BLOCK_SIZE ) - 1
64- for blk_idx in range (n_cols_blks , num_stages = 1 ):
65+ for blk_idx in tl . range (0 , n_cols_blks , num_stages = 2 ):
6566 cols = blk_idx * BLOCK_SIZE + col_offsets
6667 input_ptrs = row_input_ptr + cols
6768 input_ptrs = tl .multiple_of (input_ptrs , (16 , ))
68- x = tl .load (input_ptrs )
69+ x = tl .load (input_ptrs ). to ( tl . float32 )
6970 sum_squares += tl .sum (x * x , axis = 0 )
7071
7172 # Handle remainder
7273 cols = n_cols_blks * BLOCK_SIZE + col_offsets
7374 mask = cols < n_cols
7475 input_ptrs = row_input_ptr + cols
7576 input_ptrs = tl .multiple_of (input_ptrs , (16 , ))
76- x = tl .load (input_ptrs , mask = mask , other = 0.0 , cache_modifier = ".cg" )
77+ x = tl .load (input_ptrs , mask = mask , other = 0.0 , cache_modifier = ".cg" ). to ( tl . float32 )
7778 sum_squares += tl .sum (x * x , axis = 0 )
7879
7980 # Compute normalization factor
8081 mean_square = sum_squares / n_cols
8182 norm_factor = tl .rsqrt (mean_square + epsilon )
8283
8384 # Normalize and write output
84- for blk_idx in range (n_cols_blks , num_stages = 1 ):
85+ for blk_idx in tl . range (0 , n_cols_blks , num_stages = 2 ):
8586 cols = blk_idx * BLOCK_SIZE + col_offsets
8687 input_ptrs = row_input_ptr + cols
8788 input_ptrs = tl .multiple_of (input_ptrs , (16 , ))
88- x = tl .load (input_ptrs )
89+ x = tl .load (input_ptrs ). to ( tl . float32 )
8990 g_ptrs = g_ptr + cols
90- g = tl .load (g_ptrs )
91+ g = tl .load (g_ptrs ). to ( tl . float32 )
9192 rms_norm = x * norm_factor * g
9293 output_ptrs = row_output_ptr + cols
93- tl .store (output_ptrs , rms_norm )
94+ tl .store (output_ptrs , rms_norm . to ( output_ptr . type . element_ty ) )
9495
9596 # Handle remainder
9697 cols = n_cols_blks * BLOCK_SIZE + col_offsets
9798 mask = cols < n_cols
9899 input_ptrs = row_input_ptr + cols
99- x = tl .load (input_ptrs , mask = mask , other = 0.0 , cache_modifier = ".cg" )
100+ x = tl .load (input_ptrs , mask = mask , other = 0.0 , cache_modifier = ".cg" ). to ( tl . float32 )
100101 g_ptrs = g_ptr + cols
101- g = tl .load (g_ptrs , mask = mask , other = 0.0 )
102+ g = tl .load (g_ptrs , mask = mask , other = 0.0 ). to ( tl . float32 )
102103 rms_norm = x * norm_factor * g
103104 output_ptrs = row_output_ptr + cols
104- tl .store (output_ptrs , rms_norm , mask = mask )
105+ tl .store (output_ptrs , rms_norm . to ( output_ptr . type . element_ty ) , mask = mask )
105106
106107 else :
107108 mask = col_offsets < n_cols
108- for row_idx in tl .range (row_start , n_rows , NUM_PRGMS , num_stages = 1 ):
109- row_start_ptr = input_ptr + row_idx * input_row_stride
110- input_ptrs = row_start_ptr + col_offsets
109+ for row_idx in tl .range (row_start , n_rows , NUM_PRGMS , num_stages = 2 ):
110+ input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets
111111 input_ptrs = tl .multiple_of (input_ptrs , (16 , ))
112- row = tl .load (input_ptrs , mask = mask , other = 0.0 , cache_modifier = ".cg" )
113- g = tl .load (g_ptr + col_offsets , mask = mask , other = 0.0 )
112+ row = tl .load (input_ptrs , mask = mask , other = 0.0 , cache_modifier = ".cg" ). to ( tl . float32 )
113+ g = tl .load (g_ptr + col_offsets , mask = mask , other = 0.0 ). to ( tl . float32 )
114114 row_norm = row * row
115115 row_norm = tl .sum (row_norm , axis = - 1 )
116- row_norm = row_norm / n_cols
117- row_norm = row_norm + epsilon
118- row_norm = tl .rsqrt (row_norm )
119- rms_norm = row * row_norm
120- rms_norm = rms_norm * g
121-
122- output_row_start_ptr = output_ptr + row_idx * output_row_stride
123- output_ptrs = output_row_start_ptr + col_offsets
116+ row_norm = tl .math .rsqrt ((row_norm / n_cols ) + epsilon )
117+ rms_norm = row * row_norm * g
118+
119+ output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets
124120 output_ptrs = tl .multiple_of (output_ptrs , (16 , ))
125- tl .store (output_ptrs , rms_norm , mask = mask )
121+ tl .store (output_ptrs , rms_norm . to ( output_ptr . type . element_ty ) , mask = mask )
126122
127123
128124def triton_rmsnorm (x , y , g , n_rows , n_cols , blk_size , USE_BLOCKED , NUM_PRGMS , epsilon = 1e-6 ):
0 commit comments