Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 509424e

Browse files
authored
[port] permute to new api (#172)
* [port] permute to new api * retrigger CI * fix errors
1 parent adeaa92 commit 509424e

File tree

2 files changed

+19
-23
lines changed

2 files changed

+19
-23
lines changed

functorch/csrc/BatchRulesViews.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,24 @@ std::tuple<Tensor,optional<int64_t>> diagonal_backward_batch_rule(
328328
return std::make_tuple(std::move(result), 0);
329329
}
330330

331+
std::tuple<Tensor, optional<int64_t>> permute_batching_rule(
332+
const Tensor &self, optional<int64_t> self_bdim, IntArrayRef dims)
333+
{
334+
if (!self_bdim.has_value()) {
335+
return std::make_tuple(self.permute(dims), self_bdim);
336+
}
337+
338+
auto self_ = moveBatchDimToFront(self, self_bdim);
339+
VmapDimVector dims_;
340+
dims_.reserve(dims.size() + 1);
341+
dims_.emplace_back(0);
342+
for (auto dim : dims) {
343+
dims_.emplace_back(getPhysicalDim(self_, self_bdim.has_value(), dim));
344+
}
345+
346+
return std::make_tuple(self_.permute(dims_), 0);
347+
}
348+
331349
std::tuple<Tensor,optional<int64_t>> select_backward_batch_rule(
332350
const Tensor& grad_input, optional<int64_t> grad_input_bdim,
333351
IntArrayRef input_sizes, int64_t dim, int64_t index) {
@@ -405,6 +423,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
405423
VMAP_SUPPORT("squeeze.dim", squeeze_dim_batch_rule);
406424
VMAP_SUPPORT("_reshape_alias", _reshape_alias_batch_rule);
407425
VMAP_SUPPORT("roll", roll_batch_rule);
426+
VMAP_SUPPORT("permute", permute_batching_rule);
408427
VMAP_SUPPORT("diagonal", diagonal_batching_rule);
409428
VMAP_SUPPORT("diagonal_backward", diagonal_backward_batch_rule);
410429
VMAP_SUPPORT("select_backward", select_backward_batch_rule);

functorch/csrc/BatchingRegistrations.cpp

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -291,28 +291,6 @@ Tensor transpose_int_batching_rule(const Tensor& self, int64_t dim0, int64_t dim
291291
return self_physical.getPhysicalToLogicalMap().apply(result);
292292
}
293293

294-
Tensor permute_batching_rule(const Tensor& self, IntArrayRef dims) {
295-
if (!participatesInCurrentLevel(self)) {
296-
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
297-
return self.permute(dims);
298-
}
299-
300-
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
301-
auto dims_physical = self_physical.getPhysicalDims(dims);
302-
303-
VmapDimVector all_dims_physical;
304-
all_dims_physical.reserve(self_physical.tensor().dim());
305-
for (int64_t bdim = 0; bdim < self_physical.numBatchDims(); bdim++) {
306-
all_dims_physical.push_back(bdim);
307-
}
308-
all_dims_physical.insert(
309-
all_dims_physical.end(),
310-
dims_physical.begin(),
311-
dims_physical.end());
312-
auto result = self_physical.tensor().permute(all_dims_physical);
313-
return self_physical.getPhysicalToLogicalMap().apply(result);
314-
}
315-
316294
static int64_t getGradInputPhysicalDim(int64_t dim, IntArrayRef input_sizes, int64_t num_batch_dims) {
317295
return maybe_wrap_dim(dim, input_sizes.size()) + num_batch_dims;
318296
}
@@ -963,7 +941,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
963941
// NB: static_cast because there's another variant of narrow. However, we don't
964942
// want to support the other variant yet bc it isn't documented...
965943
m.impl("numpy_T", native::numpy_T); // composite wrt autograd
966-
m.impl("permute", permute_batching_rule);
967944
m.impl("reshape_as", native::reshape_as); // composite wrt autograd
968945
m.impl("slice.Tensor", slice_batching_rule);
969946
m.impl("split.Tensor", split_batching_rule);

0 commit comments

Comments
 (0)