@@ -386,10 +386,8 @@ def prepare_inputs(
386386 inputs = {}
387387 if not self .stateful :
388388 if past_key_values is not None :
389- if (
390- self .config .model_type not in MULTI_QUERY_ATTN_MODELS
391- or self .config .model_type == "falcon"
392- and self .config .new_decoder_architecture
389+ if self .config .model_type not in MULTI_QUERY_ATTN_MODELS or (
390+ self .config .model_type == "falcon" and self .config .new_decoder_architecture
393391 ):
394392 if self ._pkv_precision == Type .bf16 :
395393 # numpy does not support bf16, pretending f16, should change to bf16
@@ -499,10 +497,8 @@ def forward(
499497 if self .use_cache :
500498 # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
501499 past_key_values = tuple (self .request .get_tensor (key ).data for key in self .key_value_output_names )
502- if (
503- self .config .model_type not in MULTI_QUERY_ATTN_MODELS
504- or self .config .model_type == "falcon"
505- and self .config .new_decoder_architecture
500+ if self .config .model_type not in MULTI_QUERY_ATTN_MODELS or (
501+ self .config .model_type == "falcon" and self .config .new_decoder_architecture
506502 ):
507503 # Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
508504 past_key_values = tuple (
@@ -559,10 +555,8 @@ def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_ke
559555 if indicies .shape [0 ] != 1 :
560556 logits = logits [indicies ]
561557 if past_key_values and not self .stateful :
562- if (
563- self .config .model_type not in MULTI_QUERY_ATTN_MODELS
564- or self .config .model_type == "falcon"
565- and self .config .new_decoder_architecture
558+ if self .config .model_type not in MULTI_QUERY_ATTN_MODELS or (
559+ self .config .model_type == "falcon" and self .config .new_decoder_architecture
566560 ):
567561 past_key_values = tuple (
568562 tuple (
@@ -581,7 +575,7 @@ def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_ke
581575 if self .next_beam_idx is not None
582576 else np .arange (batch_size , dtype = int )[indicies ]
583577 )
584- self ._second_iter_beam_search = True
578+ self ._second_iter_beam_search = True
585579 return logits , past_key_values
586580
587581 def _deduplicate_inputs (self , model_inputs : Dict ):
@@ -692,7 +686,7 @@ def _reorder_cache(
692686 self ._second_iter_beam_search = False
693687 return past_key_values
694688 else :
695- if self .config .model_type not in MULTI_QUERY_ATTN_MODELS and not (
689+ if self .config .model_type not in MULTI_QUERY_ATTN_MODELS or (
696690 self .config .model_type == "falcon" and self .config .new_decoder_architecture
697691 ):
698692 return tuple (
0 commit comments