Skip to content

Commit d6b0238

Browse files
authored
Update 02-fused-softmax.py (#7473)
docs: Align softmax tutorial benchmark with description The documentation for the softmax tutorial states that it compares three implementations: the Triton kernel, torch.softmax, and a naive_softmax implementation. However, the accompanying code example only benchmarks the Triton and Torch versions, omitting the naive implementation. This discrepancy can cause confusion for users following the tutorial. This commit updates the benchmark code to include the `native_softmax` comparison, ensuring the code accurately reflects the tutorial's description. The changes include: - Adding 'native_softmax' to the `line_vals` list. - Adding "Native Softmax" to the `line_names` list. - Adding the corresponding logic branch to the `benchmark` function. Here is my code result: <img width="571" height="432" alt="image" src="https://github.com/user-attachments/assets/bf6dc821-2ee4-4acc-8002-0cc2188d3497" />
1 parent cd6d25f commit d6b0238

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

python/tutorials/02-fused-softmax.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,9 @@ def softmax(x):
205205
x_names=['N'], # argument names to use as an x-axis for the plot
206206
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
207207
line_arg='provider', # argument name whose value corresponds to a different line in the plot
208-
line_vals=['triton', 'torch'], # possible values for `line_arg``
209-
line_names=[
210-
"Triton",
211-
"Torch",
212-
], # label name for the lines
213-
styles=[('blue', '-'), ('green', '-')], # line styles
208+
line_vals=['triton', 'torch', 'naive_softmax'], # possible values for `line_arg``
209+
line_names=["Triton", "Torch", "Naive Softmax"], # label name for the lines
210+
styles=[('blue', '-'), ('green', '-'), ('red', '-')], # line styles
214211
ylabel="GB/s", # label name for the y-axis
215212
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
216213
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
@@ -223,6 +220,8 @@ def benchmark(M, N, provider):
223220
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
224221
if provider == 'triton':
225222
ms = triton.testing.do_bench(lambda: softmax(x))
223+
if provider == 'naive_softmax':
224+
ms = triton.testing.do_bench(lambda: naive_softmax(x))
226225
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
227226
return gbps(ms)
228227

0 commit comments

Comments
 (0)