@@ -16,6 +16,13 @@ def is_hip():
1616 return triton .runtime .driver .active .get_current_target ().backend == "hip"
1717
1818
19+ def get_num_sms ():
20+ current_device_index = torch .cuda .current_device ()
21+ current_device = torch .cuda .get_device_properties (current_device_index )
22+ num_sms = current_device .multi_processor_count
23+ return num_sms
24+
25+
1926def get_cuda_autotune_config ():
2027 return [
2128 triton .Config ({}, num_warps = 4 , num_stages = 1 ),
@@ -25,9 +32,7 @@ def get_cuda_autotune_config():
2532
2633
2734def get_hip_autotune_config ():
28- return [
29- triton .Config ({'waves_per_eu' : we }, num_warps = nw , num_stages = 2 ) for (we , nw ) in product ([0 , 1 , 2 , 4 ], [8 , 16 ])
30- ]
35+ return [triton .Config ({'waves_per_eu' : we }, num_warps = nw ) for (we , nw ) in product ([0 , 1 , 2 , 4 ], [4 , 8 , 16 ])]
3136
3237
3338def get_autotune_config ():
@@ -37,102 +42,92 @@ def get_autotune_config():
3742 return get_hip_autotune_config ()
3843
3944
40- # accumulate sum of squares for a row in a blocked manner
41- @triton .jit
42- def accumulate_sum_squares (input_ptr , input_row_stride , n_cols , BLOCK_SIZE , row_idx ):
43- col_offsets = tl .arange (0 , BLOCK_SIZE )
44- sum_squares = tl .zeros ([1 ], dtype = tl .float32 )
45- row_input_ptr = input_ptr + row_idx * input_row_stride
46-
47- n_cols_blks = tl .cdiv (n_cols , BLOCK_SIZE ) - 1
48- for start in range (0 , n_cols_blks * BLOCK_SIZE , BLOCK_SIZE ):
49- cols = start + col_offsets
50- input_ptrs = row_input_ptr + cols
51- input_ptrs = tl .multiple_of (input_ptrs , (16 , ))
52- x = tl .load (input_ptrs )
53- sum_squares += tl .sum (x * x , axis = 0 )
54-
55- # loop peeling for mask
56- cols = n_cols_blks * BLOCK_SIZE + col_offsets
57- mask = cols < n_cols
58- input_ptrs = row_input_ptr + cols
59- input_ptrs = tl .multiple_of (input_ptrs , (16 , ))
60- x = tl .load (input_ptrs , mask = mask , other = 0.0 , cache_modifier = ".cg" )
61- sum_squares += tl .sum (x * x , axis = 0 )
62-
63- return sum_squares
64-
65-
66- # apply normalization to each block of the row
67- @triton .jit
68- def apply_normalization (input_ptr , output_ptr , g_ptr , input_row_stride , output_row_stride , n_cols , norm_factor ,
69- BLOCK_SIZE , row_idx ):
70- col_offsets = tl .arange (0 , BLOCK_SIZE )
71- row_input_ptr = input_ptr + row_idx * input_row_stride
72- row_output_ptr = output_ptr + row_idx * output_row_stride
73-
74- for start in range (0 , n_cols , BLOCK_SIZE ):
75- cols = start + col_offsets
76- mask = cols < n_cols
77- input_ptrs = row_input_ptr + cols
78- input_ptrs = tl .multiple_of (input_ptrs , (16 , ))
79- g_ptrs = g_ptr + cols
80- output_ptrs = row_output_ptr + cols
81- x = tl .load (input_ptrs , mask = mask , other = 0.0 , cache_modifier = ".cg" )
82- g = tl .load (g_ptrs , mask = mask , other = 0.0 )
83- rms_norm = x * norm_factor * g
84- tl .store (output_ptrs , rms_norm , mask = mask )
85-
86-
87- # Main kernel with both blocked and non-blocked versions based on BLOCK_SIZE
8845@triton .autotune (configs = get_autotune_config (), key = ['n_rows' , 'n_cols' ], use_cuda_graph = True )
8946@triton .jit
9047def rms_kernel (output_ptr , input_ptr , g_ptr , input_row_stride , output_row_stride , n_rows , n_cols , epsilon ,
9148 BLOCK_SIZE : tl .constexpr , USE_BLOCKED : tl .constexpr , NUM_PRGMS : tl .constexpr ):
92- row_idx = tl .program_id (0 ) # Each program instance handles one row
49+ row_start = tl .program_id (0 )
9350 col_offsets = tl .arange (0 , BLOCK_SIZE )
51+ tl .assume (input_row_stride >= 0 )
52+ tl .assume (output_row_stride >= 0 )
9453
9554 if USE_BLOCKED :
96- # Blocked Approach: Accumulate sum of squares and normalize in chunks
97- sum_squares = accumulate_sum_squares (input_ptr , input_row_stride , n_cols , BLOCK_SIZE , row_idx )
98- mean_square = sum_squares / n_cols
99- norm_factor = tl .rsqrt (mean_square + epsilon )
10055
101- # Apply normalization
102- apply_normalization (input_ptr , output_ptr , g_ptr , input_row_stride , output_row_stride , n_cols , norm_factor ,
103- BLOCK_SIZE , row_idx )
56+ # Persistent loop for rows
57+ for row_idx in tl .range (row_start , n_rows , NUM_PRGMS , num_stages = 1 ):
58+ row_input_ptr = input_ptr + row_idx * input_row_stride
59+ row_output_ptr = output_ptr + row_idx * output_row_stride
60+
61+ # Accumulate sum of squares
62+ sum_squares = tl .zeros ([1 ], dtype = tl .float32 )
63+ n_cols_blks = tl .cdiv (n_cols , BLOCK_SIZE ) - 1
64+ for blk_idx in range (n_cols_blks , num_stages = 1 ):
65+ cols = blk_idx * BLOCK_SIZE + col_offsets
66+ input_ptrs = row_input_ptr + cols
67+ input_ptrs = tl .multiple_of (input_ptrs , (16 , ))
68+ x = tl .load (input_ptrs )
69+ sum_squares += tl .sum (x * x , axis = 0 )
70+
71+ # Handle remainder
72+ cols = n_cols_blks * BLOCK_SIZE + col_offsets
73+ mask = cols < n_cols
74+ input_ptrs = row_input_ptr + cols
75+ input_ptrs = tl .multiple_of (input_ptrs , (16 , ))
76+ x = tl .load (input_ptrs , mask = mask , other = 0.0 , cache_modifier = ".cg" )
77+ sum_squares += tl .sum (x * x , axis = 0 )
78+
79+ # Compute normalization factor
80+ mean_square = sum_squares / n_cols
81+ norm_factor = tl .rsqrt (mean_square + epsilon )
82+
83+ # Normalize and write output
84+ for blk_idx in range (n_cols_blks , num_stages = 1 ):
85+ cols = blk_idx * BLOCK_SIZE + col_offsets
86+ input_ptrs = row_input_ptr + cols
87+ input_ptrs = tl .multiple_of (input_ptrs , (16 , ))
88+ x = tl .load (input_ptrs )
89+ g_ptrs = g_ptr + cols
90+ g = tl .load (g_ptrs )
91+ rms_norm = x * norm_factor * g
92+ output_ptrs = row_output_ptr + cols
93+ tl .store (output_ptrs , rms_norm )
94+
95+ # Handle remainder
96+ cols = n_cols_blks * BLOCK_SIZE + col_offsets
97+ mask = cols < n_cols
98+ input_ptrs = row_input_ptr + cols
99+ x = tl .load (input_ptrs , mask = mask , other = 0.0 , cache_modifier = ".cg" )
100+ g_ptrs = g_ptr + cols
101+ g = tl .load (g_ptrs , mask = mask , other = 0.0 )
102+ rms_norm = x * norm_factor * g
103+ output_ptrs = row_output_ptr + cols
104+ tl .store (output_ptrs , rms_norm , mask = mask )
104105
105106 else :
106107 mask = col_offsets < n_cols
107- tl .assume (input_row_stride >= 0 )
108- tl .assume (output_row_stride >= 0 )
109- row_start_ptr = input_ptr + row_idx * input_row_stride
110- input_ptrs = row_start_ptr + col_offsets
111- input_ptrs = tl .multiple_of (input_ptrs , (16 , ))
112- row = tl .load (input_ptrs , mask = mask , other = 0.0 , cache_modifier = ".cg" )
113- g = tl .load (g_ptr + col_offsets , mask = mask , other = 0.0 )
114- row_norm = row * row
115- row_norm = tl .sum (row_norm , axis = - 1 )
116- row_norm = row_norm / n_cols
117- row_norm = row_norm + epsilon
118- row_norm = tl .rsqrt (row_norm )
119- rms_norm = row * row_norm
120- rms_norm = rms_norm * g
121-
122- output_row_start_ptr = output_ptr + row_idx * output_row_stride
123- output_ptrs = output_row_start_ptr + col_offsets
124- output_ptrs = tl .multiple_of (output_ptrs , (16 , ))
125- tl .store (output_ptrs , rms_norm , mask = mask )
126-
127-
128- def triton_rmsnorm (x , y , g , n_rows , n_cols , blk_size , epsilon = 1e-6 ):
129- BLOCK_SIZE = blk_size
130- # Use blocked approach if BLOCK_SIZE larger than 65536 // x.element_size()
131- USE_BLOCKED = n_cols > BLOCK_SIZE
132-
133- NUM_PRGMS = n_rows
108+ for row_idx in tl .range (row_start , n_rows , NUM_PRGMS , num_stages = 1 ):
109+ row_start_ptr = input_ptr + row_idx * input_row_stride
110+ input_ptrs = row_start_ptr + col_offsets
111+ input_ptrs = tl .multiple_of (input_ptrs , (16 , ))
112+ row = tl .load (input_ptrs , mask = mask , other = 0.0 , cache_modifier = ".cg" )
113+ g = tl .load (g_ptr + col_offsets , mask = mask , other = 0.0 )
114+ row_norm = row * row
115+ row_norm = tl .sum (row_norm , axis = - 1 )
116+ row_norm = row_norm / n_cols
117+ row_norm = row_norm + epsilon
118+ row_norm = tl .rsqrt (row_norm )
119+ rms_norm = row * row_norm
120+ rms_norm = rms_norm * g
121+
122+ output_row_start_ptr = output_ptr + row_idx * output_row_stride
123+ output_ptrs = output_row_start_ptr + col_offsets
124+ output_ptrs = tl .multiple_of (output_ptrs , (16 , ))
125+ tl .store (output_ptrs , rms_norm , mask = mask )
126+
127+
128+ def triton_rmsnorm (x , y , g , n_rows , n_cols , blk_size , USE_BLOCKED , NUM_PRGMS , epsilon = 1e-6 ):
134129 grid = lambda meta : (NUM_PRGMS , )
135- rms_kernel [grid ](y , x , g , x .stride (0 ), y .stride (0 ), n_rows , n_cols , epsilon , BLOCK_SIZE , USE_BLOCKED , NUM_PRGMS )
130+ rms_kernel [grid ](y , x , g , x .stride (0 ), y .stride (0 ), n_rows , n_cols , epsilon , blk_size , USE_BLOCKED , NUM_PRGMS )
136131
137132 return y
138133
@@ -165,8 +160,10 @@ def test_rmsnorm(M, N):
165160 n_rows , n_cols = x .shape
166161 MAX_FUSED_SIZE = 65536 // x .element_size ()
167162 blk_size = min (MAX_FUSED_SIZE , triton .next_power_of_2 (n_cols ))
163+ USE_BLOCKED = n_cols > blk_size
164+ NUM_PRGMS = min (n_rows , get_num_sms ())
168165 g = torch .ones ((1 , N ), device = 'cuda' )
169- y_triton = triton_rmsnorm (x , y , g , n_rows , n_cols , blk_size )
166+ y_triton = triton_rmsnorm (x , y , g , n_rows , n_cols , blk_size , USE_BLOCKED , NUM_PRGMS )
170167
171168 y_torch = torch_rmsnorm (x , g )
172169
@@ -219,16 +216,20 @@ def benchmark(M, N, provider):
219216 n_rows , n_cols = x .shape
220217 MAX_FUSED_SIZE = 65536 // x .element_size ()
221218 blk_size = min (MAX_FUSED_SIZE , triton .next_power_of_2 (n_cols ))
219+ USE_BLOCKED = n_cols > blk_size
220+ NUM_PRGMS = min (n_rows , get_num_sms ())
222221 stream = torch .cuda .Stream ()
223222 torch .cuda .set_stream (stream )
224223 g = torch .ones ((1 , N ), device = 'cuda' )
225224 if provider == 'torch' :
226225 ms = triton .testing .do_bench (lambda : torch_rmsnorm (x , g ))
227226 if provider == 'triton' :
228- ms = triton .testing .do_bench (lambda : triton_rmsnorm (x , y , g , n_rows , n_cols , blk_size ))
227+ ms = triton .testing .do_bench (
228+ lambda : triton_rmsnorm (x , y , g , n_rows , n_cols , blk_size , USE_BLOCKED , NUM_PRGMS ))
229229 global verbose
230230 if verbose :
231231 print (f'SIZE: { N } Best tuning config: ({ rms_kernel .best_config } )' )
232+ print (f'time: { ms } ' )
232233 gbps = lambda ms : 2 * x .nelement () * x .element_size () * 1e-9 / (ms * 1e-3 )
233234 return gbps (ms )
234235
@@ -266,8 +267,10 @@ def main():
266267 n_rows , n_cols = x .shape
267268 MAX_FUSED_SIZE = 65536 // x .element_size ()
268269 blk_size = min (MAX_FUSED_SIZE , triton .next_power_of_2 (n_cols ))
270+ USE_BLOCKED = n_cols > blk_size
271+ NUM_PRGMS = min (n_rows , get_num_sms ())
269272 g = torch .ones ((1 , args .N_start ), device = 'cuda' )
270- triton_rmsnorm (x , y , g , n_rows , n_cols , blk_size )
273+ triton_rmsnorm (x , y , g , n_rows , n_cols , blk_size , USE_BLOCKED , NUM_PRGMS )
271274 else :
272275 verbose = args .v
273276 run_benchmark (args )
0 commit comments