@@ -291,28 +291,6 @@ Tensor transpose_int_batching_rule(const Tensor& self, int64_t dim0, int64_t dim
291
291
return self_physical.getPhysicalToLogicalMap ().apply (result);
292
292
}
293
293
294
- Tensor permute_batching_rule (const Tensor& self, IntArrayRef dims) {
295
- if (!participatesInCurrentLevel (self)) {
296
- c10::impl::ExcludeDispatchKeyGuard guard (kBatchedKey );
297
- return self.permute (dims);
298
- }
299
-
300
- auto self_physical = MultiBatchVmapTransform::logicalToPhysical (self);
301
- auto dims_physical = self_physical.getPhysicalDims (dims);
302
-
303
- VmapDimVector all_dims_physical;
304
- all_dims_physical.reserve (self_physical.tensor ().dim ());
305
- for (int64_t bdim = 0 ; bdim < self_physical.numBatchDims (); bdim++) {
306
- all_dims_physical.push_back (bdim);
307
- }
308
- all_dims_physical.insert (
309
- all_dims_physical.end (),
310
- dims_physical.begin (),
311
- dims_physical.end ());
312
- auto result = self_physical.tensor ().permute (all_dims_physical);
313
- return self_physical.getPhysicalToLogicalMap ().apply (result);
314
- }
315
-
316
294
static int64_t getGradInputPhysicalDim (int64_t dim, IntArrayRef input_sizes, int64_t num_batch_dims) {
317
295
return maybe_wrap_dim (dim, input_sizes.size ()) + num_batch_dims;
318
296
}
@@ -963,7 +941,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
963
941
// NB: static_cast because there's another variant of narrow. However, we don't
964
942
// want to support the other variant yet bc it isn't documented...
965
943
m.impl (" numpy_T" , native::numpy_T); // composite wrt autograd
966
- m.impl (" permute" , permute_batching_rule);
967
944
m.impl (" reshape_as" , native::reshape_as); // composite wrt autograd
968
945
m.impl (" slice.Tensor" , slice_batching_rule);
969
946
m.impl (" split.Tensor" , split_batching_rule);
0 commit comments