@@ -27,7 +27,7 @@ namespace {
27
27
* 3. dim and index values are valid given the input tensor
28
28
*/
29
29
void check_select_scatter_args (
30
- const Tensor& input ,
30
+ const Tensor& in ,
31
31
const Tensor& src,
32
32
int64_t dim,
33
33
int64_t index,
@@ -37,56 +37,50 @@ void check_select_scatter_args(
37
37
38
38
// The dim planed to be selected on shall exist in input
39
39
ET_CHECK_MSG (
40
- dim >= 0 && dim < input .dim (),
40
+ dim >= 0 && dim < in .dim (),
41
41
" dim %" PRId64 " out of range [-%zd,%zd)" ,
42
42
dim,
43
- input .dim (),
44
- input .dim ());
43
+ in .dim (),
44
+ in .dim ());
45
45
46
46
// The index shall be valid in the given dimenson
47
47
ET_CHECK_MSG (
48
- index >= 0 && index < input .size (dim),
49
- " index %" PRId64 " out of range [-%zd,%zd) at input .size( %" PRId64 " )" ,
48
+ index >= 0 && index < in .size (dim),
49
+ " index %" PRId64 " out of range [-%zd,%zd) at in .size( %" PRId64 " )" ,
50
50
index,
51
- input .size (dim),
52
- input .size (dim),
51
+ in .size (dim),
52
+ in .size (dim),
53
53
dim);
54
54
55
- // All tensors should be same dtype
56
- ET_CHECK_SAME_DTYPE3 (input, output, src);
57
-
58
- // The size of output tensor should be the same as the input
59
- ET_CHECK_SAME_SHAPE2 (input, output);
60
-
61
- // The src.dim() shall be one lower than input.dim() since src needs to fit
55
+ // The src.dim() shall be one lower than in.dim() since src needs to fit
62
56
// into the selected data on one dim of input
63
57
// https://pytorch.org/docs/stable/generated/torch.select_scatter.html
64
58
ET_CHECK_MSG (
65
- input .dim () == src.dim () + 1 ,
66
- " input .dim() %zd != src.dim() + 1 %zd" ,
67
- input .dim (),
59
+ in .dim () == src.dim () + 1 ,
60
+ " in .dim() %zd != src.dim() + 1 %zd" ,
61
+ in .dim (),
68
62
src.dim () + 1 );
69
63
70
64
// The size of src tensor should follow these rules:
71
- // - src.size(i) shall equal to input .size(i) if i < dim,
72
- // - src.size(i) shall equal to input .size(i+1) if i >= dim
65
+ // - src.size(i) shall equal to in .size(i) if i < dim,
66
+ // - src.size(i) shall equal to in .size(i+1) if i >= dim
73
67
74
- for (ssize_t d = 0 ; d < input .dim () - 1 ; d++) {
68
+ for (ssize_t d = 0 ; d < in .dim () - 1 ; d++) {
75
69
if (d < dim) {
76
70
ET_CHECK_MSG (
77
- input .size (d) == src.size (d),
78
- " input .size(%zu) %zd != src.size(%zu) %zd | dim = %" PRId64 " )" ,
71
+ in .size (d) == src.size (d),
72
+ " in .size(%zu) %zd != src.size(%zu) %zd | dim = %" PRId64 " )" ,
79
73
d,
80
- input .size (d),
74
+ in .size (d),
81
75
d,
82
76
src.size (d),
83
77
dim);
84
78
} else {
85
79
ET_CHECK_MSG (
86
- input .size (d + 1 ) == src.size (d),
87
- " input .size(%zu) %zd != src.size(%zu) %zd | dim = %" PRId64 " )" ,
80
+ in .size (d + 1 ) == src.size (d),
81
+ " in .size(%zu) %zd != src.size(%zu) %zd | dim = %" PRId64 " )" ,
88
82
d + 1 ,
89
- input .size (d + 1 ),
83
+ in .size (d + 1 ),
90
84
d,
91
85
src.size (d),
92
86
dim);
@@ -100,67 +94,60 @@ void check_select_scatter_args(
100
94
// / Tensor(a!) out) -> Tensor(a!)
101
95
Tensor& select_scatter_out (
102
96
RuntimeContext& ctx,
103
- const Tensor& input ,
97
+ const Tensor& in ,
104
98
const Tensor& src,
105
99
int64_t dim,
106
100
int64_t index,
107
101
Tensor& out) {
108
- // Avoid unused variable warning
109
102
(void )ctx;
110
103
104
+ ET_KERNEL_CHECK (ctx, tensors_have_same_dtype (in, out), InvalidArgument, out);
105
+ ET_KERNEL_CHECK (
106
+ ctx, resize_tensor (out, in.sizes ()) == Error::Ok, InvalidArgument, out);
107
+
111
108
// Account for negative indices
112
109
if (dim < 0 ) {
113
- dim += input .dim ();
110
+ dim += in .dim ();
114
111
}
115
112
if (index < 0 ) {
116
- index += input .size (dim);
113
+ index += in .size (dim);
117
114
}
118
115
119
- // Resize the tensor to the expected output size
120
- Tensor::SizesType expected_output_size[16 ];
121
- for (size_t i = 0 ; i < input.dim (); ++i) {
122
- expected_output_size[i] = input.size (i);
123
- }
124
- auto error = resize_tensor (
125
- out, {expected_output_size, static_cast <size_t >(input.dim ())});
126
- ET_CHECK_MSG (error == Error::Ok, " Failed to resize output tensor." );
127
-
128
116
// Check args
129
- check_select_scatter_args (input , src, dim, index, out);
117
+ check_select_scatter_args (in , src, dim, index, out);
130
118
131
- // If the input is a empty tensor, no other operation could be done. We just
119
+ // If the input is an empty tensor, no other operation could be done. We just
132
120
// return the output.
133
- if (input .numel () == 0 ) {
121
+ if (in .numel () == 0 ) {
134
122
return out;
135
123
}
136
124
137
125
// To start, copy the input into the output. Input will not be empty due to
138
126
// the checks performed above.
139
- memcpy (out.mutable_data_ptr (), input .const_data_ptr (), input .nbytes ());
127
+ memcpy (out.mutable_data_ptr (), in .const_data_ptr (), in .nbytes ());
140
128
141
129
// Strides to help with memory address arithmetic
142
- size_t leading_dims = getLeadingDims (input, dim);
143
- size_t trailing_stride = getTrailingDims (input, dim);
144
-
145
- size_t dim_length = input.size (dim);
146
-
147
- // Number of bytes to copy for each memcpy
148
- size_t copy_nbytes = trailing_stride * src.element_size ();
149
-
150
- // Number of bytes to step forward to reach the next copy output location
151
- size_t out_step_nbytes = dim_length * trailing_stride * out.element_size ();
152
-
153
- // Position data pointers at the starting point
154
- size_t start_offset = index * trailing_stride * out.element_size ();
155
- char * out_data = out.mutable_data_ptr <char >() + start_offset;
156
-
157
- const char * src_data = src.const_data_ptr <char >();
158
-
159
- for (size_t step = 0 ; step < leading_dims; ++step) {
160
- memcpy (out_data, src_data, copy_nbytes);
161
- out_data += out_step_nbytes;
162
- src_data += copy_nbytes;
163
- }
130
+ size_t leading_dims = getLeadingDims (in, dim);
131
+ size_t trailing_stride = getTrailingDims (in, dim);
132
+ size_t start_offset = index * trailing_stride;
133
+ size_t out_step = in.size (dim) * trailing_stride;
134
+
135
+ ScalarType in_type = in.scalar_type ();
136
+ ScalarType src_type = src.scalar_type ();
137
+
138
+ ET_SWITCH_REAL_TYPES_AND (Bool, in_type, ctx, __func__, CTYPE, [&]() {
139
+ ET_SWITCH_REAL_TYPES_AND (Bool, src_type, ctx, __func__, CTYPE_SRC, [&]() {
140
+ CTYPE* const out_data = out.mutable_data_ptr <CTYPE>();
141
+ const CTYPE_SRC* const src_data = src.const_data_ptr <CTYPE_SRC>();
142
+
143
+ for (size_t i = 0 ; i < leading_dims; ++i) {
144
+ for (size_t j = 0 ; j < trailing_stride; ++j) {
145
+ out_data[start_offset + i * out_step + j] =
146
+ convert<CTYPE, CTYPE_SRC>(src_data[i * trailing_stride + j]);
147
+ }
148
+ }
149
+ });
150
+ });
164
151
165
152
return out;
166
153
}
0 commit comments