Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 04d5ff1

Browse files
authored
[port] expand to new api (#161)
* [port] expand to new api * update code * update code * fix incorrect merge * retrigger CI
1 parent 69a6b51 commit 04d5ff1

File tree

2 files changed

+35
-44
lines changed

2 files changed

+35
-44
lines changed

functorch/csrc/BatchRulesViews.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,40 @@ std::tuple<Tensor,optional<int64_t>> slice_backward_batch_rule(
354354
return std::make_tuple(std::move(result), 0);
355355
}
356356

357+
std::tuple<Tensor, optional<int64_t>> expand_batch_rule(
358+
const Tensor &self, optional<int64_t> self_bdim, IntArrayRef size, bool implicit)
359+
{
360+
auto self_dim = self.dim();
361+
TORCH_CHECK(static_cast<uint64_t>(self_dim - 1) <= size.size(),
362+
"expand: the number of sizes provided (", size.size(), ") ",
363+
"must be greater or equal to the number of dimensions in the tensor (", static_cast<uint64_t>(self_dim - 1), ")");
364+
365+
auto self_ = moveBatchDimToFront(self, self_bdim);
366+
auto self_sizes = self_.sizes();
367+
auto batch_size = self_sizes[0];
368+
369+
c10::SmallBuffer<int64_t, 5> size_(size.size() + 1);
370+
size_[0] = batch_size;
371+
std::copy(size.cbegin(), size.cend(), size_.begin() + 1);
372+
373+
// Here, we know we are expanding a (logical) tensor to a larger number
374+
// of dimensions. We have to be careful because we can't call expand directly
375+
// due to the presence of batch dimensions.
376+
//
377+
// As an example, let B0 be a batch dimension and consider expand(Tensor[B0, 3], [2, 3]).
378+
// The result should be a tensor of size [B0, 2, 3].
379+
// A physical view of size [B0, 3] can't directly be expanded to size [B0, 2, 3]
380+
// so the strategy here is to view it first as a tensor of size [B0, 1, 3] and
381+
// then expand.
382+
auto extra_dims = size.size() - (self_dim - 1);
383+
VmapDimVector view_shape(size_.size(), /*init_value*/1);
384+
view_shape[0] = batch_size;
385+
std::copy(self_sizes.cbegin() + 1, self_sizes.cend(),
386+
view_shape.begin() + 1 + extra_dims);
387+
388+
return std::make_tuple(self_.view(view_shape).expand(size_, implicit), 0);
389+
}
390+
357391
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
358392
VMAP_SUPPORT("diag", diag_batch_rule);
359393
VMAP_SUPPORT("chunk", chunk_batching_rule);
@@ -375,6 +409,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
375409
VMAP_SUPPORT("diagonal_backward", diagonal_backward_batch_rule);
376410
VMAP_SUPPORT("select_backward", select_backward_batch_rule);
377411
VMAP_SUPPORT("slice_backward", slice_backward_batch_rule);
412+
VMAP_SUPPORT("expand", expand_batch_rule);
378413
}
379414

380415
}}

functorch/csrc/BatchingRegistrations.cpp

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -132,49 +132,6 @@ bool isPhysicalScalarTensor(const Tensor& logical_tensor) {
132132
return true;
133133
}
134134

135-
Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit) {
136-
if (!participatesInCurrentLevel(self)) {
137-
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
138-
return self.expand(size, implicit);
139-
}
140-
141-
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
142-
auto size_physical = self_physical.getPhysicalShape(size);
143-
auto self_physical_dim = self_physical.tensor().dim();
144-
145-
TORCH_CHECK((uint64_t)self_physical_dim <= size_physical.size(),
146-
"expand: the number of sizes provided (", /*logical*/size.size(), ") ",
147-
"must be greater or equal to the number of dimensions in the tensor (",
148-
/*logical dim*/self.dim(), ")");
149-
150-
if ((uint64_t)self_physical_dim == size_physical.size()) {
151-
auto result = self_physical.tensor().expand(size_physical, implicit);
152-
return self_physical.getPhysicalToLogicalMap().apply(result);
153-
}
154-
155-
TORCH_INTERNAL_ASSERT((uint64_t)self_physical_dim < size_physical.size());
156-
// Here, we know we are expanding a (logical) tensor to a larger number
157-
// of dimensions. We have to be careful because we can't call expand directly
158-
// due to the presence of batch dimensions.
159-
//
160-
// As an example, let B0 be a batch dimension and consider expand(Tensor[B0, 3], [2, 3]).
161-
// The result should be a tensor of size [B0, 2, 3].
162-
// A physical view of size [B0, 3] can't directly be expanded to size [B0, 2, 3]
163-
// so the strategy here is to view it first as a tensor of size [B0, 1, 3] and
164-
// then expand.
165-
auto self_physical_size = self_physical.tensor().sizes();
166-
auto extra_dims = size_physical.size() - self_physical_dim;
167-
VmapDimVector view_shape(size_physical.size(), 1);
168-
std::copy(self_physical_size.begin(),
169-
self_physical_size.begin() + self_physical.numBatchDims(),
170-
view_shape.begin());
171-
std::copy(self_physical_size.begin() + self_physical.numBatchDims(),
172-
self_physical_size.end(),
173-
view_shape.begin() + self_physical.numBatchDims() + extra_dims);
174-
auto result = self_physical.tensor().view(view_shape).expand(size_physical, implicit);
175-
return self_physical.getPhysicalToLogicalMap().apply(result);
176-
}
177-
178135
std::vector<Tensor> chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) {
179136
if (!participatesInCurrentLevel(self)) {
180137
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
@@ -1001,7 +958,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
1001958
// m.impl("chunk", chunk_batching_rule);
1002959
m.impl("tensor_split.sections", tensor_split_sections_batching_rule);
1003960
m.impl("tensor_split.indices", tensor_split_indices_batching_rule);
1004-
m.impl("expand", expand_batching_rule);
1005961
m.impl("movedim.intlist", movedim_batching_rule);
1006962
m.impl("movedim.int", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t)>(native::movedim)); // composite wrt autograd
1007963
// NB: static_cast because there's another variant of narrow. However, we don't

0 commit comments

Comments
 (0)