Skip to content

Commit b26da3a

Browse files
dkozlovessayed
authored andcommitted
Extended GetSubTensor() method for the case tensorRank>iterRank
1 parent 7773abb commit b26da3a

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

include/mli_iterator.hpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ class IteratorCfg {
169169
const uint32_t pre_padding[]) { // number of virtual pixels added before each dimension
170170
for (uint32_t i = 0; i < iterRank; ++i) {
171171
int32_t dim = icfg.m_order[i];
172+
uint32_t tns_dim = tensor.get_dim(dim);
172173
m_order[i] = dim;
173174
m_count[i] = icfg.m_count[i];
174175
m_pos_inc[i] = icfg.m_pos_inc[i] * stride[dim];
@@ -181,9 +182,9 @@ class IteratorCfg {
181182
MLI_ASSERT(icfg.m_first_pos_inc[i] * stride[dim] >= pre_padding[dim]);
182183
m_first_pos_inc[i] = icfg.m_first_pos_inc[i] * stride[dim] - (int32_t)pre_padding[dim];
183184
m_last_pos_inc[i] = m_count[i] > 1 ? m_pos_inc[i] * (2 - m_count[i]) - m_first_pos_inc[i] : 0;
184-
m_size[i] = (icfg.m_size[i] - 1) * stride[dim] + effective_kernel_size[dim];
185-
m_first_size[i] = (icfg.m_first_size[i] - 1) * stride[dim] + effective_kernel_size[dim] - pre_padding[dim];
186-
m_last_size[i] = tensor.get_dim(dim) + m_last_pos_inc[i];
185+
m_size[i] = MIN((icfg.m_size[i] - 1) * stride[dim] + effective_kernel_size[dim], tns_dim);
186+
m_first_size[i] = MIN((icfg.m_first_size[i] - 1) * stride[dim] + effective_kernel_size[dim] - pre_padding[dim], tns_dim);
187+
m_last_size[i] = tns_dim + m_last_pos_inc[i];
187188
m_diff_code[i] = CalcDiffCode(m_count[i], m_first_size[i], m_size[i], m_last_size[i]);
188189
}
189190
m_buf_tiles_num = icfg.get_buf_tiles_num();
@@ -987,15 +988,18 @@ class TensorIterator {
987988
}
988989

989990

990-
991+
/**
992+
* @brief Returns the subtensor of the full tensor in current iteration position
993+
*
994+
*/
991995
Tensor<buf_T, tensorRank> GetSubTensor() {
992996
uint32_t pos[tensorRank];
993997
uint32_t copysize[tensorRank];
994-
uint32_t r = 0;
995-
for (r = 0; r < iterRank; r++) {
998+
// If the iterRank is less than the tensorRank we need to return the full size of non-iterable dimension
999+
for (uint32_t r = 0; r < tensorRank; ++r) pos[r] = 0, copysize[r] = m_full_tensor.get_dim(r);
1000+
for (uint32_t r = 0; r < iterRank; ++r) {
9961001
int32_t dim = m_config.get_order(r);
9971002
if (dim == kSkipIterDim) continue;
998-
9991003
pos[dim] = (uint32_t)m_pos[r];
10001004
if (m_tile_idx[r] == m_config.get_count(r) - 1) { // Last iteration
10011005
copysize[dim] = m_config.get_last_size(r);
@@ -1004,7 +1008,6 @@ class TensorIterator {
10041008
} else { // Middle iteration
10051009
copysize[dim] = m_config.get_size(r);
10061010
}
1007-
copysize[dim] = MIN(m_full_tensor.get_dim(dim) - pos[dim], copysize[dim]);
10081011
}
10091012
return m_full_tensor.slice(pos, copysize);
10101013
}

0 commit comments

Comments
 (0)