Skip to content

Commit 8e81c14

Browse files
authored
change repeat to repeat_interleave (#792)
* change repeat to repeat_interleave change repeat to repeat_interleave change repeat to repeat_interleave * change repeat to repeat_interleave
1 parent 4197e5e commit 8e81c14

File tree

5 files changed

+8
-7
lines changed

5 files changed

+8
-7
lines changed

mindocr/models/backbones/mindcv_models/repmlp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,9 @@ def local_inject(self):
228228
self.fc3.bias.data = fc3_bias
229229

230230
def _convert_conv_to_fc(self, conv_kernel, conv_bias):
231-
I = ops.eye(self.h * self.w).repeat(1, self.S).reshape(self.h * self.w, self.S, self.h, self.w) # noqa: E741
232-
fc_k = ops.Conv2D(I, conv_kernel, pad=(conv_kernel.size(2) // 2, conv_kernel.size(3) // 2), group=self.S)
231+
# noqa: E741
232+
x = ops.eye(self.h * self.w).repeat_interleave(1, self.S).reshape(self.h * self.w, self.S, self.h, self.w)
233+
fc_k = ops.Conv2D(x, conv_kernel, pad=(conv_kernel.size(2) // 2, conv_kernel.size(3) // 2), group=self.S)
233234
fc_k = fc_k.reshape(self.h * self.w, self.S * self.h * self.w).t()
234235
fc_bias = conv_bias.repeat_interleave(self.h * self.w)
235236
return fc_k, fc_bias

mindocr/models/heads/kie_relationextraction_head.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,12 @@ def __init__(self, hidden_size=768, hidden_dropout_prob=0.1, use_float16: bool =
4949
def construct(self, hidden_states, question, question_label, answer, answer_label):
5050
__, _, hidden_size = hidden_states.shape
5151
q_label_repr = self.entity_emb(question_label)
52-
question = question.expand_dims(-1).repeat(hidden_size, -1)
52+
question = question.expand_dims(-1).repeat_interleave(hidden_size, -1)
5353
tmp_hidden_states = ops.gather_d(hidden_states, 1, question)
5454
q_repr = ops.concat((tmp_hidden_states, q_label_repr), axis=-1)
5555

5656
a_label_repr = self.entity_emb(answer_label)
57-
answer = answer.expand_dims(-1).repeat(hidden_size, -1)
57+
answer = answer.expand_dims(-1).repeat_interleave(hidden_size, -1)
5858
tmp_hidden_states = ops.gather_d(hidden_states, 1, answer)
5959
a_repr = ops.concat((tmp_hidden_states, a_label_repr), axis=-1)
6060

mindocr/models/heads/rec_robustscanner_head.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def construct(self, query, key, value, h, w, valid_width_masks=None):
8383
logits_i = logits[i].squeeze(0) # (c, h, w)
8484
logits_i = logits_i.view((-1, w)) # (c*h, w)
8585
ch = c * h
86-
valid_width_mask = valid_width_mask.repeat(ch, axis=0) # (c*h, w)
86+
valid_width_mask = valid_width_mask.repeat_interleave(ch, 0) # (c*h, w)
8787
valid_width_mask = ops.cast(valid_width_mask, ms.bool_)
8888
logits_i = ops.select(valid_width_mask, logits_i, float('-inf')) # (c*h, w)
8989
logits_list.append(logits_i.view((c, h, w))) # (c, h, w)

mindocr/models/heads/rec_sar_head.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def _2d_attention(self,
232232
attn_weight_i = attn_weight_i.transpose((0, 1, 3, 2)) # (T, h, c, w)
233233
attn_weight_i = attn_weight_i.view((-1, w)) # (T*h*c, w)
234234
Tch = T * h * c
235-
valid_width_mask = valid_width_mask.repeat(Tch, axis=0) # (T*h*c, w)
235+
valid_width_mask = valid_width_mask.repeat_interleave(Tch, 0) # (T*h*c, w)
236236
valid_width_mask = ops.cast(valid_width_mask, ms.bool_)
237237
attn_weight_i = ops.select(
238238
valid_width_mask, attn_weight_i.astype(ms.float32), float('-inf')) # (T*h*c, w)

mindocr/models/heads/rec_visionlan_head.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def __init__(self, n_dim: int = 512, N_max_character: int = 25, n_position: int
162162
def construct(self, enc_output):
163163
# enc_output: b,256,512
164164
reading_order = Tensor(np.arange(self.character_len), dtype=ms.int64)
165-
reading_order = mnp.repeat(reading_order[None, ...], enc_output.shape[0], 0) # (S,) -> (B, S)
165+
reading_order = ops.repeat_interleave(reading_order[None, ...], enc_output.shape[0], 0) # (S,) -> (B, S)
166166
reading_order = self.f0_embedding(reading_order) # b,max_len,512
167167
# calculate attention
168168
t = self.w0(reading_order.transpose((0, 2, 1))) # b,512,256

0 commit comments

Comments
 (0)