@@ -219,7 +219,7 @@ std::tuple<Tensor,optional<int64_t>> scatter_value_batch_rule(
219
219
auto physical_dim = getPhysicalDim (self_, /* has_batch_dim*/ true , dim);
220
220
221
221
auto result = at::scatter (self_, physical_dim, index_, value);
222
- // result should have same shape as self
222
+ // result should have same rank as self
223
223
if (self_logical_rank == 0 ) {
224
224
result = result.squeeze (-1 );
225
225
}
@@ -259,7 +259,7 @@ inline std::tuple<Tensor,optional<int64_t>> scatter_batch_rule(
259
259
auto physical_dim = getPhysicalDim (self_, /* has_batch_dim*/ true , dim);
260
260
261
261
auto result = f (self_, physical_dim, index_, src_);
262
- // result should have same shape as self
262
+ // result should have same rank as self
263
263
if (self_logical_rank == 0 ) {
264
264
result = result.squeeze (-1 );
265
265
}
@@ -309,7 +309,7 @@ std::tuple<Tensor,optional<int64_t>> gather_batch_rule(
309
309
auto physical_dim = getPhysicalDim (self_, /* has_batch_dim*/ true , dim);
310
310
311
311
auto result = at::gather (self_, physical_dim, index_, sparse_grad);
312
- // result should have same shape as index
312
+ // result should have same rank as index
313
313
if (index_logical_rank == 0 ) {
314
314
result = result.squeeze (-1 );
315
315
}
@@ -346,7 +346,56 @@ std::tuple<Tensor,optional<int64_t>> gather_backward_batch_rule(
346
346
347
347
auto physical_dim = getPhysicalDim (self_, /* has_batch_dim*/ true , dim);
348
348
auto result = at::gather_backward (grad_, self_, physical_dim, index_, sparse_grad);
349
- // result should has same shape as self
349
+ // result should has same rank as self
350
+ if (self_logical_rank == 0 ) {
351
+ result = result.squeeze (-1 );
352
+ }
353
+ return std::make_tuple (result, 0 );
354
+ }
355
+
356
+ std::tuple<Tensor, optional<int64_t >> index_select_batch_rule (
357
+ const Tensor& self, optional<int64_t > self_bdim,
358
+ int64_t dim,
359
+ const Tensor& index, optional<int64_t > index_bdim) {
360
+
361
+ auto self_logical_rank = rankWithoutBatchDim (self, self_bdim);
362
+ auto index_logical_rank = rankWithoutBatchDim (index, index_bdim);
363
+ auto batch_size = bdim_size (self, self_bdim, index, index_bdim);
364
+
365
+ auto self_ = moveBatchDimToFront (self, self_bdim);
366
+ auto index_ = moveBatchDimToFront (index, index_bdim);
367
+
368
+ if (self_logical_rank == 0 ) {
369
+ self_ = self_.unsqueeze (-1 );
370
+ }
371
+ if (index_logical_rank == 0 ) {
372
+ index_ = index_.unsqueeze (-1 );
373
+ }
374
+ self_ = ensure_has_bdim (self_, self_bdim.has_value (), batch_size);
375
+ index_ = ensure_has_bdim (index_, index_bdim.has_value (), batch_size);
376
+ auto physical_dim = getPhysicalDim (self_, /* has_batch_dim*/ true , dim);
377
+
378
+ if (index_.dim () < self_.dim ()) {
379
+ // setup new_index_shape as [BS, 1, ..., le, ..., 1]
380
+ // to reshape index_
381
+ auto le = index_.size (1 ); // get non-batch size of index tensor
382
+ {
383
+ VmapDimVector new_index_shape (self_.dim (), 1 );
384
+ new_index_shape[0 ] = self_.size (0 ); // set up batch size
385
+ new_index_shape[physical_dim] = le;
386
+ index_ = index_.reshape (new_index_shape);
387
+ }
388
+ // Now apply expand to index_
389
+ {
390
+ auto self_shape = self_.sizes ();
391
+ VmapDimVector new_index_shape = {self_shape.begin (), self_shape.end ()};
392
+ new_index_shape[physical_dim] = le;
393
+ index_ = index_.expand (new_index_shape);
394
+ }
395
+ }
396
+
397
+ auto result = at::gather (self_, physical_dim, index_);
398
+ // result should have same rank as self
350
399
if (self_logical_rank == 0 ) {
351
400
result = result.squeeze (-1 );
352
401
}
@@ -361,6 +410,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
361
410
VMAP_SUPPORT (" scatter.value" , scatter_value_batch_rule);
362
411
VMAP_SUPPORT (" scatter.src" , scatter_src_batch_rule);
363
412
VMAP_SUPPORT (" scatter_add" , scatter_add_batch_rule);
413
+ VMAP_SUPPORT (" index_select" , index_select_batch_rule);
414
+
364
415
}
365
416
366
417
}}
0 commit comments