Skip to content

Commit 2f5def2

Browse files
fix tile compute error. test=develop (#6059)
1 parent bf16d66 commit 2f5def2

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

lite/kernels/host/tile_compute.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,16 @@ void TileCompute<T, PType>::Run() {
6767
tmp_dst_tensor.Resize(out_dims);
6868
auto tmp_src = tmp_src_tensor.mutable_data<T>();
6969
auto tmp_dst = tmp_dst_tensor.mutable_data<T>();
70-
for (int i = 0; i < in->dims().production(); i++) {
70+
for (int i = 0; i < in_dims.production(); i++) {
7171
tmp_src[i] = in_data[i];
7272
}
73+
7374
for (int i = bcast_dims.size() - 1; i >= 0; i--) {
7475
if (bcast_dims[i] > 1) {
75-
for (int m = 0; m < in_stride[i]; m++) {
76+
int num = in_stride[1] / in_stride[i + 1];
77+
int dst_stride = in_stride[i + 1] * bcast_dims[i + 1];
78+
for (int m = 0; m < num; m++) {
7679
for (int j = 0; j < bcast_dims[i]; j++) {
77-
int dst_stride = in_stride[i + 1] * bcast_dims[i + 1];
7880
std::memcpy(tmp_dst + j * dst_stride + m * bcast_dims[i] * dst_stride,
7981
tmp_src + m * in_stride[i + 1],
8082
dst_stride * sizeof(T));

0 commit comments

Comments
 (0)