@@ -39,8 +39,8 @@ void check_input_args(
39
39
// Input and output tensors should be the same shape
40
40
ET_CHECK_SAME_SHAPE2 (input, output);
41
41
42
- // All tensors should have the same dtype
43
- ET_CHECK_SAME_DTYPE3 (input, src , output);
42
+ // Input and output tensors should have the same shape
43
+ ET_CHECK_SAME_DTYPE2 (input, output);
44
44
45
45
// The input.dim() shall equal to src.dim()
46
46
ET_CHECK_MSG (
@@ -158,28 +158,36 @@ Tensor& slice_scatter_out(
158
158
check_input_args (input, src, dim, num_values, step, out);
159
159
160
160
size_t dim_length = input.size (dim);
161
-
162
161
size_t leading_dims = getLeadingDims (input, dim);
163
162
size_t trailing_dims = getTrailingDims (input, dim);
164
163
165
- size_t length_per_step = trailing_dims * input.element_size ();
166
-
167
- const char * in_data = input.const_data_ptr <char >();
168
- const char * src_data = src.const_data_ptr <char >();
169
-
170
- char * out_data = out.mutable_data_ptr <char >();
171
-
172
164
// To start, copy the input into the output
173
- memcpy (out_data, in_data, input.nbytes ());
174
-
175
- for (int i = 0 ; i < leading_dims; i++) {
176
- char * dst = out_data + (i * dim_length + start) * length_per_step;
177
- for (int j = 0 ; j < num_values; j++) {
178
- memcpy (dst, src_data, length_per_step);
179
- src_data += length_per_step;
180
- dst += step * length_per_step;
181
- }
182
- }
165
+ memcpy (out.mutable_data_ptr (), input.const_data_ptr (), input.nbytes ());
166
+
167
+ ScalarType in_type = input.scalar_type ();
168
+ ScalarType src_type = src.scalar_type ();
169
+
170
+ ET_SWITCH_REAL_TYPES_AND (Bool, in_type, ctx, __func__, CTYPE, [&]() {
171
+ ET_SWITCH_REAL_TYPES_AND (Bool, src_type, ctx, __func__, CTYPE_SRC, [&]() {
172
+ CTYPE* out_data = out.mutable_data_ptr <CTYPE>();
173
+ const CTYPE_SRC* src_data = src.const_data_ptr <CTYPE_SRC>();
174
+
175
+ size_t src_offset = 0 ;
176
+
177
+ for (int i = 0 ; i < leading_dims; i++) {
178
+ size_t out_offset = (i * dim_length + start) * trailing_dims;
179
+ for (int j = 0 ; j < num_values; j++) {
180
+ for (size_t k = 0 ; k < trailing_dims; ++k) {
181
+ out_data[out_offset + k] =
182
+ convert<CTYPE, CTYPE_SRC>(src_data[src_offset + k]);
183
+ }
184
+ src_offset += trailing_dims;
185
+ out_offset += step * trailing_dims;
186
+ }
187
+ }
188
+ });
189
+ });
190
+
183
191
return out;
184
192
}
185
193
0 commit comments