Skip to content

Commit 9275994

Browse files
committed
Remove the unused n_rows parameter in the triton_softmax_kernel function
1 parent 33a29fe commit 9275994

File tree

2 files changed

+1
-3
lines changed

2 files changed

+1
-3
lines changed

code_size_comparison.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
plt.rcParams["axes.labelweight"] = "bold"
88

99
kernels = ("add", "softmax", "matmul", "conv2d", "attention")
10-
lines_of_code = {"Triton": (19, 26, 57, 110, 98), "NineToothed": (10, 12, 34, 17, 51)}
10+
lines_of_code = {"Triton": (19, 25, 57, 110, 98), "NineToothed": (10, 12, 34, 17, 51)}
1111

1212
x = np.arange(len(kernels))
1313
width = 0.4

softmax.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def triton_softmax_kernel(
3333
output_ptr,
3434
input_row_stride,
3535
output_row_stride,
36-
n_rows,
3736
n_cols,
3837
BLOCK_SIZE: tl.constexpr,
3938
):
@@ -63,7 +62,6 @@ def triton_softmax(input):
6362
output,
6463
input.stride(0),
6564
output.stride(0),
66-
input.shape[0],
6765
input.shape[1],
6866
BLOCK_SIZE=triton.next_power_of_2(input.shape[-1]),
6967
)

0 commit comments

Comments
 (0)