@@ -39,8 +39,8 @@ void check_input_args(
3939 // Input and output tensors should be the same shape
4040 ET_CHECK_SAME_SHAPE2 (input, output);
4141
42- // All tensors should have the same dtype
43- ET_CHECK_SAME_DTYPE3 (input, src , output);
42+ // Input and output tensors should have the same shape
43+ ET_CHECK_SAME_DTYPE2 (input, output);
4444
4545 // The input.dim() shall equal to src.dim()
4646 ET_CHECK_MSG (
@@ -158,28 +158,36 @@ Tensor& slice_scatter_out(
158158 check_input_args (input, src, dim, num_values, step, out);
159159
160160 size_t dim_length = input.size (dim);
161-
162161 size_t leading_dims = getLeadingDims (input, dim);
163162 size_t trailing_dims = getTrailingDims (input, dim);
164163
165- size_t length_per_step = trailing_dims * input.element_size ();
166-
167- const char * in_data = input.const_data_ptr <char >();
168- const char * src_data = src.const_data_ptr <char >();
169-
170- char * out_data = out.mutable_data_ptr <char >();
171-
172164 // To start, copy the input into the output
173- memcpy (out_data, in_data, input.nbytes ());
174-
175- for (int i = 0 ; i < leading_dims; i++) {
176- char * dst = out_data + (i * dim_length + start) * length_per_step;
177- for (int j = 0 ; j < num_values; j++) {
178- memcpy (dst, src_data, length_per_step);
179- src_data += length_per_step;
180- dst += step * length_per_step;
181- }
182- }
165+ memcpy (out.mutable_data_ptr (), input.const_data_ptr (), input.nbytes ());
166+
167+ ScalarType in_type = input.scalar_type ();
168+ ScalarType src_type = src.scalar_type ();
169+
170+ ET_SWITCH_REAL_TYPES_AND (Bool, in_type, ctx, __func__, CTYPE, [&]() {
171+ ET_SWITCH_REAL_TYPES_AND (Bool, src_type, ctx, __func__, CTYPE_SRC, [&]() {
172+ CTYPE* out_data = out.mutable_data_ptr <CTYPE>();
173+ const CTYPE_SRC* src_data = src.const_data_ptr <CTYPE_SRC>();
174+
175+ size_t src_offset = 0 ;
176+
177+ for (int i = 0 ; i < leading_dims; i++) {
178+ size_t out_offset = (i * dim_length + start) * trailing_dims;
179+ for (int j = 0 ; j < num_values; j++) {
180+ for (size_t k = 0 ; k < trailing_dims; ++k) {
181+ out_data[out_offset + k] =
182+ convert<CTYPE, CTYPE_SRC>(src_data[src_offset + k]);
183+ }
184+ src_offset += trailing_dims;
185+ out_offset += step * trailing_dims;
186+ }
187+ }
188+ });
189+ });
190+
183191 return out;
184192}
185193
0 commit comments