Skip to content

Commit 79d31a7

Browse files
authored
convert movedim to new api (#211)
* convery movedim to new api * move to view batching rules * remove automatically unbatched * convery movedim to new api * move to view batching rules * remove automatically unbatched * remove unused * remove merge typo
1 parent 582b71a commit 79d31a7

File tree

4 files changed

+26
-15
lines changed

4 files changed

+26
-15
lines changed

functorch/csrc/BatchRulesHelper.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,22 @@ int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical
5252
return wrapped_dim;
5353
}
5454

55+
VmapDimVector getPhysicalDims(const Tensor& tensor, bool has_batch_dim, IntArrayRef logical_dims) {
56+
// NB: assumes the batch dim is at the front of the tensor
57+
optional<int64_t> bdim = has_batch_dim ? optional<int64_t>(0) : nullopt;
58+
auto rank = rankWithoutBatchDim(tensor, bdim);
59+
VmapDimVector result;
60+
result.reserve(logical_dims.size());
61+
for (auto d : logical_dims){
62+
if (has_batch_dim) {
63+
result.push_back(maybe_wrap_dim(d, rank)+1);
64+
} else {
65+
result.push_back(maybe_wrap_dim(d, rank));
66+
}
67+
}
68+
return result;
69+
}
70+
5571
Tensor maybePadToLogicalRank(const Tensor& tensor, optional<int64_t> has_bdim, int64_t logical_rank) {
5672
if (!has_bdim) {
5773
return tensor;

functorch/csrc/BatchRulesHelper.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ int64_t rankWithoutBatchDim(const Tensor& tensor, optional<int64_t> maybe_batch_
2727
int64_t numelWithoutBatchDim(const Tensor& tensor, optional<int64_t> maybe_batch_dim);
2828
optional<int64_t> valIfNonempty(optional<int64_t> maybe_empty, int64_t new_val);
2929
int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim);
30+
VmapDimVector getPhysicalDims(const Tensor& tensor, bool has_batch_dim, IntArrayRef logical_dims);
3031

3132
void vmapIncompatibleInplaceError(const char* schema_name);
3233

functorch/csrc/BatchRulesViews.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,13 @@ std::tuple<Tensor, optional<int64_t>> unfold_batch_rule(
446446
return std::make_tuple(result, 0);
447447
}
448448

449+
std::tuple<Tensor, optional<int64_t>> movedim_batch_rule(const Tensor& self, optional<int64_t> self_bdim, IntArrayRef source, IntArrayRef destination) {
450+
auto self_ = moveBatchDimToFront(self, self_bdim);
451+
auto source_ = getPhysicalDims(self_, self_bdim.has_value(), source);
452+
auto destination_ = getPhysicalDims(self_, self_bdim.has_value(), destination);
453+
return std::make_tuple(self_.movedim(source_, destination_), 0);
454+
}
455+
449456
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
450457
VMAP_SUPPORT("diag", diag_batch_rule);
451458
VMAP_SUPPORT("chunk", chunk_batching_rule);
@@ -471,6 +478,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
471478
VMAP_SUPPORT("view", view_batching_rule);
472479
VMAP_SUPPORT("expand", expand_batch_rule);
473480
VMAP_SUPPORT("unfold", unfold_batch_rule);
474-
}
481+
VMAP_SUPPORT("movedim.intlist", movedim_batch_rule);
482+
}
475483

476484
}}

functorch/csrc/BatchingRegistrations.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -308,18 +308,6 @@ Tensor slice_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes,
308308
return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
309309
}
310310

311-
Tensor movedim_batching_rule(const Tensor& self, IntArrayRef source, IntArrayRef destination) {
312-
if (!participatesInCurrentLevel(self)) {
313-
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
314-
return at::movedim(self, source, destination);
315-
}
316-
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
317-
auto source_physical = self_physical.getPhysicalDims(source);
318-
auto destination_physical = self_physical.getPhysicalDims(destination);
319-
auto result = at::movedim(self_physical.tensor(), source_physical, destination_physical);
320-
return self_physical.getPhysicalToLogicalMap().apply(result);
321-
}
322-
323311
std::vector<Tensor> split_batching_rule(const Tensor& self, int64_t split_size, int64_t dim) {
324312
if (!participatesInCurrentLevel(self)) {
325313
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
@@ -884,8 +872,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
884872
// m.impl("chunk", chunk_batching_rule);
885873
m.impl("tensor_split.sections", tensor_split_sections_batching_rule);
886874
m.impl("tensor_split.indices", tensor_split_indices_batching_rule);
887-
m.impl("movedim.intlist", movedim_batching_rule);
888-
m.impl("movedim.int", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t)>(native::movedim)); // composite wrt autograd
889875
// NB: static_cast because there's another variant of narrow. However, we don't
890876
// want to support the other variant yet bc it isn't documented...
891877
m.impl("numpy_T", native::numpy_T); // composite wrt autograd

0 commit comments

Comments
 (0)