Skip to content

Commit e7c435c

Browse files
authored
Enable backward mode for softmax operator (#528)
1 parent 2beb393 commit e7c435c

File tree

1 file changed

+169
-30
lines changed

1 file changed

+169
-30
lines changed

tritonbench/operators/softmax/operator.py

Lines changed: 169 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from tritonbench.utils.triton_op import (
1313
BenchmarkOperator,
1414
BenchmarkOperatorMetrics,
15+
Mode,
1516
register_benchmark,
1617
register_metric,
1718
register_x_val,
@@ -41,21 +42,9 @@ def parse_op_args(args: List[str]):
4142
return parser.parse_args(args)
4243

4344

44-
class Operator(BenchmarkOperator):
45-
DEFAULT_PRECISION = "fp16"
46-
FWD_ONLY = True
47-
is_compute_bound = False
48-
49-
def __init__(
50-
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
51-
):
52-
super().__init__(tb_args, extra_args)
53-
args = parse_op_args(self.extra_args)
54-
self.M = args.M
55-
self.N = args.N
56-
57-
@register_benchmark()
58-
def triton_softmax(self, x):
45+
class TritonSoftmax(torch.autograd.Function):
46+
@staticmethod
47+
def forward(ctx, x):
5948
n_rows, n_cols = x.shape
6049
# The block size is the smallest power of two greater than the number of columns in `x`
6150
BLOCK_SIZE = triton.next_power_of_2(n_cols)
@@ -71,21 +60,43 @@ def triton_softmax(self, x):
7160
# Allocate output
7261
y = torch.empty_like(x)
7362

74-
# Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o
75-
# f the input matrix
76-
def _inner():
77-
Operator.softmax_kernel[(n_rows,)](
78-
y,
79-
x,
80-
x.stride(0),
81-
y.stride(0),
82-
n_cols,
83-
num_warps=num_warps,
84-
BLOCK_SIZE=BLOCK_SIZE,
85-
)
86-
return y
63+
# Enqueue kernel
64+
Operator.softmax_kernel[(n_rows,)](
65+
y,
66+
x,
67+
x.stride(0),
68+
y.stride(0),
69+
n_cols,
70+
num_warps=num_warps,
71+
BLOCK_SIZE=BLOCK_SIZE,
72+
)
73+
ctx.save_for_backward(y)
74+
return y
8775

88-
return _inner
76+
@staticmethod
77+
def backward(ctx, grad_output):
78+
(y,) = ctx.saved_tensors
79+
return Operator.softmax_bwd_triton(grad_output, y)
80+
81+
82+
triton_softmax_fn = TritonSoftmax.apply
83+
84+
85+
class Operator(BenchmarkOperator):
86+
DEFAULT_PRECISION = "fp16"
87+
is_compute_bound = False
88+
89+
def __init__(
90+
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
91+
):
92+
super().__init__(tb_args, extra_args)
93+
args = parse_op_args(self.extra_args)
94+
self.M = args.M
95+
self.N = args.N
96+
97+
@register_benchmark()
98+
def triton_softmax(self, x):
99+
return lambda: triton_softmax_fn(x)
89100

90101
@triton.jit
91102
def softmax_kernel(
@@ -117,6 +128,125 @@ def softmax_kernel(
117128
output_ptrs = output_row_start_ptr + col_offsets
118129
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
119130

131+
@triton.jit
132+
def softmax_bwd_kernel(
133+
softmax_output,
134+
grad_output,
135+
grad_input,
136+
grad_input_stride_0,
137+
grad_input_stride_1,
138+
grad_output_stride_0,
139+
grad_output_stride_1,
140+
softmax_output_stride_0,
141+
softmax_output_stride_1,
142+
m,
143+
n,
144+
BLOCK_SIZE_0: tl.constexpr,
145+
BLOCK_SIZE_1: tl.constexpr,
146+
BLOCK_SIZE_2: tl.constexpr,
147+
):
148+
pid_0 = tl.program_id(0)
149+
offset_0 = pid_0 * BLOCK_SIZE_0
150+
indices_0 = (offset_0 + tl.arange(0, BLOCK_SIZE_0)).to(tl.int32)
151+
mask_0 = indices_0 < m
152+
sum_per_row = tl.full([BLOCK_SIZE_0], 0.0, tl.float32)
153+
for offset_1 in tl.range(0, n.to(tl.int32), BLOCK_SIZE_1):
154+
indices_1 = offset_1 + tl.arange(0, BLOCK_SIZE_1).to(tl.int32)
155+
mask_1 = indices_1 < n
156+
sum_per_row_copy = sum_per_row
157+
sum_per_row_copy_0 = sum_per_row_copy
158+
load = tl.load(
159+
softmax_output
160+
+ (
161+
indices_0[:, None] * softmax_output_stride_0
162+
+ indices_1[None, :] * softmax_output_stride_1
163+
),
164+
mask_0[:, None] & mask_1[None, :],
165+
other=0,
166+
)
167+
load_1 = tl.load(
168+
grad_output
169+
+ (
170+
indices_0[:, None] * grad_output_stride_0
171+
+ indices_1[None, :] * grad_output_stride_1
172+
),
173+
mask_0[:, None] & mask_1[None, :],
174+
other=0,
175+
)
176+
v_0 = load * load_1
177+
sum_1 = tl.cast(tl.sum(v_0, 1), tl.float16)
178+
v_1 = tl.cast(sum_1, tl.float32)
179+
sum_per_row = sum_per_row_copy_0 + v_1
180+
for offset_2 in tl.range(0, n.to(tl.int32), BLOCK_SIZE_2):
181+
indices_2 = offset_2 + tl.arange(0, BLOCK_SIZE_2).to(tl.int32)
182+
mask_2 = indices_2 < n
183+
sum_per_row_copy_1 = sum_per_row
184+
sum_per_row_copy_1_0 = sum_per_row_copy_1
185+
load_2 = tl.load(
186+
softmax_output
187+
+ (
188+
indices_0[:, None] * softmax_output_stride_0
189+
+ indices_2[None, :] * softmax_output_stride_1
190+
),
191+
mask_0[:, None] & mask_2[None, :],
192+
other=0,
193+
)
194+
load_3 = tl.load(
195+
grad_output
196+
+ (
197+
indices_0[:, None] * grad_output_stride_0
198+
+ indices_2[None, :] * grad_output_stride_1
199+
),
200+
mask_0[:, None] & mask_2[None, :],
201+
other=0,
202+
)
203+
subscript = sum_per_row_copy_1_0[:, None]
204+
v_3 = tl.cast(load_3, tl.float32)
205+
v_4 = v_3 - subscript
206+
v_5 = tl.cast(load_2, tl.float32)
207+
v_6 = v_5 * v_4
208+
v_7 = tl.cast(v_6, tl.float16)
209+
tl.store(
210+
grad_input
211+
+ (
212+
indices_0[:, None] * grad_input_stride_0
213+
+ indices_2[None, :] * grad_input_stride_1
214+
),
215+
v_7,
216+
mask_0[:, None] & mask_2[None, :],
217+
)
218+
219+
@staticmethod
220+
def softmax_bwd_triton(grad_output, softmax_output):
221+
"""
222+
Helion generated triton kernel for softmax backward pass
223+
PR: https://github.com/pytorch/helion/pull/744
224+
"""
225+
m, n = grad_output.size()
226+
grad_input = torch.empty_like(grad_output)
227+
228+
BLOCK_SIZE_0 = min(32, triton.next_power_of_2(m))
229+
BLOCK_SIZE_1 = triton.next_power_of_2(n)
230+
BLOCK_SIZE_2 = BLOCK_SIZE_1
231+
232+
Operator.softmax_bwd_kernel[(triton.cdiv(m, BLOCK_SIZE_0),)](
233+
softmax_output,
234+
grad_output,
235+
grad_input,
236+
grad_input.stride(0),
237+
grad_input.stride(1),
238+
grad_output.stride(0),
239+
grad_output.stride(1),
240+
softmax_output.stride(0),
241+
softmax_output.stride(1),
242+
m,
243+
n,
244+
BLOCK_SIZE_0,
245+
BLOCK_SIZE_1,
246+
BLOCK_SIZE_2,
247+
)
248+
return grad_input
249+
120250
@register_benchmark(baseline=True)
121251
def naive_softmax(self, x):
122252
"""Compute row-wise softmax of X using native pytorch."""
@@ -153,8 +283,17 @@ def get_input_iter(self):
153283
if additional_shapes:
154284
shapes.extend(additional_shapes)
155285

286+
requires_grad = not (self.mode == Mode.FWD_NO_GRAD)
287+
156288
for M, N in shapes:
157-
yield (torch.randn([M, N], dtype=self.dtype, device=self.device),)
289+
yield (
290+
torch.randn(
291+
[M, N],
292+
dtype=self.dtype,
293+
device=self.device,
294+
requires_grad=requires_grad,
295+
),
296+
)
158297

159298
@register_x_val(label="(M, N)")
160299
def get_x_val(self, example_inputs):

0 commit comments

Comments
 (0)