12
12
from tritonbench .utils .triton_op import (
13
13
BenchmarkOperator ,
14
14
BenchmarkOperatorMetrics ,
15
+ Mode ,
15
16
register_benchmark ,
16
17
register_metric ,
17
18
register_x_val ,
@@ -41,21 +42,9 @@ def parse_op_args(args: List[str]):
41
42
return parser .parse_args (args )
42
43
43
44
44
- class Operator (BenchmarkOperator ):
45
- DEFAULT_PRECISION = "fp16"
46
- FWD_ONLY = True
47
- is_compute_bound = False
48
-
49
- def __init__ (
50
- self , tb_args : argparse .Namespace , extra_args : Optional [List [str ]] = None
51
- ):
52
- super ().__init__ (tb_args , extra_args )
53
- args = parse_op_args (self .extra_args )
54
- self .M = args .M
55
- self .N = args .N
56
-
57
- @register_benchmark ()
58
- def triton_softmax (self , x ):
45
+ class TritonSoftmax (torch .autograd .Function ):
46
+ @staticmethod
47
+ def forward (ctx , x ):
59
48
n_rows , n_cols = x .shape
60
49
# The block size is the smallest power of two greater than the number of columns in `x`
61
50
BLOCK_SIZE = triton .next_power_of_2 (n_cols )
@@ -71,21 +60,43 @@ def triton_softmax(self, x):
71
60
# Allocate output
72
61
y = torch .empty_like (x )
73
62
74
- # Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o
75
- # f the input matrix
76
- def _inner ():
77
- Operator .softmax_kernel [(n_rows ,)](
78
- y ,
79
- x ,
80
- x .stride (0 ),
81
- y .stride (0 ),
82
- n_cols ,
83
- num_warps = num_warps ,
84
- BLOCK_SIZE = BLOCK_SIZE ,
85
- )
86
- return y
63
+ # Enqueue kernel
64
+ Operator .softmax_kernel [(n_rows ,)](
65
+ y ,
66
+ x ,
67
+ x .stride (0 ),
68
+ y .stride (0 ),
69
+ n_cols ,
70
+ num_warps = num_warps ,
71
+ BLOCK_SIZE = BLOCK_SIZE ,
72
+ )
73
+ ctx .save_for_backward (y )
74
+ return y
87
75
88
- return _inner
76
+ @staticmethod
77
+ def backward (ctx , grad_output ):
78
+ (y ,) = ctx .saved_tensors
79
+ return Operator .softmax_bwd_triton (grad_output , y )
80
+
81
+
82
+ triton_softmax_fn = TritonSoftmax .apply
83
+
84
+
85
+ class Operator (BenchmarkOperator ):
86
+ DEFAULT_PRECISION = "fp16"
87
+ is_compute_bound = False
88
+
89
+ def __init__ (
90
+ self , tb_args : argparse .Namespace , extra_args : Optional [List [str ]] = None
91
+ ):
92
+ super ().__init__ (tb_args , extra_args )
93
+ args = parse_op_args (self .extra_args )
94
+ self .M = args .M
95
+ self .N = args .N
96
+
97
+ @register_benchmark ()
98
+ def triton_softmax (self , x ):
99
+ return lambda : triton_softmax_fn (x )
89
100
90
101
@triton .jit
91
102
def softmax_kernel (
@@ -117,6 +128,125 @@ def softmax_kernel(
117
128
output_ptrs = output_row_start_ptr + col_offsets
118
129
tl .store (output_ptrs , softmax_output , mask = col_offsets < n_cols )
119
130
131
+ @triton .jit
132
+ def softmax_bwd_kernel (
133
+ softmax_output ,
134
+ grad_output ,
135
+ grad_input ,
136
+ grad_input_stride_0 ,
137
+ grad_input_stride_1 ,
138
+ grad_output_stride_0 ,
139
+ grad_output_stride_1 ,
140
+ softmax_output_stride_0 ,
141
+ softmax_output_stride_1 ,
142
+ m ,
143
+ n ,
144
+ BLOCK_SIZE_0 : tl .constexpr ,
145
+ BLOCK_SIZE_1 : tl .constexpr ,
146
+ BLOCK_SIZE_2 : tl .constexpr ,
147
+ ):
148
+ pid_0 = tl .program_id (0 )
149
+ offset_0 = pid_0 * BLOCK_SIZE_0
150
+ indices_0 = (offset_0 + tl .arange (0 , BLOCK_SIZE_0 )).to (tl .int32 )
151
+ mask_0 = indices_0 < m
152
+ sum_per_row = tl .full ([BLOCK_SIZE_0 ], 0.0 , tl .float32 )
153
+ for offset_1 in tl .range (0 , n .to (tl .int32 ), BLOCK_SIZE_1 ):
154
+ indices_1 = offset_1 + tl .arange (0 , BLOCK_SIZE_1 ).to (tl .int32 )
155
+ mask_1 = indices_1 < n
156
+ sum_per_row_copy = sum_per_row
157
+ sum_per_row_copy_0 = sum_per_row_copy
158
+ load = tl .load (
159
+ softmax_output
160
+ + (
161
+ indices_0 [:, None ] * softmax_output_stride_0
162
+ + indices_1 [None , :] * softmax_output_stride_1
163
+ ),
164
+ mask_0 [:, None ] & mask_1 [None , :],
165
+ other = 0 ,
166
+ )
167
+ load_1 = tl .load (
168
+ grad_output
169
+ + (
170
+ indices_0 [:, None ] * grad_output_stride_0
171
+ + indices_1 [None , :] * grad_output_stride_1
172
+ ),
173
+ mask_0 [:, None ] & mask_1 [None , :],
174
+ other = 0 ,
175
+ )
176
+ v_0 = load * load_1
177
+ sum_1 = tl .cast (tl .sum (v_0 , 1 ), tl .float16 )
178
+ v_1 = tl .cast (sum_1 , tl .float32 )
179
+ sum_per_row = sum_per_row_copy_0 + v_1
180
+ for offset_2 in tl .range (0 , n .to (tl .int32 ), BLOCK_SIZE_2 ):
181
+ indices_2 = offset_2 + tl .arange (0 , BLOCK_SIZE_2 ).to (tl .int32 )
182
+ mask_2 = indices_2 < n
183
+ sum_per_row_copy_1 = sum_per_row
184
+ sum_per_row_copy_1_0 = sum_per_row_copy_1
185
+ load_2 = tl .load (
186
+ softmax_output
187
+ + (
188
+ indices_0 [:, None ] * softmax_output_stride_0
189
+ + indices_2 [None , :] * softmax_output_stride_1
190
+ ),
191
+ mask_0 [:, None ] & mask_2 [None , :],
192
+ other = 0 ,
193
+ )
194
+ load_3 = tl .load (
195
+ grad_output
196
+ + (
197
+ indices_0 [:, None ] * grad_output_stride_0
198
+ + indices_2 [None , :] * grad_output_stride_1
199
+ ),
200
+ mask_0 [:, None ] & mask_2 [None , :],
201
+ other = 0 ,
202
+ )
203
+ subscript = sum_per_row_copy_1_0 [:, None ]
204
+ v_3 = tl .cast (load_3 , tl .float32 )
205
+ v_4 = v_3 - subscript
206
+ v_5 = tl .cast (load_2 , tl .float32 )
207
+ v_6 = v_5 * v_4
208
+ v_7 = tl .cast (v_6 , tl .float16 )
209
+ tl .store (
210
+ grad_input
211
+ + (
212
+ indices_0 [:, None ] * grad_input_stride_0
213
+ + indices_2 [None , :] * grad_input_stride_1
214
+ ),
215
+ v_7 ,
216
+ mask_0 [:, None ] & mask_2 [None , :],
217
+ )
218
+
219
+ @staticmethod
220
+ def softmax_bwd_triton (grad_output , softmax_output ):
221
+ """
222
+ Helion generated triton kernel for softmax backward pass
223
+ PR: https://github.com/pytorch/helion/pull/744
224
+ """
225
+ m , n = grad_output .size ()
226
+ grad_input = torch .empty_like (grad_output )
227
+
228
+ BLOCK_SIZE_0 = min (32 , triton .next_power_of_2 (m ))
229
+ BLOCK_SIZE_1 = triton .next_power_of_2 (n )
230
+ BLOCK_SIZE_2 = BLOCK_SIZE_1
231
+
232
+ Operator .softmax_bwd_kernel [(triton .cdiv (m , BLOCK_SIZE_0 ),)](
233
+ softmax_output ,
234
+ grad_output ,
235
+ grad_input ,
236
+ grad_input .stride (0 ),
237
+ grad_input .stride (1 ),
238
+ grad_output .stride (0 ),
239
+ grad_output .stride (1 ),
240
+ softmax_output .stride (0 ),
241
+ softmax_output .stride (1 ),
242
+ m ,
243
+ n ,
244
+ BLOCK_SIZE_0 ,
245
+ BLOCK_SIZE_1 ,
246
+ BLOCK_SIZE_2 ,
247
+ )
248
+ return grad_input
249
+
120
250
@register_benchmark (baseline = True )
121
251
def naive_softmax (self , x ):
122
252
"""Compute row-wise softmax of X using native pytorch."""
@@ -153,8 +283,17 @@ def get_input_iter(self):
153
283
if additional_shapes :
154
284
shapes .extend (additional_shapes )
155
285
286
+ requires_grad = not (self .mode == Mode .FWD_NO_GRAD )
287
+
156
288
for M , N in shapes :
157
- yield (torch .randn ([M , N ], dtype = self .dtype , device = self .device ),)
289
+ yield (
290
+ torch .randn (
291
+ [M , N ],
292
+ dtype = self .dtype ,
293
+ device = self .device ,
294
+ requires_grad = requires_grad ,
295
+ ),
296
+ )
158
297
159
298
@register_x_val (label = "(M, N)" )
160
299
def get_x_val (self , example_inputs ):
0 commit comments