Skip to content

Commit 97d6b65

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Dtype compliance: scatter_add
Reviewed By: kirklandsign Differential Revision: D48318689 fbshipit-source-id: f89bc317ac552c19a5a7dfdd90f7feaca943c07e
1 parent a82d03c commit 97d6b65

File tree

1 file changed

+68
-84
lines changed

1 file changed

+68
-84
lines changed

kernels/portable/cpu/op_scatter_add.cpp

Lines changed: 68 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,59 @@ using ScalarType = exec_aten::ScalarType;
1818

1919
namespace {
2020

21+
void check_arguments(
22+
const Tensor& self,
23+
int64_t dim,
24+
const Tensor& index,
25+
const Tensor& src,
26+
Tensor& out) {
27+
ET_CHECK_SAME_SHAPE_AND_DTYPE2(self, out);
28+
ET_CHECK_SAME_DTYPE2(self, src);
29+
ET_CHECK_MSG(
30+
index.scalar_type() == ScalarType::Long,
31+
"Expected dypte int64 for index");
32+
ET_CHECK_MSG(
33+
dim >= -self.dim() && dim < self.dim(),
34+
"dim %" PRId64 " >= 0 && dim %" PRId64 " < self.dim() %zd",
35+
dim,
36+
dim,
37+
self.dim());
38+
ET_CHECK_MSG(
39+
self.dim() == src.dim() && self.dim() == index.dim(),
40+
"self, index and src should have same number of dimensions.");
41+
dim = dim < 0 ? dim + self.dim() : dim;
42+
for (size_t d = 0; d < self.dim(); ++d) {
43+
ET_CHECK_MSG(
44+
index.size(d) <= src.size(d),
45+
"size of dimension %zd of index should be smaller than the size of that dimension of src",
46+
d);
47+
if (d != dim) {
48+
ET_CHECK_MSG(
49+
index.size(d) <= self.size(d),
50+
"size of dimension %zd of index should be smaller than the size of that dimension of self if dimension %zd != dim %zd",
51+
d,
52+
d,
53+
(size_t)dim);
54+
}
55+
}
56+
const long* index_data = index.const_data_ptr<long>();
57+
for (size_t i = 0; i < index.numel(); ++i) {
58+
ET_CHECK_MSG(
59+
index_data[i] < self.size(dim),
60+
"Index is out of bounds for dimension %zd with size %zd",
61+
(size_t)dim,
62+
self.size(dim));
63+
}
64+
}
65+
2166
/**
2267
* Add input_data to output_data in the fashion of scatter recursively
2368
*/
24-
template <typename CTYPE_DATA>
69+
template <typename CTYPE>
2570
void scatter_add_helper(
26-
CTYPE_DATA* src_data,
27-
long* index_data,
28-
CTYPE_DATA* output_data,
71+
const CTYPE* src_data,
72+
const long* index_data,
73+
CTYPE* out_data,
2974
const Tensor& src,
3075
const Tensor& index,
3176
Tensor& out,
@@ -35,15 +80,14 @@ void scatter_add_helper(
3580
// the last dimension, copy data
3681
if (current_dim == index.dim() - 1) {
3782
size_t trailing_dims = getTrailingDims(out, dim);
38-
CTYPE_DATA* output_data_base =
39-
output_data - (size_t)dim_offset * trailing_dims;
83+
CTYPE* out_data_base = out_data - (size_t)dim_offset * trailing_dims;
4084
for (size_t i = 0; i < index.size(current_dim); ++i) {
41-
output_data = output_data_base + (size_t)index_data[i] * trailing_dims;
85+
out_data = out_data_base + (size_t)index_data[i] * trailing_dims;
4286
// if dim is the last dimension, do not need to traverse again
4387
if (dim == current_dim) {
44-
*output_data += src_data[i];
88+
*out_data += src_data[i];
4589
} else {
46-
output_data[i] += src_data[i];
90+
out_data[i] += src_data[i];
4791
}
4892
}
4993
return;
@@ -54,60 +98,27 @@ void scatter_add_helper(
5498
size_t current_dim_offset = 0;
5599
// recursively set data for the next dimension
56100
for (size_t i = 0; i < index.size(current_dim); ++i) {
57-
scatter_add_helper<CTYPE_DATA>(
101+
scatter_add_helper<CTYPE>(
58102
src_data,
59103
index_data,
60-
output_data,
104+
out_data,
61105
src,
62106
index,
63107
out,
64108
dim,
65109
current_dim + 1,
66110
current_dim_offset);
67111
src_data += trailing_dims_src;
68-
output_data += trailing_dims_out;
112+
out_data += trailing_dims_out;
69113
index_data += trailing_dims_index;
70114
if (current_dim == dim) {
71115
current_dim_offset += 1;
72116
}
73117
}
74118
}
75119

76-
template <typename CTYPE_DATA>
77-
void scatter_add(
78-
const Tensor& self,
79-
int64_t dim,
80-
const Tensor& index,
81-
const Tensor& src,
82-
Tensor& out) {
83-
long* index_data = index.mutable_data_ptr<long>();
84-
for (size_t i = 0; i < index.numel(); ++i) {
85-
ET_CHECK_MSG(
86-
index_data[i] < self.size(dim),
87-
"Index is out of bounds for dimension %zd with size %zd",
88-
(size_t)dim,
89-
self.size(dim));
90-
}
91-
CTYPE_DATA* self_data = self.mutable_data_ptr<CTYPE_DATA>();
92-
CTYPE_DATA* src_data = src.mutable_data_ptr<CTYPE_DATA>();
93-
CTYPE_DATA* out_data = out.mutable_data_ptr<CTYPE_DATA>();
94-
memcpy(out_data, self_data, self.nbytes());
95-
scatter_add_helper<CTYPE_DATA>(
96-
src_data, index_data, out_data, src, index, out, dim, 0, 0);
97-
}
98-
99120
} // namespace
100121

101-
/**
102-
* Adds all values from the tensor src into self at the indices specified in the
103-
* index tensor in a similar fashion as scatter(). For each value in src, it is
104-
* added to an index in self which is specified by its index in src for
105-
* dimension != dim and by the corresponding value in index for dimension = dim.
106-
*
107-
* Assume tensor self, src, out have the same dtype, and shall be in any real
108-
* types (Byte, Char, Short, Int, Long, Float, Double), and index tensor shall
109-
* be in Long (int64) type.
110-
*/
111122
Tensor& scatter_add_out(
112123
RuntimeContext& ctx,
113124
const Tensor& self,
@@ -116,48 +127,21 @@ Tensor& scatter_add_out(
116127
const Tensor& src,
117128
Tensor& out) {
118129
(void)ctx;
119-
ET_CHECK_SAME_SHAPE_AND_DTYPE2(self, out);
120-
ET_CHECK_SAME_DTYPE2(self, src);
121-
ET_CHECK_MSG(
122-
index.scalar_type() == ScalarType::Long,
123-
"Expected dypte int64 for index");
124-
ET_CHECK_MSG(
125-
dim >= -self.dim() && dim < self.dim(),
126-
"dim %" PRId64 " >= 0 && dim %" PRId64 " < self.dim() %zd",
127-
dim,
128-
dim,
129-
self.dim());
130-
ET_CHECK_MSG(
131-
self.dim() == src.dim() && self.dim() == index.dim(),
132-
"self, index and src should have same number of dimensions.");
133-
dim = dim < 0 ? dim + self.dim() : dim;
134-
for (size_t d = 0; d < self.dim(); ++d) {
135-
ET_CHECK_MSG(
136-
index.size(d) <= src.size(d),
137-
"size of dimension %zd of index should be smaller than the size of that dimension of src",
138-
d);
139-
if (d != dim) {
140-
ET_CHECK_MSG(
141-
index.size(d) <= self.size(d),
142-
"size of dimension %zd of index should be smaller than the size of that dimension of self if dimension %zd != dim %zd",
143-
d,
144-
d,
145-
(size_t)dim);
146-
}
147-
}
148130

149-
#define SCATTER_ADD(ctype, dtype) \
150-
case ScalarType::dtype: \
151-
scatter_add<ctype>(self, dim, index, src, out); \
152-
break;
131+
check_arguments(self, dim, index, src, out);
153132

154-
switch (self.scalar_type()) {
155-
ET_FORALL_REAL_TYPES(SCATTER_ADD)
156-
default:
157-
ET_CHECK_MSG(false, "Unhandled input dtype %hhd", self.scalar_type());
158-
}
133+
ScalarType self_type = self.scalar_type();
134+
135+
ET_SWITCH_REAL_TYPES_AND(Bool, self_type, ctx, __func__, CTYPE, [&]() {
136+
const CTYPE* self_data = self.const_data_ptr<CTYPE>();
137+
const long* index_data = index.const_data_ptr<long>();
138+
const CTYPE* src_data = src.const_data_ptr<CTYPE>();
139+
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
140+
memcpy(out_data, self_data, self.nbytes());
141+
scatter_add_helper<CTYPE>(
142+
src_data, index_data, out_data, src, index, out, dim, 0, 0);
143+
});
159144

160-
#undef SCATTER_ADD
161145
return out;
162146
}
163147

0 commit comments

Comments
 (0)