@@ -129,7 +129,7 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n
129129tg_slm_sizes = [2 ** i for i in [0 , 1 , 2 , 4 , 8 , 16 , 24 , 32 , 48 , 64 , 96 , 128 ]] # TODO: Get from properties
130130
131131
132- def softmax (x , y ):
132+ def softmax (x ):
133133
134134 def occupancy (num_warps , size_smem ):
135135
@@ -155,6 +155,9 @@ def allocated_slm_size(size_smem):
155155 # way so you don't have to come up with manual heuristics yourself.
156156 num_warps = min (max_num_warps , max (1 , BLOCK_SIZE // (WARP_SIZE * 4 )))
157157
158+ # Allocate output
159+ y = torch .empty_like (x )
160+
158161 # pre-compile kernel to get register usage and compute thread occupancy.
159162 kernel = softmax_kernel .warmup (y , x , x .stride (0 ), y .stride (0 ), n_rows , n_cols , num_warps = num_warps ,
160163 threads_per_warp = WARP_SIZE , BLOCK_SIZE = BLOCK_SIZE , grid = (1 , ))
@@ -186,8 +189,7 @@ def allocated_slm_size(size_smem):
186189
187190torch .manual_seed (0 )
188191x = torch .randn (1823 , 781 , device = DEVICE )
189- y = torch .empty_like (x )
190- y_triton = softmax (x , y )
192+ y_triton = softmax (x )
191193y_torch = torch .softmax (x , axis = 1 )
192194assert torch .allclose (y_triton , y_torch ), (y_triton , y_torch )
193195
@@ -224,7 +226,7 @@ def benchmark(M, N, provider):
224226 if provider == 'torch' :
225227 ms = triton .testing .do_bench (lambda : torch .softmax (x , axis = - 1 ))
226228 if provider == 'triton' :
227- ms = triton .testing .do_bench (lambda : softmax (x , y ))
229+ ms = triton .testing .do_bench (lambda : softmax (x ))
228230 gbps = lambda ms : 2 * x .numel () * x .element_size () * 1e-9 / (ms * 1e-3 )
229231 torch .xpu .empty_cache ()
230232 return gbps (ms )
0 commit comments