Skip to content

Commit 4212e80

Browse files
committed
leave only torch.xpu.empty_cache()
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 0579541 commit 4212e80

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

python/tutorials/02-fused-softmax.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n
129129
tg_slm_sizes = [2**i for i in [0, 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128]] # TODO: Get from properties
130130

131131

132-
def softmax(x, y):
132+
def softmax(x):
133133

134134
def occupancy(num_warps, size_smem):
135135

@@ -155,6 +155,9 @@ def allocated_slm_size(size_smem):
155155
# way so you don't have to come up with manual heuristics yourself.
156156
num_warps = min(max_num_warps, max(1, BLOCK_SIZE // (WARP_SIZE * 4)))
157157

158+
# Allocate output
159+
y = torch.empty_like(x)
160+
158161
# pre-compile kernel to get register usage and compute thread occupancy.
159162
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, num_warps=num_warps,
160163
threads_per_warp=WARP_SIZE, BLOCK_SIZE=BLOCK_SIZE, grid=(1, ))
@@ -186,8 +189,7 @@ def allocated_slm_size(size_smem):
186189

187190
torch.manual_seed(0)
188191
x = torch.randn(1823, 781, device=DEVICE)
189-
y = torch.empty_like(x)
190-
y_triton = softmax(x, y)
192+
y_triton = softmax(x)
191193
y_torch = torch.softmax(x, axis=1)
192194
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
193195

@@ -224,7 +226,7 @@ def benchmark(M, N, provider):
224226
if provider == 'torch':
225227
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
226228
if provider == 'triton':
227-
ms = triton.testing.do_bench(lambda: softmax(x, y))
229+
ms = triton.testing.do_bench(lambda: softmax(x))
228230
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
229231
torch.xpu.empty_cache()
230232
return gbps(ms)

0 commit comments

Comments
 (0)