Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 14283d3

Browse files
authored
add scatter.reduce batching rule (#188)
* add scatter.reduce batching rule * minor update
1 parent e0e6001 commit 14283d3

File tree

2 files changed

+45
-11
lines changed

2 files changed

+45
-11
lines changed

functorch/csrc/BatchRulesScatterOps.cpp

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,15 @@ int64_t bdim_size(
186186
TORCH_INTERNAL_ASSERT(false);
187187
}
188188

189-
std::tuple<Tensor,optional<int64_t>> scatter_value_batch_rule(
189+
namespace {
190+
191+
template<typename Func, typename ...Args>
192+
std::tuple<Tensor,optional<int64_t>> scatter_batch_rule(
193+
Func f,
190194
const Tensor& self, optional<int64_t> self_bdim,
191195
int64_t dim,
192196
const Tensor& index, optional<int64_t> index_bdim,
193-
const Scalar& value) {
197+
const Scalar& value, Args... args) {
194198
auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
195199
auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
196200
auto batch_size = bdim_size(self, self_bdim, index, index_bdim);
@@ -208,23 +212,21 @@ std::tuple<Tensor,optional<int64_t>> scatter_value_batch_rule(
208212
index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
209213
auto physical_dim = getPhysicalDim(self_, /*has_batch_dim*/true, dim);
210214

211-
auto result = at::scatter(self_, physical_dim, index_, value);
212-
// result should have same rank as self
215+
auto result = f(self_, physical_dim, index_, value, args...);
216+
// result should have same shape as self
213217
if (self_logical_rank == 0) {
214218
result = result.squeeze(-1);
215219
}
216220
return std::make_tuple(result, 0);
217221
}
218222

219-
namespace {
220-
221-
template <typename Func>
223+
template <typename Func, typename ...Args>
222224
inline std::tuple<Tensor,optional<int64_t>> scatter_batch_rule(
223225
Func f,
224226
const Tensor& self, optional<int64_t> self_bdim,
225227
int64_t dim,
226228
const Tensor& index, optional<int64_t> index_bdim,
227-
const Tensor& src, optional<int64_t> src_bdim) {
229+
const Tensor& src, optional<int64_t> src_bdim, Args... args) {
228230
auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
229231
auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
230232
auto src_logical_rank = rankWithoutBatchDim(src, src_bdim);
@@ -248,8 +250,8 @@ inline std::tuple<Tensor,optional<int64_t>> scatter_batch_rule(
248250
src_ = ensure_has_bdim(src_, src_bdim.has_value(), batch_size);
249251
auto physical_dim = getPhysicalDim(self_, /*has_batch_dim*/true, dim);
250252

251-
auto result = f(self_, physical_dim, index_, src_);
252-
// result should have same rank as self
253+
auto result = f(self_, physical_dim, index_, src_, args...);
254+
// result should have same shape as self
253255
if (self_logical_rank == 0) {
254256
result = result.squeeze(-1);
255257
}
@@ -258,6 +260,15 @@ inline std::tuple<Tensor,optional<int64_t>> scatter_batch_rule(
258260

259261
} // namespace
260262

263+
std::tuple<Tensor,optional<int64_t>> scatter_value_batch_rule(
264+
const Tensor& self, optional<int64_t> self_bdim,
265+
int64_t dim,
266+
const Tensor& index, optional<int64_t> index_bdim,
267+
const Scalar& value) {
268+
return scatter_batch_rule(ATEN_FN2(scatter, value),
269+
self, self_bdim, dim, index, index_bdim, value);
270+
}
271+
261272
std::tuple<Tensor,optional<int64_t>> scatter_src_batch_rule(
262273
const Tensor& self, optional<int64_t> self_bdim,
263274
int64_t dim,
@@ -276,6 +287,28 @@ std::tuple<Tensor,optional<int64_t>> scatter_add_batch_rule(
276287
self, self_bdim, dim, index, index_bdim, src, src_bdim);
277288
}
278289

290+
std::tuple<Tensor,optional<int64_t>> scatter_reduce_batch_rule(
291+
const Tensor& self, optional<int64_t> self_bdim,
292+
int64_t dim,
293+
const Tensor& index, optional<int64_t> index_bdim,
294+
const Tensor& src, optional<int64_t> src_bdim,
295+
const c10::string_view reduce) {
296+
using scatter_reduce_value_sig = Tensor (*)(const Tensor&, int64_t, const Tensor&, const Tensor&, const c10::string_view reduce);
297+
return scatter_batch_rule(ATEN_FN2(scatter, reduce),
298+
self, self_bdim, dim, index, index_bdim, src, src_bdim, reduce);
299+
}
300+
301+
std::tuple<Tensor,optional<int64_t>> scatter_value_reduce_batch_rule(
302+
const Tensor& self, optional<int64_t> self_bdim,
303+
int64_t dim,
304+
const Tensor& index, optional<int64_t> index_bdim,
305+
const Scalar& src,
306+
const c10::string_view reduce) {
307+
using scatter_reduce_value_sig = Tensor (*)(const Tensor&, int64_t, const Tensor&, const Scalar&, const c10::string_view reduce);
308+
return scatter_batch_rule(ATEN_FN2(scatter, value_reduce),
309+
self, self_bdim, dim, index, index_bdim, src, reduce);
310+
}
311+
279312
std::tuple<Tensor,optional<int64_t>> gather_batch_rule(
280313
const Tensor& self, optional<int64_t> self_bdim,
281314
int64_t dim,
@@ -400,6 +433,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
400433
VMAP_SUPPORT("scatter.value", scatter_value_batch_rule);
401434
VMAP_SUPPORT("scatter.src", scatter_src_batch_rule);
402435
VMAP_SUPPORT("scatter_add", scatter_add_batch_rule);
436+
VMAP_SUPPORT("scatter.reduce", scatter_reduce_batch_rule);
437+
VMAP_SUPPORT("scatter.value_reduce", scatter_value_reduce_batch_rule);
403438
VMAP_SUPPORT("index_select", index_select_batch_rule);
404439

405440
}

test/test_vmap.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3089,7 +3089,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
30893089
xfail('resize_as_'),
30903090
xfail('resolve_conj'),
30913091
xfail('resolve_neg'),
3092-
xfail('scatter'),
30933092
xfail('take'),
30943093
xfail('take_along_dim'),
30953094
xfail('tensor_split'),

0 commit comments

Comments
 (0)