Skip to content

Commit 971472c

Browse files
committed
Fix 02-fused-softmax tutorial on BMG
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 5cb786d commit 971472c

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

python/tutorials/02-fused-softmax.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,10 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n
126126
max_num_resident_warps = NUM_SM * warps_per_sm
127127
kernels = {}
128128
# Possible SLM allocation sizes in kB
129-
tg_slm_sizes = [i * 2**i for i in [0, 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128]] # TODO: Get from properties
129+
tg_slm_sizes = [2**i for i in [0, 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128]] # TODO: Get from properties
130130

131131

132-
def softmax(x):
132+
def softmax(x, y):
133133

134134
def occupancy(num_warps, size_smem):
135135

@@ -155,9 +155,6 @@ def allocated_slm_size(size_smem):
155155
# way so you don't have to come up with manual heuristics yourself.
156156
num_warps = min(max_num_warps, max(1, BLOCK_SIZE // (WARP_SIZE * 4)))
157157

158-
# Allocate output
159-
y = torch.empty_like(x)
160-
161158
# pre-compile kernel to get register usage and compute thread occupancy.
162159
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, num_warps=num_warps,
163160
threads_per_warp=WARP_SIZE, BLOCK_SIZE=BLOCK_SIZE, grid=(1, ))
@@ -189,7 +186,8 @@ def allocated_slm_size(size_smem):
189186

190187
torch.manual_seed(0)
191188
x = torch.randn(1823, 781, device=DEVICE)
192-
y_triton = softmax(x)
189+
y = torch.empty_like(x)
190+
y_triton = softmax(x, y)
193191
y_torch = torch.softmax(x, axis=1)
194192
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
195193

@@ -226,7 +224,7 @@ def benchmark(M, N, provider):
226224
if provider == 'torch':
227225
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
228226
if provider == 'triton':
229-
ms = triton.testing.do_bench(lambda: softmax(x))
227+
ms = triton.testing.do_bench(lambda: softmax(x, y))
230228
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
231229
return gbps(ms)
232230

0 commit comments

Comments
 (0)