@@ -23,6 +23,46 @@ using ScalarType = exec_aten::ScalarType;
23
23
24
24
namespace {
25
25
26
+ template <typename CTYPE>
27
+ void scatter_src_helper (
28
+ const Tensor& in,
29
+ int64_t dim,
30
+ const Tensor& index,
31
+ const Tensor& src,
32
+ Tensor& out) {
33
+ const CTYPE* in_data = in.const_data_ptr <CTYPE>();
34
+ const long * index_data = index.const_data_ptr <long >();
35
+ const CTYPE* src_data = src.const_data_ptr <CTYPE>();
36
+ CTYPE* out_data = out.mutable_data_ptr <CTYPE>();
37
+
38
+ memcpy (out_data, in_data, in.nbytes ());
39
+
40
+ if (dim < 0 ) {
41
+ dim += nonzero_dim (in);
42
+ }
43
+
44
+ for (size_t ix = 0 ; ix < index.numel (); ++ix) {
45
+ // @lint-ignore CLANGTIDY facebook-hte-CArray
46
+ size_t ix_coord[kTensorDimensionLimit ];
47
+ indexToCoordinate (index, ix, ix_coord);
48
+
49
+ size_t src_ix = coordinateToIndex (src, ix_coord);
50
+
51
+ // @lint-ignore CLANGTIDY facebook-hte-CArray
52
+ size_t out_coord[kTensorDimensionLimit ];
53
+ for (size_t i = 0 ; i < out.dim (); ++i) {
54
+ if (i == dim) {
55
+ out_coord[i] = index_data[ix];
56
+ } else {
57
+ out_coord[i] = ix_coord[i];
58
+ }
59
+ }
60
+ size_t out_ix = coordinateToIndex (out, out_coord);
61
+
62
+ out_data[out_ix] = src_data[src_ix];
63
+ }
64
+ }
65
+
26
66
template <typename CTYPE, typename CTYPE_VAL>
27
67
void scatter_value_helper (
28
68
const Tensor& in,
@@ -36,15 +76,16 @@ void scatter_value_helper(
36
76
37
77
memcpy (out_data, in_data, in.nbytes ());
38
78
39
- if (index.dim () == 0 ) {
40
- out_data[index_data[0 ]] = static_cast <CTYPE>(val);
41
- return ;
79
+ if (dim < 0 ) {
80
+ dim += nonzero_dim (in);
42
81
}
43
82
44
83
for (size_t ix = 0 ; ix < index.numel (); ++ix) {
84
+ // @lint-ignore CLANGTIDY facebook-hte-CArray
45
85
size_t ix_coord[kTensorDimensionLimit ];
46
86
indexToCoordinate (index, ix, ix_coord);
47
87
88
+ // @lint-ignore CLANGTIDY facebook-hte-CArray
48
89
size_t out_coord[kTensorDimensionLimit ];
49
90
for (size_t i = 0 ; i < out.dim (); ++i) {
50
91
if (i == dim) {
@@ -61,6 +102,36 @@ void scatter_value_helper(
61
102
62
103
} // namespace
63
104
105
+ Tensor& scatter_src_out (
106
+ RuntimeContext& context,
107
+ const Tensor& in,
108
+ int64_t dim,
109
+ const Tensor& index,
110
+ const Tensor& src,
111
+ Tensor& out) {
112
+ (void )context;
113
+
114
+ ET_KERNEL_CHECK (
115
+ context,
116
+ check_scatter_src_args (in, dim, index, src, out),
117
+ InvalidArgument,
118
+ out);
119
+
120
+ ET_KERNEL_CHECK (
121
+ context,
122
+ resize_tensor (out, in.sizes ()) == Error::Ok,
123
+ InvalidArgument,
124
+ out);
125
+
126
+ constexpr auto name = " scatter.src_out" ;
127
+
128
+ ET_SWITCH_REALHB_TYPES (in.scalar_type (), ctx, name, CTYPE, [&]() {
129
+ scatter_src_helper<CTYPE>(in, dim, index, src, out);
130
+ });
131
+
132
+ return out;
133
+ }
134
+
64
135
Tensor& scatter_value_out (
65
136
RuntimeContext& ctx,
66
137
const Tensor& in,
@@ -79,10 +150,6 @@ Tensor& scatter_value_out(
79
150
ET_KERNEL_CHECK (
80
151
ctx, resize_tensor (out, in.sizes ()) == Error::Ok, InvalidArgument, out);
81
152
82
- if (dim < 0 ) {
83
- dim += nonzero_dim (in);
84
- }
85
-
86
153
ScalarType val_type = utils::get_scalar_dtype (value);
87
154
88
155
constexpr auto name = " scatter.value_out" ;
0 commit comments