@@ -18,14 +18,59 @@ using ScalarType = exec_aten::ScalarType;
18
18
19
19
namespace {
20
20
21
+ void check_arguments (
22
+ const Tensor& self,
23
+ int64_t dim,
24
+ const Tensor& index,
25
+ const Tensor& src,
26
+ Tensor& out) {
27
+ ET_CHECK_SAME_SHAPE_AND_DTYPE2 (self, out);
28
+ ET_CHECK_SAME_DTYPE2 (self, src);
29
+ ET_CHECK_MSG (
30
+ index.scalar_type () == ScalarType::Long,
31
+ " Expected dypte int64 for index" );
32
+ ET_CHECK_MSG (
33
+ dim >= -self.dim () && dim < self.dim (),
34
+ " dim %" PRId64 " >= 0 && dim %" PRId64 " < self.dim() %zd" ,
35
+ dim,
36
+ dim,
37
+ self.dim ());
38
+ ET_CHECK_MSG (
39
+ self.dim () == src.dim () && self.dim () == index.dim (),
40
+ " self, index and src should have same number of dimensions." );
41
+ dim = dim < 0 ? dim + self.dim () : dim;
42
+ for (size_t d = 0 ; d < self.dim (); ++d) {
43
+ ET_CHECK_MSG (
44
+ index.size (d) <= src.size (d),
45
+ " size of dimension %zd of index should be smaller than the size of that dimension of src" ,
46
+ d);
47
+ if (d != dim) {
48
+ ET_CHECK_MSG (
49
+ index.size (d) <= self.size (d),
50
+ " size of dimension %zd of index should be smaller than the size of that dimension of self if dimension %zd != dim %zd" ,
51
+ d,
52
+ d,
53
+ (size_t )dim);
54
+ }
55
+ }
56
+ const long * index_data = index.const_data_ptr <long >();
57
+ for (size_t i = 0 ; i < index.numel (); ++i) {
58
+ ET_CHECK_MSG (
59
+ index_data[i] < self.size (dim),
60
+ " Index is out of bounds for dimension %zd with size %zd" ,
61
+ (size_t )dim,
62
+ self.size (dim));
63
+ }
64
+ }
65
+
21
66
/* *
22
67
* Add input_data to output_data in the fashion of scatter recursively
23
68
*/
24
- template <typename CTYPE_DATA >
69
+ template <typename CTYPE >
25
70
void scatter_add_helper (
26
- CTYPE_DATA * src_data,
27
- long * index_data,
28
- CTYPE_DATA* output_data ,
71
+ const CTYPE * src_data,
72
+ const long * index_data,
73
+ CTYPE* out_data ,
29
74
const Tensor& src,
30
75
const Tensor& index,
31
76
Tensor& out,
@@ -35,15 +80,14 @@ void scatter_add_helper(
35
80
// the last dimension, copy data
36
81
if (current_dim == index.dim () - 1 ) {
37
82
size_t trailing_dims = getTrailingDims (out, dim);
38
- CTYPE_DATA* output_data_base =
39
- output_data - (size_t )dim_offset * trailing_dims;
83
+ CTYPE* out_data_base = out_data - (size_t )dim_offset * trailing_dims;
40
84
for (size_t i = 0 ; i < index.size (current_dim); ++i) {
41
- output_data = output_data_base + (size_t )index_data[i] * trailing_dims;
85
+ out_data = out_data_base + (size_t )index_data[i] * trailing_dims;
42
86
// if dim is the last dimension, do not need to traverse again
43
87
if (dim == current_dim) {
44
- *output_data += src_data[i];
88
+ *out_data += src_data[i];
45
89
} else {
46
- output_data [i] += src_data[i];
90
+ out_data [i] += src_data[i];
47
91
}
48
92
}
49
93
return ;
@@ -54,60 +98,27 @@ void scatter_add_helper(
54
98
size_t current_dim_offset = 0 ;
55
99
// recursively set data for the next dimension
56
100
for (size_t i = 0 ; i < index.size (current_dim); ++i) {
57
- scatter_add_helper<CTYPE_DATA >(
101
+ scatter_add_helper<CTYPE >(
58
102
src_data,
59
103
index_data,
60
- output_data ,
104
+ out_data ,
61
105
src,
62
106
index,
63
107
out,
64
108
dim,
65
109
current_dim + 1 ,
66
110
current_dim_offset);
67
111
src_data += trailing_dims_src;
68
- output_data += trailing_dims_out;
112
+ out_data += trailing_dims_out;
69
113
index_data += trailing_dims_index;
70
114
if (current_dim == dim) {
71
115
current_dim_offset += 1 ;
72
116
}
73
117
}
74
118
}
75
119
76
- template <typename CTYPE_DATA>
77
- void scatter_add (
78
- const Tensor& self,
79
- int64_t dim,
80
- const Tensor& index,
81
- const Tensor& src,
82
- Tensor& out) {
83
- long * index_data = index.mutable_data_ptr <long >();
84
- for (size_t i = 0 ; i < index.numel (); ++i) {
85
- ET_CHECK_MSG (
86
- index_data[i] < self.size (dim),
87
- " Index is out of bounds for dimension %zd with size %zd" ,
88
- (size_t )dim,
89
- self.size (dim));
90
- }
91
- CTYPE_DATA* self_data = self.mutable_data_ptr <CTYPE_DATA>();
92
- CTYPE_DATA* src_data = src.mutable_data_ptr <CTYPE_DATA>();
93
- CTYPE_DATA* out_data = out.mutable_data_ptr <CTYPE_DATA>();
94
- memcpy (out_data, self_data, self.nbytes ());
95
- scatter_add_helper<CTYPE_DATA>(
96
- src_data, index_data, out_data, src, index, out, dim, 0 , 0 );
97
- }
98
-
99
120
} // namespace
100
121
101
- /* *
102
- * Adds all values from the tensor src into self at the indices specified in the
103
- * index tensor in a similar fashion as scatter(). For each value in src, it is
104
- * added to an index in self which is specified by its index in src for
105
- * dimension != dim and by the corresponding value in index for dimension = dim.
106
- *
107
- * Assume tensor self, src, out have the same dtype, and shall be in any real
108
- * types (Byte, Char, Short, Int, Long, Float, Double), and index tensor shall
109
- * be in Long (int64) type.
110
- */
111
122
Tensor& scatter_add_out (
112
123
RuntimeContext& ctx,
113
124
const Tensor& self,
@@ -116,48 +127,21 @@ Tensor& scatter_add_out(
116
127
const Tensor& src,
117
128
Tensor& out) {
118
129
(void )ctx;
119
- ET_CHECK_SAME_SHAPE_AND_DTYPE2 (self, out);
120
- ET_CHECK_SAME_DTYPE2 (self, src);
121
- ET_CHECK_MSG (
122
- index.scalar_type () == ScalarType::Long,
123
- " Expected dypte int64 for index" );
124
- ET_CHECK_MSG (
125
- dim >= -self.dim () && dim < self.dim (),
126
- " dim %" PRId64 " >= 0 && dim %" PRId64 " < self.dim() %zd" ,
127
- dim,
128
- dim,
129
- self.dim ());
130
- ET_CHECK_MSG (
131
- self.dim () == src.dim () && self.dim () == index.dim (),
132
- " self, index and src should have same number of dimensions." );
133
- dim = dim < 0 ? dim + self.dim () : dim;
134
- for (size_t d = 0 ; d < self.dim (); ++d) {
135
- ET_CHECK_MSG (
136
- index.size (d) <= src.size (d),
137
- " size of dimension %zd of index should be smaller than the size of that dimension of src" ,
138
- d);
139
- if (d != dim) {
140
- ET_CHECK_MSG (
141
- index.size (d) <= self.size (d),
142
- " size of dimension %zd of index should be smaller than the size of that dimension of self if dimension %zd != dim %zd" ,
143
- d,
144
- d,
145
- (size_t )dim);
146
- }
147
- }
148
130
149
- #define SCATTER_ADD (ctype, dtype ) \
150
- case ScalarType::dtype: \
151
- scatter_add<ctype>(self, dim, index, src, out); \
152
- break ;
131
+ check_arguments (self, dim, index, src, out);
153
132
154
- switch (self.scalar_type ()) {
155
- ET_FORALL_REAL_TYPES (SCATTER_ADD)
156
- default :
157
- ET_CHECK_MSG (false , " Unhandled input dtype %hhd" , self.scalar_type ());
158
- }
133
+ ScalarType self_type = self.scalar_type ();
134
+
135
+ ET_SWITCH_REAL_TYPES_AND (Bool, self_type, ctx, __func__, CTYPE, [&]() {
136
+ const CTYPE* self_data = self.const_data_ptr <CTYPE>();
137
+ const long * index_data = index.const_data_ptr <long >();
138
+ const CTYPE* src_data = src.const_data_ptr <CTYPE>();
139
+ CTYPE* out_data = out.mutable_data_ptr <CTYPE>();
140
+ memcpy (out_data, self_data, self.nbytes ());
141
+ scatter_add_helper<CTYPE>(
142
+ src_data, index_data, out_data, src, index, out, dim, 0 , 0 );
143
+ });
159
144
160
- #undef SCATTER_ADD
161
145
return out;
162
146
}
163
147
0 commit comments