Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 076302c

Browse files
authored
add scatter_add batch rule (#182)
* add scatter_add batch rule * update test_ops * retrigger CI * address review
1 parent 71a446a commit 076302c

File tree

3 files changed

+27
-4
lines changed

3 files changed

+27
-4
lines changed

functorch/csrc/BatchRulesScatterOps.cpp

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,11 @@ std::tuple<Tensor,optional<int64_t>> scatter_value_batch_rule(
226226
return std::make_tuple(result, 0);
227227
}
228228

229-
std::tuple<Tensor,optional<int64_t>> scatter_src_batch_rule(
229+
namespace {
230+
231+
template <typename Func>
232+
inline std::tuple<Tensor,optional<int64_t>> scatter_batch_rule(
233+
Func f,
230234
const Tensor& self, optional<int64_t> self_bdim,
231235
int64_t dim,
232236
const Tensor& index, optional<int64_t> index_bdim,
@@ -254,14 +258,34 @@ std::tuple<Tensor,optional<int64_t>> scatter_src_batch_rule(
254258
src_ = ensure_has_bdim(src_, src_bdim.has_value(), batch_size);
255259
auto physical_dim = getPhysicalDim(self_, /*has_batch_dim*/true, dim);
256260

257-
auto result = at::scatter(self_, physical_dim, index_, src_);
261+
auto result = f(self_, physical_dim, index_, src_);
258262
// result should have same shape as self
259263
if (self_logical_rank == 0) {
260264
result = result.squeeze(-1);
261265
}
262266
return std::make_tuple(result, 0);
263267
}
264268

269+
} // namespace
270+
271+
std::tuple<Tensor,optional<int64_t>> scatter_src_batch_rule(
272+
const Tensor& self, optional<int64_t> self_bdim,
273+
int64_t dim,
274+
const Tensor& index, optional<int64_t> index_bdim,
275+
const Tensor& src, optional<int64_t> src_bdim) {
276+
return scatter_batch_rule(ATEN_FN2(scatter, src),
277+
self, self_bdim, dim, index, index_bdim, src, src_bdim);
278+
}
279+
280+
std::tuple<Tensor,optional<int64_t>> scatter_add_batch_rule(
281+
const Tensor& self, optional<int64_t> self_bdim,
282+
int64_t dim,
283+
const Tensor& index, optional<int64_t> index_bdim,
284+
const Tensor& src, optional<int64_t> src_bdim) {
285+
return scatter_batch_rule(ATEN_FN(scatter_add),
286+
self, self_bdim, dim, index, index_bdim, src, src_bdim);
287+
}
288+
265289
std::tuple<Tensor,optional<int64_t>> gather_batch_rule(
266290
const Tensor& self, optional<int64_t> self_bdim,
267291
int64_t dim,
@@ -336,6 +360,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
336360
VMAP_SUPPORT("gather_backward", gather_backward_batch_rule);
337361
VMAP_SUPPORT("scatter.value", scatter_value_batch_rule);
338362
VMAP_SUPPORT("scatter.src", scatter_src_batch_rule);
363+
VMAP_SUPPORT("scatter_add", scatter_add_batch_rule);
339364
}
340365

341366
}}

test/test_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,6 @@ def test_vmapvjp(self, device, dtype, op):
458458
xfail('quantile'),
459459
xfail('renorm'),
460460
xfail('repeat_interleave'),
461-
xfail('scatter_add'),
462461
xfail('solve'),
463462
xfail('sort'),
464463
xfail('symeig'),

test/test_vmap.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3083,7 +3083,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
30833083
xfail('resolve_conj'),
30843084
xfail('resolve_neg'),
30853085
xfail('scatter'),
3086-
xfail('scatter_add'),
30873086
xfail('take'),
30883087
xfail('take_along_dim'),
30893088
xfail('tensor_split'),

0 commit comments

Comments
 (0)