Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def forward(
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
Expand All @@ -175,7 +175,7 @@ def forward(

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
Expand Down Expand Up @@ -238,15 +238,13 @@ def forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
**kwargs,
) -> Tuple[
torch.FloatTensor,
torch.FloatTensor,
Optional[torch.FloatTensor],
Optional[torch.FloatTensor],
Optional[torch.Tensor],
Optional[Cache],
]:
residual = hidden_states

Expand All @@ -259,7 +257,6 @@ def forward(
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states

Expand All @@ -272,8 +269,6 @@ def forward(
if not output_attentions:
self_attn_weights = None

if not use_cache:
present_key_value = None
return hidden_states, self_attn_weights, present_key_value


Expand Down Expand Up @@ -317,9 +312,10 @@ def forward(
input_ids: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[
Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor]]]
] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
Expand All @@ -335,7 +331,6 @@ def forward(
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache

return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
Expand All @@ -359,17 +354,25 @@ def forward(
inputs_embeds = self.embed_layer(input_ids)
seq_length = inputs_embeds.shape[1]

if self.gradient_checkpointing and self.training:
if use_cache:
use_cache = False

past_key_values_length = 0
use_legacy_cache = False

if use_cache:
if past_key_values is not None:
use_legacy_cache = not isinstance(past_key_values, Cache)
# Converts the legacy cache which is tuple into an equivalent Cache. Used for backward compatibility.
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
# Suppose the sequence length of each layer is the same
past_key_values_length = past_key_values.get_seq_length()

# When training + checkpoints, caching is usually disabled (just do not transfer)
if (
self.gradient_checkpointing
and self.training
and isinstance(past_key_values, Cache)
):
past_key_values = None
past_key_values_length = 0

if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
Expand Down Expand Up @@ -412,7 +415,6 @@ def forward(
position_ids,
past_key_values,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
Expand All @@ -421,15 +423,14 @@ def forward(
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)

hidden_states = layer_outputs[0]

if output_attentions:
all_self_attns += (layer_outputs[1],)

if use_cache:
if isinstance(past_key_values, Cache):
next_decoder_cache = layer_outputs[2]

hidden_states = self.norm(hidden_states)
Expand All @@ -438,7 +439,7 @@ def forward(
all_hidden_states += (hidden_states,)

next_cache = None
if use_cache:
if isinstance(past_key_values, Cache):
next_cache = (
next_decoder_cache.to_legacy_cache()
if use_legacy_cache
Expand Down Expand Up @@ -484,12 +485,13 @@ def forward(
input_ids: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[
Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor]]]
] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.FloatTensor] = None,
loss_masks: Optional[torch.FloatTensor] = None,
mask_y: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
Expand Down Expand Up @@ -525,7 +527,6 @@ def forward(
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
Expand Down Expand Up @@ -604,16 +605,9 @@ def prepare_inputs_for_generation(
# Omit tokens covered by past_key_values
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
if isinstance(past_key_values, DynamicCache):
past_length = past_key_values.seen_tokens
else:
past_length = cache_length

max_cache_length = past_key_values.get_max_length()
past_length = past_key_values.get_seq_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
past_length = past_key_values[0][0].shape[2]

# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
Expand All @@ -622,26 +616,13 @@ def prepare_inputs_for_generation(
if attention_mask is not None and attention_mask.shape[1] > (
input_ids.shape[1] // self.config.input_token_len
):
input_ids = input_ids[
:,
-(attention_mask.shape[1] - past_length)
* self.config.input_token_len :,
]
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < (input_ids.shape[1] // self.config.input_token_len):
input_ids = input_ids[:, past_length * self.config.input_token_len :]
# 3 - Otherwise (past_length >= (input_ids.shape[1] // self.config.input_token_len)), let's assume input_ids only has unprocessed tokens.

# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + (input_ids.shape[1] // self.config.input_token_len)
> max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
Expand All @@ -662,7 +643,6 @@ def prepare_inputs_for_generation(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"revin": revin,
"num_samples": num_samples,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def generate(
outputs = (outputs * stdev) + means
return outputs

def _greedy_search(
def _sample(
self,
input_ids: torch.Tensor,
logits_processor: Optional[LogitsProcessorList] = None,
Expand Down Expand Up @@ -269,7 +269,7 @@ def _greedy_search(
horizon_length = next_tokens.shape[-1] // self.config.input_token_len

past_key_values = model_kwargs.get("past_key_values")
if past_key_values is None:
if past_key_values is None or generate_results is None:
generate_results = next_tokens
else:
generate_results = torch.cat([generate_results, next_tokens], dim=-1)
Expand Down Expand Up @@ -328,9 +328,13 @@ def _update_model_kwargs_for_generation(
standardize_cache_format: bool = False,
) -> Dict[str, Any]:
# update past_key_values
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
if "past_key_values" in outputs:
model_kwargs["past_key_values"] = outputs.past_key_values
elif "mems" in outputs:
model_kwargs["past_key_values"] = outputs.mems
elif "past_buckets_states" in outputs:
model_kwargs["past_key_values"] = outputs.past_buckets_states

if getattr(outputs, "state", None) is not None:
model_kwargs["state"] = outputs.state

Expand Down
Loading
Loading