@@ -226,7 +226,11 @@ std::tuple<Tensor,optional<int64_t>> scatter_value_batch_rule(
226
226
return std::make_tuple (result, 0 );
227
227
}
228
228
229
- std::tuple<Tensor,optional<int64_t >> scatter_src_batch_rule (
229
+ namespace {
230
+
231
+ template <typename Func>
232
+ inline std::tuple<Tensor,optional<int64_t >> scatter_batch_rule (
233
+ Func f,
230
234
const Tensor& self, optional<int64_t > self_bdim,
231
235
int64_t dim,
232
236
const Tensor& index, optional<int64_t > index_bdim,
@@ -254,14 +258,34 @@ std::tuple<Tensor,optional<int64_t>> scatter_src_batch_rule(
254
258
src_ = ensure_has_bdim (src_, src_bdim.has_value (), batch_size);
255
259
auto physical_dim = getPhysicalDim (self_, /* has_batch_dim*/ true , dim);
256
260
257
- auto result = at::scatter (self_, physical_dim, index_, src_);
261
+ auto result = f (self_, physical_dim, index_, src_);
258
262
// result should have same shape as self
259
263
if (self_logical_rank == 0 ) {
260
264
result = result.squeeze (-1 );
261
265
}
262
266
return std::make_tuple (result, 0 );
263
267
}
264
268
269
+ } // namespace
270
+
271
+ std::tuple<Tensor,optional<int64_t >> scatter_src_batch_rule (
272
+ const Tensor& self, optional<int64_t > self_bdim,
273
+ int64_t dim,
274
+ const Tensor& index, optional<int64_t > index_bdim,
275
+ const Tensor& src, optional<int64_t > src_bdim) {
276
+ return scatter_batch_rule (ATEN_FN2 (scatter, src),
277
+ self, self_bdim, dim, index, index_bdim, src, src_bdim);
278
+ }
279
+
280
+ std::tuple<Tensor,optional<int64_t >> scatter_add_batch_rule (
281
+ const Tensor& self, optional<int64_t > self_bdim,
282
+ int64_t dim,
283
+ const Tensor& index, optional<int64_t > index_bdim,
284
+ const Tensor& src, optional<int64_t > src_bdim) {
285
+ return scatter_batch_rule (ATEN_FN (scatter_add),
286
+ self, self_bdim, dim, index, index_bdim, src, src_bdim);
287
+ }
288
+
265
289
std::tuple<Tensor,optional<int64_t >> gather_batch_rule (
266
290
const Tensor& self, optional<int64_t > self_bdim,
267
291
int64_t dim,
@@ -336,6 +360,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
336
360
VMAP_SUPPORT (" gather_backward" , gather_backward_batch_rule);
337
361
VMAP_SUPPORT (" scatter.value" , scatter_value_batch_rule);
338
362
VMAP_SUPPORT (" scatter.src" , scatter_src_batch_rule);
363
+ VMAP_SUPPORT (" scatter_add" , scatter_add_batch_rule);
339
364
}
340
365
341
366
}}
0 commit comments