Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 609f257

Browse files
authored
Added index_select batching rule (#183)
* Added index_select batching rule Description: - Added index_select batching rule - Updated tests Note: index_select_backward can not done due to its composite nature and `index_add_` usage * Attempt to fix failing tests * Fixed issue with tests * Fixed nit comment
1 parent bacb094 commit 609f257

File tree

2 files changed

+56
-6
lines changed

2 files changed

+56
-6
lines changed

functorch/csrc/BatchRulesScatterOps.cpp

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ std::tuple<Tensor,optional<int64_t>> scatter_value_batch_rule(
219219
auto physical_dim = getPhysicalDim(self_, /*has_batch_dim*/true, dim);
220220

221221
auto result = at::scatter(self_, physical_dim, index_, value);
222-
// result should have same shape as self
222+
// result should have same rank as self
223223
if (self_logical_rank == 0) {
224224
result = result.squeeze(-1);
225225
}
@@ -259,7 +259,7 @@ inline std::tuple<Tensor,optional<int64_t>> scatter_batch_rule(
259259
auto physical_dim = getPhysicalDim(self_, /*has_batch_dim*/true, dim);
260260

261261
auto result = f(self_, physical_dim, index_, src_);
262-
// result should have same shape as self
262+
// result should have same rank as self
263263
if (self_logical_rank == 0) {
264264
result = result.squeeze(-1);
265265
}
@@ -309,7 +309,7 @@ std::tuple<Tensor,optional<int64_t>> gather_batch_rule(
309309
auto physical_dim = getPhysicalDim(self_, /*has_batch_dim*/true, dim);
310310

311311
auto result = at::gather(self_, physical_dim, index_, sparse_grad);
312-
// result should have same shape as index
312+
// result should have same rank as index
313313
if (index_logical_rank == 0) {
314314
result = result.squeeze(-1);
315315
}
@@ -346,7 +346,56 @@ std::tuple<Tensor,optional<int64_t>> gather_backward_batch_rule(
346346

347347
auto physical_dim = getPhysicalDim(self_, /*has_batch_dim*/true, dim);
348348
auto result = at::gather_backward(grad_, self_, physical_dim, index_, sparse_grad);
349-
// result should has same shape as self
349+
// result should has same rank as self
350+
if (self_logical_rank == 0) {
351+
result = result.squeeze(-1);
352+
}
353+
return std::make_tuple(result, 0);
354+
}
355+
356+
std::tuple<Tensor, optional<int64_t>> index_select_batch_rule(
357+
const Tensor& self, optional<int64_t> self_bdim,
358+
int64_t dim,
359+
const Tensor& index, optional<int64_t> index_bdim) {
360+
361+
auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
362+
auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
363+
auto batch_size = bdim_size(self, self_bdim, index, index_bdim);
364+
365+
auto self_ = moveBatchDimToFront(self, self_bdim);
366+
auto index_ = moveBatchDimToFront(index, index_bdim);
367+
368+
if (self_logical_rank == 0) {
369+
self_ = self_.unsqueeze(-1);
370+
}
371+
if (index_logical_rank == 0) {
372+
index_ = index_.unsqueeze(-1);
373+
}
374+
self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
375+
index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
376+
auto physical_dim = getPhysicalDim(self_, /*has_batch_dim*/true, dim);
377+
378+
if (index_.dim() < self_.dim()) {
379+
// setup new_index_shape as [BS, 1, ..., le, ..., 1]
380+
// to reshape index_
381+
auto le = index_.size(1); // get non-batch size of index tensor
382+
{
383+
VmapDimVector new_index_shape(self_.dim(), 1);
384+
new_index_shape[0] = self_.size(0); // set up batch size
385+
new_index_shape[physical_dim] = le;
386+
index_ = index_.reshape(new_index_shape);
387+
}
388+
// Now apply expand to index_
389+
{
390+
auto self_shape = self_.sizes();
391+
VmapDimVector new_index_shape = {self_shape.begin(), self_shape.end()};
392+
new_index_shape[physical_dim] = le;
393+
index_ = index_.expand(new_index_shape);
394+
}
395+
}
396+
397+
auto result = at::gather(self_, physical_dim, index_);
398+
// result should have same rank as self
350399
if (self_logical_rank == 0) {
351400
result = result.squeeze(-1);
352401
}
@@ -361,6 +410,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
361410
VMAP_SUPPORT("scatter.value", scatter_value_batch_rule);
362411
VMAP_SUPPORT("scatter.src", scatter_src_batch_rule);
363412
VMAP_SUPPORT("scatter_add", scatter_add_batch_rule);
413+
VMAP_SUPPORT("index_select", index_select_batch_rule);
414+
364415
}
365416

366417
}}

test/test_vmap.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2994,7 +2994,7 @@ class TestVmapOperatorsOpInfo(TestCase):
29942994
xfail('block_diag'),
29952995
xfail('nn.functional.dropout'),
29962996
2997-
# entries in here need don't work and need to be fixed.
2997+
# entries in here don't work and need to be fixed.
29982998
# Each one of these is a bug
29992999
xfail('unfold'),
30003000
xfail('svd', device_type='cuda'),
@@ -3041,7 +3041,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
30413041
xfail('index_copy'),
30423042
xfail('index_fill'),
30433043
xfail('index_put'),
3044-
xfail('index_select'),
30453044
xfail('isin'),
30463045
xfail('linalg.cholesky'),
30473046
xfail('linalg.eigvals'),

0 commit comments

Comments
 (0)