@@ -169,6 +169,7 @@ class IteratorCfg {
169169 const uint32_t pre_padding[]) { // number of virtual pixels added before each dimension
170170 for (uint32_t i = 0 ; i < iterRank; ++i) {
171171 int32_t dim = icfg.m_order [i];
172+ uint32_t tns_dim = tensor.get_dim (dim);
172173 m_order[i] = dim;
173174 m_count[i] = icfg.m_count [i];
174175 m_pos_inc[i] = icfg.m_pos_inc [i] * stride[dim];
@@ -181,9 +182,9 @@ class IteratorCfg {
181182 MLI_ASSERT (icfg.m_first_pos_inc [i] * stride[dim] >= pre_padding[dim]);
182183 m_first_pos_inc[i] = icfg.m_first_pos_inc [i] * stride[dim] - (int32_t )pre_padding[dim];
183184 m_last_pos_inc[i] = m_count[i] > 1 ? m_pos_inc[i] * (2 - m_count[i]) - m_first_pos_inc[i] : 0 ;
184- m_size[i] = ( icfg.m_size [i] - 1 ) * stride[dim] + effective_kernel_size[dim];
185- m_first_size[i] = ( icfg.m_first_size [i] - 1 ) * stride[dim] + effective_kernel_size[dim] - pre_padding[dim];
186- m_last_size[i] = tensor. get_dim (dim) + m_last_pos_inc[i];
185+ m_size[i] = MIN (( icfg.m_size [i] - 1 ) * stride[dim] + effective_kernel_size[dim], tns_dim) ;
186+ m_first_size[i] = MIN (( icfg.m_first_size [i] - 1 ) * stride[dim] + effective_kernel_size[dim] - pre_padding[dim], tns_dim) ;
187+ m_last_size[i] = tns_dim + m_last_pos_inc[i];
187188 m_diff_code[i] = CalcDiffCode (m_count[i], m_first_size[i], m_size[i], m_last_size[i]);
188189 }
189190 m_buf_tiles_num = icfg.get_buf_tiles_num ();
@@ -987,15 +988,18 @@ class TensorIterator {
987988 }
988989
989990
990-
991+ /* *
992+ * @brief Returns the subtensor of the full tensor in current iteration position
993+ *
994+ */
991995 Tensor<buf_T, tensorRank> GetSubTensor () {
992996 uint32_t pos[tensorRank];
993997 uint32_t copysize[tensorRank];
994- uint32_t r = 0 ;
995- for (r = 0 ; r < iterRank; r++) {
998+ // If the iterRank is less than the tensorRank we need to return the full size of non-iterable dimension
999+ for (uint32_t r = 0 ; r < tensorRank; ++r) pos[r] = 0 , copysize[r] = m_full_tensor.get_dim (r);
1000+ for (uint32_t r = 0 ; r < iterRank; ++r) {
9961001 int32_t dim = m_config.get_order (r);
9971002 if (dim == kSkipIterDim ) continue ;
998-
9991003 pos[dim] = (uint32_t )m_pos[r];
10001004 if (m_tile_idx[r] == m_config.get_count (r) - 1 ) { // Last iteration
10011005 copysize[dim] = m_config.get_last_size (r);
@@ -1004,7 +1008,6 @@ class TensorIterator {
10041008 } else { // Middle iteration
10051009 copysize[dim] = m_config.get_size (r);
10061010 }
1007- copysize[dim] = MIN (m_full_tensor.get_dim (dim) - pos[dim], copysize[dim]);
10081011 }
10091012 return m_full_tensor.slice (pos, copysize);
10101013 }
0 commit comments