Skip to content

Commit de47d22

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 74f038f commit de47d22

File tree

1 file changed

+1
-9
lines changed

1 file changed

+1
-9
lines changed

colossalai/shardformer/modeling/qwen2.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
CausalLMOutputWithPast,
1010
SequenceClassifierOutputWithPast,
1111
)
12-
from transformers.cache_utils import DynamicCache
1312

1413
try:
1514
from transformers.modeling_attn_mask_utils import (
@@ -210,8 +209,7 @@ def qwen2_model_forward(
210209
if output_hidden_states:
211210
all_hidden_states += (hidden_states,)
212211

213-
past_key_value = past_key_values[idx] if past_key_values is not None else None
214-
212+
past_key_values[idx] if past_key_values is not None else None
215213

216214
if idx - start_idx < num_ckpt_layers:
217215
layer_outputs = self._gradient_checkpointing_func(
@@ -523,7 +521,6 @@ def forward(
523521
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
524522
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
525523

526-
527524
kv_seq_len = key_states.shape[-2]
528525
if past_key_value is not None:
529526
if self.layer_idx is None:
@@ -649,7 +646,6 @@ def forward(
649646
else:
650647
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
651648

652-
653649
if inputs_embeds is None:
654650
inputs_embeds = self.embed_tokens(input_ids)
655651

@@ -669,11 +665,9 @@ def forward(
669665
else:
670666
position_ids = position_ids.view(-1, seq_length).long()
671667

672-
673668
# embed positions
674669
hidden_states = inputs_embeds
675670

676-
677671
if shard_config.enable_flash_attention:
678672
# in this case, attention_mask is a dict rather than a tensor
679673
mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
@@ -693,7 +687,6 @@ def forward(
693687
sliding_window=self.config.sliding_window,
694688
)
695689

696-
697690
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
698691
if use_cache:
699692
logger.warning_once(
@@ -746,7 +739,6 @@ def forward(
746739

747740
hidden_states = layer_outputs[0]
748741

749-
750742
if output_attentions:
751743
all_self_attns += (layer_outputs[1],)
752744

0 commit comments

Comments
 (0)