Skip to content

Commit a2f2285

Browse files
authored
Fix 02-fused-softmax tutorial on BMG (#4383)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 30bdfce commit a2f2285

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

python/tutorials/02-fused-softmax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n
126126
max_num_resident_warps = NUM_SM * warps_per_sm
127127
kernels = {}
128128
# Possible SLM allocation sizes in kB
129-
tg_slm_sizes = [i * 2**i for i in [0, 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128]] # TODO: Get from properties
129+
tg_slm_sizes = [i * 2**10 for i in [0, 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128]] # TODO: Get from properties
130130

131131

132132
def softmax(x):
@@ -228,6 +228,7 @@ def benchmark(M, N, provider):
228228
if provider == 'triton':
229229
ms = triton.testing.do_bench(lambda: softmax(x))
230230
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
231+
torch.xpu.empty_cache()
231232
return gbps(ms)
232233

233234

0 commit comments

Comments
 (0)