### 🐛 Describe the bug alexnet fp16 training with channel last, bs=4096 max_pool2d_with_indices [4096, 64, 55, 55], stride[193600, 1, 3520, 64], takes 12ms. [4096, 192, 27, 27] takes 9ms. [4096,256,13,13] takes 3ms. their backward takes MaxPool2dBackwardDeterministicKernelFunctor<c10::Half, true> takes 80ms in total. alexnet e2e is 470ms, max pool takes 104ms, conv takes 330ms ### Versions main