|
7 | 7 | from typing import Any, Callable, List |
8 | 8 |
|
9 | 9 | import torch |
| 10 | +from packaging import version |
10 | 11 |
|
11 | 12 | aten = torch.ops.aten |
12 | 13 |
|
@@ -188,131 +189,136 @@ def zero_flop_jit(*args): |
188 | 189 | return 0 |
189 | 190 |
|
190 | 191 |
|
191 | | -flop_mapping = { |
| 192 | +if version.parse(torch.__version__) >= version.parse('1.12.0'): |
| 193 | + flop_mapping = { |
192 | 194 | # gemm |
193 | | - aten.mm.default: matmul_flop_jit, |
194 | | - aten.matmul.default: matmul_flop_jit, |
195 | | - aten.addmm.default: addmm_flop_jit, |
196 | | - aten.bmm.default: bmm_flop_jit, |
| 195 | + aten.mm.default: matmul_flop_jit, |
| 196 | + aten.matmul.default: matmul_flop_jit, |
| 197 | + aten.addmm.default: addmm_flop_jit, |
| 198 | + aten.bmm.default: bmm_flop_jit, |
197 | 199 |
|
198 | 200 | # convolution |
199 | | - aten.convolution.default: conv_flop_jit, |
200 | | - aten._convolution.default: conv_flop_jit, |
201 | | - aten.convolution_backward.default: conv_backward_flop_jit, |
| 201 | + aten.convolution.default: conv_flop_jit, |
| 202 | + aten._convolution.default: conv_flop_jit, |
| 203 | + aten.convolution_backward.default: conv_backward_flop_jit, |
202 | 204 |
|
203 | 205 | # normalization |
204 | | - aten.native_batch_norm.default: batchnorm_flop_jit, |
205 | | - aten.native_batch_norm_backward.default: batchnorm_flop_jit, |
206 | | - aten.cudnn_batch_norm.default: batchnorm_flop_jit, |
207 | | - aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True), |
208 | | - aten.native_layer_norm.default: norm_flop_counter(2, 0), |
209 | | - aten.native_layer_norm_backward.default: norm_flop_counter(2, 0), |
| 206 | + aten.native_batch_norm.default: batchnorm_flop_jit, |
| 207 | + aten.native_batch_norm_backward.default: batchnorm_flop_jit, |
| 208 | + aten.cudnn_batch_norm.default: batchnorm_flop_jit, |
| 209 | + aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True), |
| 210 | + aten.native_layer_norm.default: norm_flop_counter(2, 0), |
| 211 | + aten.native_layer_norm_backward.default: norm_flop_counter(2, 0), |
210 | 212 |
|
211 | 213 | # pooling |
212 | | - aten.avg_pool1d.default: elementwise_flop_counter(1, 0), |
213 | | - aten.avg_pool2d.default: elementwise_flop_counter(1, 0), |
214 | | - aten.avg_pool2d_backward.default: elementwise_flop_counter(0, 1), |
215 | | - aten.avg_pool3d.default: elementwise_flop_counter(1, 0), |
216 | | - aten.avg_pool3d_backward.default: elementwise_flop_counter(0, 1), |
217 | | - aten.max_pool1d.default: elementwise_flop_counter(1, 0), |
218 | | - aten.max_pool2d.default: elementwise_flop_counter(1, 0), |
219 | | - aten.max_pool3d.default: elementwise_flop_counter(1, 0), |
220 | | - aten.max_pool1d_with_indices.default: elementwise_flop_counter(1, 0), |
221 | | - aten.max_pool2d_with_indices.default: elementwise_flop_counter(1, 0), |
222 | | - aten.max_pool2d_with_indices_backward.default: elementwise_flop_counter(0, 1), |
223 | | - aten.max_pool3d_with_indices.default: elementwise_flop_counter(1, 0), |
224 | | - aten.max_pool3d_with_indices_backward.default: elementwise_flop_counter(0, 1), |
225 | | - aten._adaptive_avg_pool2d.default: elementwise_flop_counter(1, 0), |
226 | | - aten._adaptive_avg_pool2d_backward.default: elementwise_flop_counter(0, 1), |
227 | | - aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0), |
228 | | - aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1), |
229 | | - aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1), |
230 | | - aten.embedding.default: elementwise_flop_counter(1, 0), |
231 | | -} |
232 | | - |
233 | | -elementwise_flop_aten = [ |
| 214 | + aten.avg_pool1d.default: elementwise_flop_counter(1, 0), |
| 215 | + aten.avg_pool2d.default: elementwise_flop_counter(1, 0), |
| 216 | + aten.avg_pool2d_backward.default: elementwise_flop_counter(0, 1), |
| 217 | + aten.avg_pool3d.default: elementwise_flop_counter(1, 0), |
| 218 | + aten.avg_pool3d_backward.default: elementwise_flop_counter(0, 1), |
| 219 | + aten.max_pool1d.default: elementwise_flop_counter(1, 0), |
| 220 | + aten.max_pool2d.default: elementwise_flop_counter(1, 0), |
| 221 | + aten.max_pool3d.default: elementwise_flop_counter(1, 0), |
| 222 | + aten.max_pool1d_with_indices.default: elementwise_flop_counter(1, 0), |
| 223 | + aten.max_pool2d_with_indices.default: elementwise_flop_counter(1, 0), |
| 224 | + aten.max_pool2d_with_indices_backward.default: elementwise_flop_counter(0, 1), |
| 225 | + aten.max_pool3d_with_indices.default: elementwise_flop_counter(1, 0), |
| 226 | + aten.max_pool3d_with_indices_backward.default: elementwise_flop_counter(0, 1), |
| 227 | + aten._adaptive_avg_pool2d.default: elementwise_flop_counter(1, 0), |
| 228 | + aten._adaptive_avg_pool2d_backward.default: elementwise_flop_counter(0, 1), |
| 229 | + aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0), |
| 230 | + aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1), |
| 231 | + aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1), |
| 232 | + aten.embedding.default: elementwise_flop_counter(1, 0), |
| 233 | + } |
| 234 | + |
| 235 | + elementwise_flop_aten = [ |
234 | 236 | # basic op |
235 | | - aten.add.Tensor, |
236 | | - aten.add_.Tensor, |
237 | | - aten.div.Tensor, |
238 | | - aten.div_.Tensor, |
239 | | - aten.div.Scalar, |
240 | | - aten.div_.Scalar, |
241 | | - aten.mul.Tensor, |
242 | | - aten.mul.Scalar, |
243 | | - aten.mul_.Tensor, |
244 | | - aten.neg.default, |
245 | | - aten.pow.Tensor_Scalar, |
246 | | - aten.rsub.Scalar, |
247 | | - aten.sum.default, |
248 | | - aten.sum.dim_IntList, |
249 | | - aten.mean.dim, |
| 237 | + aten.add.Tensor, |
| 238 | + aten.add_.Tensor, |
| 239 | + aten.div.Tensor, |
| 240 | + aten.div_.Tensor, |
| 241 | + aten.div.Scalar, |
| 242 | + aten.div_.Scalar, |
| 243 | + aten.mul.Tensor, |
| 244 | + aten.mul.Scalar, |
| 245 | + aten.mul_.Tensor, |
| 246 | + aten.neg.default, |
| 247 | + aten.pow.Tensor_Scalar, |
| 248 | + aten.rsub.Scalar, |
| 249 | + aten.sum.default, |
| 250 | + aten.sum.dim_IntList, |
| 251 | + aten.mean.dim, |
250 | 252 |
|
251 | 253 | # activation op |
252 | | - aten.hardswish.default, |
253 | | - aten.hardswish_.default, |
254 | | - aten.hardswish_backward.default, |
255 | | - aten.hardtanh.default, |
256 | | - aten.hardtanh_.default, |
257 | | - aten.hardtanh_backward.default, |
258 | | - aten.hardsigmoid_backward.default, |
259 | | - aten.hardsigmoid.default, |
260 | | - aten.gelu.default, |
261 | | - aten.gelu_backward.default, |
262 | | - aten.silu.default, |
263 | | - aten.silu_.default, |
264 | | - aten.silu_backward.default, |
265 | | - aten.sigmoid.default, |
266 | | - aten.sigmoid_backward.default, |
267 | | - aten._softmax.default, |
268 | | - aten._softmax_backward_data.default, |
269 | | - aten.relu_.default, |
270 | | - aten.relu.default, |
271 | | - aten.tanh.default, |
272 | | - aten.tanh_backward.default, |
273 | | - aten.threshold_backward.default, |
| 254 | + aten.hardswish.default, |
| 255 | + aten.hardswish_.default, |
| 256 | + aten.hardswish_backward.default, |
| 257 | + aten.hardtanh.default, |
| 258 | + aten.hardtanh_.default, |
| 259 | + aten.hardtanh_backward.default, |
| 260 | + aten.hardsigmoid_backward.default, |
| 261 | + aten.hardsigmoid.default, |
| 262 | + aten.gelu.default, |
| 263 | + aten.gelu_backward.default, |
| 264 | + aten.silu.default, |
| 265 | + aten.silu_.default, |
| 266 | + aten.silu_backward.default, |
| 267 | + aten.sigmoid.default, |
| 268 | + aten.sigmoid_backward.default, |
| 269 | + aten._softmax.default, |
| 270 | + aten._softmax_backward_data.default, |
| 271 | + aten.relu_.default, |
| 272 | + aten.relu.default, |
| 273 | + aten.tanh.default, |
| 274 | + aten.tanh_backward.default, |
| 275 | + aten.threshold_backward.default, |
274 | 276 |
|
275 | 277 | # dropout |
276 | | - aten.native_dropout.default, |
277 | | - aten.native_dropout_backward.default, |
278 | | -] |
279 | | - |
280 | | -for op in elementwise_flop_aten: |
281 | | - flop_mapping[op] = elementwise_flop_counter(1, 0) |
282 | | - |
283 | | -# TODO: this will be removed in future |
284 | | -zero_flop_aten = [ |
285 | | - aten.as_strided.default, |
286 | | - aten.as_strided_.default, |
287 | | - aten.bernoulli_.float, |
288 | | - aten.cat.default, |
289 | | - aten.clone.default, |
290 | | - aten.copy_.default, |
291 | | - aten.detach.default, |
292 | | - aten.expand.default, |
293 | | - aten.empty_like.default, |
294 | | - aten.new_empty.default, |
295 | | - aten.new_empty_strided.default, |
296 | | - aten.ones_like.default, |
297 | | - aten._reshape_alias.default, |
298 | | - aten.select.int, |
299 | | - aten.select_backward.default, |
300 | | - aten.squeeze.dim, |
301 | | - aten.slice.Tensor, |
302 | | - aten.slice_backward.default, |
303 | | - aten.split.Tensor, |
304 | | - aten.permute.default, |
305 | | - aten.t.default, |
306 | | - aten.transpose.int, |
307 | | - aten._to_copy.default, |
308 | | - aten.unsqueeze.default, |
309 | | - aten.unbind.int, |
310 | | - aten._unsafe_view.default, |
311 | | - aten.view.default, |
312 | | - aten.where.self, |
313 | | - aten.zero_.default, |
314 | | - aten.zeros_like.default, |
315 | | -] |
316 | | - |
317 | | -for op in zero_flop_aten: |
318 | | - flop_mapping[op] = zero_flop_jit |
| 278 | + aten.native_dropout.default, |
| 279 | + aten.native_dropout_backward.default, |
| 280 | + ] |
| 281 | + for op in elementwise_flop_aten: |
| 282 | + flop_mapping[op] = elementwise_flop_counter(1, 0) |
| 283 | + |
| 284 | + # TODO: this will be removed in future |
| 285 | + zero_flop_aten = [ |
| 286 | + aten.as_strided.default, |
| 287 | + aten.as_strided_.default, |
| 288 | + aten.bernoulli_.float, |
| 289 | + aten.cat.default, |
| 290 | + aten.clone.default, |
| 291 | + aten.copy_.default, |
| 292 | + aten.detach.default, |
| 293 | + aten.expand.default, |
| 294 | + aten.empty_like.default, |
| 295 | + aten.new_empty.default, |
| 296 | + aten.new_empty_strided.default, |
| 297 | + aten.ones_like.default, |
| 298 | + aten._reshape_alias.default, |
| 299 | + aten.select.int, |
| 300 | + aten.select_backward.default, |
| 301 | + aten.squeeze.dim, |
| 302 | + aten.slice.Tensor, |
| 303 | + aten.slice_backward.default, |
| 304 | + aten.split.Tensor, |
| 305 | + aten.permute.default, |
| 306 | + aten.t.default, |
| 307 | + aten.transpose.int, |
| 308 | + aten._to_copy.default, |
| 309 | + aten.unsqueeze.default, |
| 310 | + aten.unbind.int, |
| 311 | + aten._unsafe_view.default, |
| 312 | + aten.view.default, |
| 313 | + aten.where.self, |
| 314 | + aten.zero_.default, |
| 315 | + aten.zeros_like.default, |
| 316 | + ] |
| 317 | + |
| 318 | + for op in zero_flop_aten: |
| 319 | + flop_mapping[op] = zero_flop_jit |
| 320 | + |
| 321 | +else: |
| 322 | + flop_mapping = {} |
| 323 | + elementwise_flop_aten = {} |
| 324 | + zero_flop_aten = {} |
0 commit comments