Skip to content

Commit 827388d

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Dtype compliance: select_scatter
Reviewed By: SS-JIA Differential Revision: D48288080 fbshipit-source-id: a3fc4f24a05870cd28257baf643291433a816533
1 parent 8c77ef8 commit 827388d

File tree

1 file changed

+53
-66
lines changed

1 file changed

+53
-66
lines changed

kernels/portable/cpu/op_select_scatter.cpp

Lines changed: 53 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ namespace {
2727
* 3. dim and index values are valid given the input tensor
2828
*/
2929
void check_select_scatter_args(
30-
const Tensor& input,
30+
const Tensor& in,
3131
const Tensor& src,
3232
int64_t dim,
3333
int64_t index,
@@ -37,56 +37,50 @@ void check_select_scatter_args(
3737

3838
// The dim planed to be selected on shall exist in input
3939
ET_CHECK_MSG(
40-
dim >= 0 && dim < input.dim(),
40+
dim >= 0 && dim < in.dim(),
4141
"dim %" PRId64 " out of range [-%zd,%zd)",
4242
dim,
43-
input.dim(),
44-
input.dim());
43+
in.dim(),
44+
in.dim());
4545

4646
// The index shall be valid in the given dimenson
4747
ET_CHECK_MSG(
48-
index >= 0 && index < input.size(dim),
49-
"index %" PRId64 " out of range [-%zd,%zd) at input.size( %" PRId64 ")",
48+
index >= 0 && index < in.size(dim),
49+
"index %" PRId64 " out of range [-%zd,%zd) at in.size( %" PRId64 ")",
5050
index,
51-
input.size(dim),
52-
input.size(dim),
51+
in.size(dim),
52+
in.size(dim),
5353
dim);
5454

55-
// All tensors should be same dtype
56-
ET_CHECK_SAME_DTYPE3(input, output, src);
57-
58-
// The size of output tensor should be the same as the input
59-
ET_CHECK_SAME_SHAPE2(input, output);
60-
61-
// The src.dim() shall be one lower than input.dim() since src needs to fit
55+
// The src.dim() shall be one lower than in.dim() since src needs to fit
6256
// into the selected data on one dim of input
6357
// https://pytorch.org/docs/stable/generated/torch.select_scatter.html
6458
ET_CHECK_MSG(
65-
input.dim() == src.dim() + 1,
66-
"input.dim() %zd != src.dim() + 1 %zd",
67-
input.dim(),
59+
in.dim() == src.dim() + 1,
60+
"in.dim() %zd != src.dim() + 1 %zd",
61+
in.dim(),
6862
src.dim() + 1);
6963

7064
// The size of src tensor should follow these rules:
71-
// - src.size(i) shall equal to input.size(i) if i < dim,
72-
// - src.size(i) shall equal to input.size(i+1) if i >= dim
65+
// - src.size(i) shall equal to in.size(i) if i < dim,
66+
// - src.size(i) shall equal to in.size(i+1) if i >= dim
7367

74-
for (ssize_t d = 0; d < input.dim() - 1; d++) {
68+
for (ssize_t d = 0; d < in.dim() - 1; d++) {
7569
if (d < dim) {
7670
ET_CHECK_MSG(
77-
input.size(d) == src.size(d),
78-
"input.size(%zu) %zd != src.size(%zu) %zd | dim = %" PRId64 ")",
71+
in.size(d) == src.size(d),
72+
"in.size(%zu) %zd != src.size(%zu) %zd | dim = %" PRId64 ")",
7973
d,
80-
input.size(d),
74+
in.size(d),
8175
d,
8276
src.size(d),
8377
dim);
8478
} else {
8579
ET_CHECK_MSG(
86-
input.size(d + 1) == src.size(d),
87-
"input.size(%zu) %zd != src.size(%zu) %zd | dim = %" PRId64 ")",
80+
in.size(d + 1) == src.size(d),
81+
"in.size(%zu) %zd != src.size(%zu) %zd | dim = %" PRId64 ")",
8882
d + 1,
89-
input.size(d + 1),
83+
in.size(d + 1),
9084
d,
9185
src.size(d),
9286
dim);
@@ -100,67 +94,60 @@ void check_select_scatter_args(
10094
/// Tensor(a!) out) -> Tensor(a!)
10195
Tensor& select_scatter_out(
10296
RuntimeContext& ctx,
103-
const Tensor& input,
97+
const Tensor& in,
10498
const Tensor& src,
10599
int64_t dim,
106100
int64_t index,
107101
Tensor& out) {
108-
// Avoid unused variable warning
109102
(void)ctx;
110103

104+
ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);
105+
ET_KERNEL_CHECK(
106+
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
107+
111108
// Account for negative indices
112109
if (dim < 0) {
113-
dim += input.dim();
110+
dim += in.dim();
114111
}
115112
if (index < 0) {
116-
index += input.size(dim);
113+
index += in.size(dim);
117114
}
118115

119-
// Resize the tensor to the expected output size
120-
Tensor::SizesType expected_output_size[16];
121-
for (size_t i = 0; i < input.dim(); ++i) {
122-
expected_output_size[i] = input.size(i);
123-
}
124-
auto error = resize_tensor(
125-
out, {expected_output_size, static_cast<size_t>(input.dim())});
126-
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
127-
128116
// Check args
129-
check_select_scatter_args(input, src, dim, index, out);
117+
check_select_scatter_args(in, src, dim, index, out);
130118

131-
// If the input is a empty tensor, no other operation could be done. We just
119+
// If the input is an empty tensor, no other operation could be done. We just
132120
// return the output.
133-
if (input.numel() == 0) {
121+
if (in.numel() == 0) {
134122
return out;
135123
}
136124

137125
// To start, copy the input into the output. Input will not be empty due to
138126
// the checks performed above.
139-
memcpy(out.mutable_data_ptr(), input.const_data_ptr(), input.nbytes());
127+
memcpy(out.mutable_data_ptr(), in.const_data_ptr(), in.nbytes());
140128

141129
// Strides to help with memory address arithmetic
142-
size_t leading_dims = getLeadingDims(input, dim);
143-
size_t trailing_stride = getTrailingDims(input, dim);
144-
145-
size_t dim_length = input.size(dim);
146-
147-
// Number of bytes to copy for each memcpy
148-
size_t copy_nbytes = trailing_stride * src.element_size();
149-
150-
// Number of bytes to step forward to reach the next copy output location
151-
size_t out_step_nbytes = dim_length * trailing_stride * out.element_size();
152-
153-
// Position data pointers at the starting point
154-
size_t start_offset = index * trailing_stride * out.element_size();
155-
char* out_data = out.mutable_data_ptr<char>() + start_offset;
156-
157-
const char* src_data = src.const_data_ptr<char>();
158-
159-
for (size_t step = 0; step < leading_dims; ++step) {
160-
memcpy(out_data, src_data, copy_nbytes);
161-
out_data += out_step_nbytes;
162-
src_data += copy_nbytes;
163-
}
130+
size_t leading_dims = getLeadingDims(in, dim);
131+
size_t trailing_stride = getTrailingDims(in, dim);
132+
size_t start_offset = index * trailing_stride;
133+
size_t out_step = in.size(dim) * trailing_stride;
134+
135+
ScalarType in_type = in.scalar_type();
136+
ScalarType src_type = src.scalar_type();
137+
138+
ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, __func__, CTYPE, [&]() {
139+
ET_SWITCH_REAL_TYPES_AND(Bool, src_type, ctx, __func__, CTYPE_SRC, [&]() {
140+
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
141+
const CTYPE_SRC* const src_data = src.const_data_ptr<CTYPE_SRC>();
142+
143+
for (size_t i = 0; i < leading_dims; ++i) {
144+
for (size_t j = 0; j < trailing_stride; ++j) {
145+
out_data[start_offset + i * out_step + j] =
146+
convert<CTYPE, CTYPE_SRC>(src_data[i * trailing_stride + j]);
147+
}
148+
}
149+
});
150+
});
164151

165152
return out;
166153
}

0 commit comments

Comments
 (0)