@@ -186,11 +186,15 @@ int64_t bdim_size(
186
186
TORCH_INTERNAL_ASSERT (false );
187
187
}
188
188
189
- std::tuple<Tensor,optional<int64_t >> scatter_value_batch_rule (
189
+ namespace {
190
+
191
+ template <typename Func, typename ...Args>
192
+ std::tuple<Tensor,optional<int64_t >> scatter_batch_rule (
193
+ Func f,
190
194
const Tensor& self, optional<int64_t > self_bdim,
191
195
int64_t dim,
192
196
const Tensor& index, optional<int64_t > index_bdim,
193
- const Scalar& value) {
197
+ const Scalar& value, Args... args ) {
194
198
auto self_logical_rank = rankWithoutBatchDim (self, self_bdim);
195
199
auto index_logical_rank = rankWithoutBatchDim (index, index_bdim);
196
200
auto batch_size = bdim_size (self, self_bdim, index, index_bdim);
@@ -208,23 +212,21 @@ std::tuple<Tensor,optional<int64_t>> scatter_value_batch_rule(
208
212
index_ = ensure_has_bdim (index_, index_bdim.has_value (), batch_size);
209
213
auto physical_dim = getPhysicalDim (self_, /* has_batch_dim*/ true , dim);
210
214
211
- auto result = at::scatter (self_, physical_dim, index_, value);
212
- // result should have same rank as self
215
+ auto result = f (self_, physical_dim, index_, value, args... );
216
+ // result should have same shape as self
213
217
if (self_logical_rank == 0 ) {
214
218
result = result.squeeze (-1 );
215
219
}
216
220
return std::make_tuple (result, 0 );
217
221
}
218
222
219
- namespace {
220
-
221
- template <typename Func>
223
+ template <typename Func, typename ...Args>
222
224
inline std::tuple<Tensor,optional<int64_t >> scatter_batch_rule (
223
225
Func f,
224
226
const Tensor& self, optional<int64_t > self_bdim,
225
227
int64_t dim,
226
228
const Tensor& index, optional<int64_t > index_bdim,
227
- const Tensor& src, optional<int64_t > src_bdim) {
229
+ const Tensor& src, optional<int64_t > src_bdim, Args... args ) {
228
230
auto self_logical_rank = rankWithoutBatchDim (self, self_bdim);
229
231
auto index_logical_rank = rankWithoutBatchDim (index, index_bdim);
230
232
auto src_logical_rank = rankWithoutBatchDim (src, src_bdim);
@@ -248,8 +250,8 @@ inline std::tuple<Tensor,optional<int64_t>> scatter_batch_rule(
248
250
src_ = ensure_has_bdim (src_, src_bdim.has_value (), batch_size);
249
251
auto physical_dim = getPhysicalDim (self_, /* has_batch_dim*/ true , dim);
250
252
251
- auto result = f (self_, physical_dim, index_, src_);
252
- // result should have same rank as self
253
+ auto result = f (self_, physical_dim, index_, src_, args... );
254
+ // result should have same shape as self
253
255
if (self_logical_rank == 0 ) {
254
256
result = result.squeeze (-1 );
255
257
}
@@ -258,6 +260,15 @@ inline std::tuple<Tensor,optional<int64_t>> scatter_batch_rule(
258
260
259
261
} // namespace
260
262
263
+ std::tuple<Tensor,optional<int64_t >> scatter_value_batch_rule (
264
+ const Tensor& self, optional<int64_t > self_bdim,
265
+ int64_t dim,
266
+ const Tensor& index, optional<int64_t > index_bdim,
267
+ const Scalar& value) {
268
+ return scatter_batch_rule (ATEN_FN2 (scatter, value),
269
+ self, self_bdim, dim, index, index_bdim, value);
270
+ }
271
+
261
272
std::tuple<Tensor,optional<int64_t >> scatter_src_batch_rule (
262
273
const Tensor& self, optional<int64_t > self_bdim,
263
274
int64_t dim,
@@ -276,6 +287,28 @@ std::tuple<Tensor,optional<int64_t>> scatter_add_batch_rule(
276
287
self, self_bdim, dim, index, index_bdim, src, src_bdim);
277
288
}
278
289
290
+ std::tuple<Tensor,optional<int64_t >> scatter_reduce_batch_rule (
291
+ const Tensor& self, optional<int64_t > self_bdim,
292
+ int64_t dim,
293
+ const Tensor& index, optional<int64_t > index_bdim,
294
+ const Tensor& src, optional<int64_t > src_bdim,
295
+ const c10::string_view reduce) {
296
+ using scatter_reduce_value_sig = Tensor (*)(const Tensor&, int64_t , const Tensor&, const Tensor&, const c10::string_view reduce);
297
+ return scatter_batch_rule (ATEN_FN2 (scatter, reduce),
298
+ self, self_bdim, dim, index, index_bdim, src, src_bdim, reduce);
299
+ }
300
+
301
+ std::tuple<Tensor,optional<int64_t >> scatter_value_reduce_batch_rule (
302
+ const Tensor& self, optional<int64_t > self_bdim,
303
+ int64_t dim,
304
+ const Tensor& index, optional<int64_t > index_bdim,
305
+ const Scalar& src,
306
+ const c10::string_view reduce) {
307
+ using scatter_reduce_value_sig = Tensor (*)(const Tensor&, int64_t , const Tensor&, const Scalar&, const c10::string_view reduce);
308
+ return scatter_batch_rule (ATEN_FN2 (scatter, value_reduce),
309
+ self, self_bdim, dim, index, index_bdim, src, reduce);
310
+ }
311
+
279
312
std::tuple<Tensor,optional<int64_t >> gather_batch_rule (
280
313
const Tensor& self, optional<int64_t > self_bdim,
281
314
int64_t dim,
@@ -400,6 +433,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
400
433
VMAP_SUPPORT (" scatter.value" , scatter_value_batch_rule);
401
434
VMAP_SUPPORT (" scatter.src" , scatter_src_batch_rule);
402
435
VMAP_SUPPORT (" scatter_add" , scatter_add_batch_rule);
436
+ VMAP_SUPPORT (" scatter.reduce" , scatter_reduce_batch_rule);
437
+ VMAP_SUPPORT (" scatter.value_reduce" , scatter_value_reduce_batch_rule);
403
438
VMAP_SUPPORT (" index_select" , index_select_batch_rule);
404
439
405
440
}
0 commit comments