Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 0b319e5

Browse files
authored
[port] view to new api (#164)
* [port] view to new api * handle no batch dim case * forward the correct bdim * refactor non bdim code * address review
1 parent dc82a73 commit 0b319e5

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

functorch/csrc/BatchRulesViews.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,18 @@ std::tuple<Tensor,optional<int64_t>> slice_backward_batch_rule(
372372
return std::make_tuple(std::move(result), 0);
373373
}
374374

375+
std::tuple<Tensor, optional<int64_t>> view_batching_rule(
376+
const Tensor &self, optional<int64_t> self_bdim, IntArrayRef size)
377+
{
378+
TORCH_INTERNAL_ASSERT(self_bdim.has_value());
379+
auto self_ = moveBatchDimToFront(self, self_bdim);
380+
VmapDimVector size_(size.size() + 1);
381+
// copy batch size
382+
size_[0] = self_.size(0);
383+
std::copy(size.cbegin(), size.cend(), size_.begin() + 1);
384+
return std::make_tuple(self_.view(size_), 0);
385+
}
386+
375387
std::tuple<Tensor, optional<int64_t>> expand_batch_rule(
376388
const Tensor &self, optional<int64_t> self_bdim, IntArrayRef size, bool implicit)
377389
{
@@ -428,6 +440,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
428440
VMAP_SUPPORT("diagonal_backward", diagonal_backward_batch_rule);
429441
VMAP_SUPPORT("select_backward", select_backward_batch_rule);
430442
VMAP_SUPPORT("slice_backward", slice_backward_batch_rule);
443+
VMAP_SUPPORT("view", view_batching_rule);
431444
VMAP_SUPPORT("expand", expand_batch_rule);
432445
}
433446

functorch/csrc/BatchingRegistrations.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -407,17 +407,6 @@ Tensor contiguous_batching_rule(const Tensor& self, MemoryFormat memory_format)
407407
return physical_view.getPhysicalToLogicalMap().apply(result);
408408
}
409409

410-
Tensor view_batching_rule(const Tensor& self, IntArrayRef size) {
411-
if (!participatesInCurrentLevel(self)) {
412-
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
413-
return self.view(size);
414-
}
415-
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
416-
auto size_physical = self_physical.getPhysicalShape(size);
417-
auto result = self_physical.tensor().view(size_physical);
418-
return self_physical.getPhysicalToLogicalMap().apply(result);
419-
}
420-
421410
Tensor view_as_complex_batching_rule(const Tensor& self) {
422411
if (!participatesInCurrentLevel(self)) {
423412
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
@@ -952,7 +941,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
952941
m.impl("unbind.int", unbind_batching_rule);
953942
m.impl("unfold", unfold_batching_rule);
954943
m.impl("unsqueeze_", unsqueeze__batching_rule);
955-
m.impl("view", view_batching_rule);
956944
m.impl("view_as", native::view_as); // composite wrt autograd
957945

958946
m.impl("addmm", addmm_batching_rule);

0 commit comments

Comments
 (0)