Skip to content

Commit e3cbeed

Browse files
Add op: scatter.src_out
Differential Revision: D62143589 Pull Request resolved: #5037
1 parent 35d0f59 commit e3cbeed

File tree

6 files changed

+389
-7
lines changed

6 files changed

+389
-7
lines changed

kernels/aten/functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,8 @@
323323

324324
- op: scalar_tensor.out
325325

326+
- op: scatter.src_out
327+
326328
- op: scatter.value_out
327329

328330
- op: scatter_add.out

kernels/portable/cpu/op_scatter.cpp

Lines changed: 74 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,46 @@ using ScalarType = exec_aten::ScalarType;
2323

2424
namespace {
2525

26+
template <typename CTYPE>
27+
void scatter_src_helper(
28+
const Tensor& in,
29+
int64_t dim,
30+
const Tensor& index,
31+
const Tensor& src,
32+
Tensor& out) {
33+
const CTYPE* in_data = in.const_data_ptr<CTYPE>();
34+
const long* index_data = index.const_data_ptr<long>();
35+
const CTYPE* src_data = src.const_data_ptr<CTYPE>();
36+
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
37+
38+
memcpy(out_data, in_data, in.nbytes());
39+
40+
if (dim < 0) {
41+
dim += nonzero_dim(in);
42+
}
43+
44+
for (size_t ix = 0; ix < index.numel(); ++ix) {
45+
// @lint-ignore CLANGTIDY facebook-hte-CArray
46+
size_t ix_coord[kTensorDimensionLimit];
47+
indexToCoordinate(index, ix, ix_coord);
48+
49+
size_t src_ix = coordinateToIndex(src, ix_coord);
50+
51+
// @lint-ignore CLANGTIDY facebook-hte-CArray
52+
size_t out_coord[kTensorDimensionLimit];
53+
for (size_t i = 0; i < out.dim(); ++i) {
54+
if (i == dim) {
55+
out_coord[i] = index_data[ix];
56+
} else {
57+
out_coord[i] = ix_coord[i];
58+
}
59+
}
60+
size_t out_ix = coordinateToIndex(out, out_coord);
61+
62+
out_data[out_ix] = src_data[src_ix];
63+
}
64+
}
65+
2666
template <typename CTYPE, typename CTYPE_VAL>
2767
void scatter_value_helper(
2868
const Tensor& in,
@@ -36,15 +76,16 @@ void scatter_value_helper(
3676

3777
memcpy(out_data, in_data, in.nbytes());
3878

39-
if (index.dim() == 0) {
40-
out_data[index_data[0]] = static_cast<CTYPE>(val);
41-
return;
79+
if (dim < 0) {
80+
dim += nonzero_dim(in);
4281
}
4382

4483
for (size_t ix = 0; ix < index.numel(); ++ix) {
84+
// @lint-ignore CLANGTIDY facebook-hte-CArray
4585
size_t ix_coord[kTensorDimensionLimit];
4686
indexToCoordinate(index, ix, ix_coord);
4787

88+
// @lint-ignore CLANGTIDY facebook-hte-CArray
4889
size_t out_coord[kTensorDimensionLimit];
4990
for (size_t i = 0; i < out.dim(); ++i) {
5091
if (i == dim) {
@@ -61,6 +102,36 @@ void scatter_value_helper(
61102

62103
} // namespace
63104

105+
Tensor& scatter_src_out(
106+
RuntimeContext& context,
107+
const Tensor& in,
108+
int64_t dim,
109+
const Tensor& index,
110+
const Tensor& src,
111+
Tensor& out) {
112+
(void)context;
113+
114+
ET_KERNEL_CHECK(
115+
context,
116+
check_scatter_src_args(in, dim, index, src, out),
117+
InvalidArgument,
118+
out);
119+
120+
ET_KERNEL_CHECK(
121+
context,
122+
resize_tensor(out, in.sizes()) == Error::Ok,
123+
InvalidArgument,
124+
out);
125+
126+
constexpr auto name = "scatter.src_out";
127+
128+
ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
129+
scatter_src_helper<CTYPE>(in, dim, index, src, out);
130+
});
131+
132+
return out;
133+
}
134+
64135
Tensor& scatter_value_out(
65136
RuntimeContext& ctx,
66137
const Tensor& in,
@@ -79,10 +150,6 @@ Tensor& scatter_value_out(
79150
ET_KERNEL_CHECK(
80151
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
81152

82-
if (dim < 0) {
83-
dim += nonzero_dim(in);
84-
}
85-
86153
ScalarType val_type = utils::get_scalar_dtype(value);
87154

88155
constexpr auto name = "scatter.value_out";

kernels/portable/cpu/util/index_util.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,15 @@ bool check_scatter_add_args(
191191
return true;
192192
}
193193

194+
bool check_scatter_src_args(
195+
const Tensor& self,
196+
int64_t dim,
197+
const Tensor& index,
198+
const Tensor& src,
199+
Tensor& out) {
200+
return check_scatter_add_args(self, dim, index, src, out);
201+
}
202+
194203
bool check_scatter_value_args(
195204
const Tensor& self,
196205
int64_t dim,

kernels/portable/cpu/util/index_util.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,13 @@ bool check_scatter_add_args(
4343
const Tensor& src,
4444
Tensor& out);
4545

46+
bool check_scatter_src_args(
47+
const Tensor& self,
48+
int64_t dim,
49+
const Tensor& index,
50+
const Tensor& src,
51+
Tensor& out);
52+
4653
bool check_scatter_value_args(
4754
const Tensor& self,
4855
int64_t dim,

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,11 @@
742742
- arg_meta: null
743743
kernel_name: torch::executor::scalar_tensor_out
744744

745+
- op: scatter.src_out
746+
kernels:
747+
- arg_meta: null
748+
kernel_name: torch::executor::scatter_src_out
749+
745750
- op: scatter.value_out
746751
kernels:
747752
- arg_meta: null

0 commit comments

Comments
 (0)