-
Notifications
You must be signed in to change notification settings - Fork 151
Fix Phi long context issue #1504
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
helena-intel
wants to merge
10
commits into
huggingface:main
Choose a base branch
from
helena-intel:ea/lonrope_exp
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 8 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
4440904
test longrope phi4
eaidova 72cd3c8
update prepare_inputs_for_generation
eaidova 8aa5978
change condition
eaidova 19feb0b
Merge branch 'main' into ea/lonrope_exp
IlyasMoutawwakil 822664a
Merge branch 'main' into ea/lonrope_exp
helena-intel 4426e18
Merge remote-tracking branch 'origin/main' into ea/lonrope_exp
helena-intel 9f0394a
Merge remote-tracking branch 'origin/main' into ea/lonrope_exp
helena-intel c8adca6
Skip longrope for phi_moe for now
helena-intel 75af74c
Remove commented out code
helena-intel 8185565
Merge remote-tracking branch 'upstream/main' into ea/lonrope_exp
helena-intel File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1493,15 +1493,54 @@ def _phi3_self_attn_sdpa_forward( | |
| return attn_output, None, past_key_value | ||
|
|
||
|
|
||
| # @torch.jit.script | ||
|
||
| def select_ext_factor( | ||
| seq_len: torch.Tensor, max_pos_embeddings: torch.Tensor, short_factor: torch.Tensor, long_factor: torch.Tensor | ||
| ): | ||
| return torch.where( | ||
| seq_len <= max_pos_embeddings, short_factor, long_factor | ||
| ) # short_factor * (seq_len <= max_pos_embeddings) + long_factor * (seq_len > max_pos_embeddings) | ||
|
|
||
|
|
||
| def long_rope(self, x, position_ids, seq_len=None): | ||
| seq_len = torch.max(position_ids) + 1 | ||
| original_max_position_embeddings = ( | ||
| self.original_max_position_embeddings | ||
| if hasattr(self, "original_max_positional_embeddings") | ||
| else self.config.original_max_position_embeddings | ||
| ) | ||
| max_position_embeddings = ( | ||
| self.max_position_embeddings | ||
| if hasattr(self, "max_position_embeddings") | ||
| else self.config.max_position_embeddings | ||
| ) | ||
| inv_freq = select_ext_factor(seq_len, original_max_position_embeddings, self.inv_freq, self.long_inv_freq) | ||
|
|
||
| inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) | ||
| position_ids_expanded = position_ids[:, None, :].float() | ||
|
|
||
| # Force float32 since bfloat16 loses precision on long contexts | ||
| # See https://github.com/huggingface/transformers/pull/29285 | ||
| device_type = x.device.type | ||
| device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" | ||
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) | ||
| emb = torch.cat((freqs, freqs), dim=-1) | ||
|
|
||
| scale = max_position_embeddings / original_max_position_embeddings | ||
| if scale <= 1.0: | ||
| scaling_factor = 1.0 | ||
| else: | ||
| scaling_factor = math.sqrt(1 + math.log(scale) / math.log(original_max_position_embeddings)) | ||
| cos = emb.cos() * scaling_factor | ||
| sin = emb.sin() * scaling_factor | ||
| return cos, sin | ||
|
|
||
|
|
||
| class Phi3ModelPatcher(OVDecoderModelPatcher): | ||
| def __enter__(self): | ||
| super().__enter__() | ||
|
|
||
| # currently, long RoPE can not be traced for long context support, disable it for avoid potential accuracy issues | ||
| if self._model.config.max_position_embeddings != getattr( | ||
| self._model.config, "original_max_position_embeddings", self._model.config.max_position_embeddings | ||
| ): | ||
| self._model.config.max_position_embeddings = self._model.config.original_max_position_embeddings | ||
|
|
||
| if is_transformers_version("<", "4.48.0"): | ||
| self._model.model._orig_forward = self._model.model.forward | ||
|
|
@@ -1529,6 +1568,23 @@ def __enter__(self): | |
| rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim) | ||
| ) | ||
|
|
||
| if ( | ||
| hasattr(self._model.model, "rotary_emb") | ||
| and getattr(self._model.model.rotary_emb, "rope_type", "default") == "longrope" | ||
| ): | ||
| long_inv_freq, _ = self._model.model.rotary_emb.rope_init_fn( | ||
| self._model.config, | ||
| torch.device("cpu"), | ||
| seq_len=self._model.config.original_max_position_embeddings + 1, | ||
| ) | ||
| self._model.model.rotary_emb.long_inv_freq = long_inv_freq | ||
| self._model.model.rotary_emb._orig_forward = self._model.model.rotary_emb.forward | ||
| self._model.model.rotary_emb.forward = types.MethodType(long_rope, self._model.model.rotary_emb) | ||
| elif self._model.config.max_position_embeddings != getattr( | ||
| self._model.config, "original_max_position_embeddings", self._model.config.max_position_embeddings | ||
| ): | ||
| self._model.config.max_position_embeddings = self._model.config.original_max_position_embeddings | ||
|
|
||
| def __exit__(self, exc_type, exc_value, traceback): | ||
| super().__exit__(exc_type, exc_value, traceback) | ||
| if hasattr(self._model.model, "_orig_forward"): | ||
|
|
@@ -1538,6 +1594,8 @@ def __exit__(self, exc_type, exc_value, traceback): | |
| for layer in self._model.model.layers: | ||
| if hasattr(layer.self_attn, "_orig_forward"): | ||
| layer.self_attn.forward = layer.self_attn._orig_forward | ||
| if hasattr(self._model.model, "rotary_emb") and hasattr(self._model.model.rotary_emb, "_orig_forward"): | ||
| self._model.model.rotary_emb.forward = self._model.model.rotary_emb._orig_forward | ||
|
|
||
|
|
||
| # Modified from https://github.com/huggingface/transformers/blob/v4.50.2/src/transformers/models/phimoe/modeling_phimoe.py#L756 | ||
|
|
@@ -1590,13 +1648,47 @@ def _phi_moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor) -> torc | |
|
|
||
| class PhiMoEModelPatcher(Phi3ModelPatcher): | ||
| def __enter__(self): | ||
| super().__enter__() | ||
| # Call OVDecoderModelPatcher.__enter__() directly to skip Phi3ModelPatcher's longrope logic | ||
| # PhiMoE has a different rotary embedding structure, longrope is not yet supported | ||
| OVDecoderModelPatcher.__enter__(self) | ||
|
|
||
| if is_transformers_version("<", "4.48.0"): | ||
| self._model.model._orig_forward = self._model.model.forward | ||
| self._model.model.forward = types.MethodType(phi3_442_forward, self._model.model) | ||
|
|
||
| # init inv_freq for torchscript tracing for PhiMoE | ||
| for layer in self._model.model.layers: | ||
| if ( | ||
| is_torch_version(">=", "2.1.0") | ||
| and is_transformers_version("<", "4.48.0") | ||
| or not getattr(self._model, "_supports_sdpa", False) | ||
| ): | ||
| orig_self_attn_fwd = layer.self_attn.forward | ||
| layer.self_attn.forward = types.MethodType(_phi3_self_attn_sdpa_forward, layer.self_attn) | ||
| layer.self_attn._orig_forward = orig_self_attn_fwd | ||
|
|
||
| if ( | ||
| hasattr(layer.self_attn, "rotary_emb") | ||
| and getattr(layer.self_attn.rotary_emb, "inv_freq", None) is None | ||
| ): | ||
| rotary_emb = layer.self_attn.rotary_emb | ||
| layer.self_attn.rotary_emb.inv_freq = 1.0 / ( | ||
| rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim) | ||
| ) | ||
|
|
||
| # Apply MoE-specific patching | ||
| for layer in self._model.model.layers: | ||
| layer.block_sparse_moe._orig_forward = layer.block_sparse_moe.forward | ||
| layer.block_sparse_moe.forward = types.MethodType( | ||
| _phi_moe_sparse_moe_block_forward, layer.block_sparse_moe | ||
| ) | ||
|
|
||
| # For PhiMoE, reset max_position_embeddings if it was extended (skip longrope support for now) | ||
| if self._model.config.max_position_embeddings != getattr( | ||
| self._model.config, "original_max_position_embeddings", self._model.config.max_position_embeddings | ||
| ): | ||
| self._model.config.max_position_embeddings = self._model.config.original_max_position_embeddings | ||
|
|
||
| def __exit__(self, exc_type, exc_value, traceback): | ||
| super().__exit__(exc_type, exc_value, traceback) | ||
| for layer in self._model.model.layers: | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -358,10 +358,10 @@ def _export( | |
| variant=variant, | ||
| ) | ||
|
|
||
| if config.model_type == "phi3" and config.max_position_embeddings != getattr( | ||
| config, "original_max_position_embeddings", config.max_position_embeddings | ||
| ): | ||
| config.max_position_embeddings = config.original_max_position_embeddings | ||
| # if config.model_type == "phi3" and config.max_position_embeddings != getattr( | ||
| # config, "original_max_position_embeddings", config.max_position_embeddings | ||
| # ): | ||
| # config.max_position_embeddings = config.original_max_position_embeddings | ||
|
||
|
|
||
| return cls._from_pretrained( | ||
| model_id=save_dir_path, | ||
|
|
@@ -870,6 +870,8 @@ def _from_pretrained( | |
| init_cls = OVBloomForCausalLM | ||
| elif model_type == "gpt_bigcode": | ||
| init_cls = OVGPTBigCodeForCausalLM | ||
| elif model_type == "phi3": | ||
| init_cls = OVPhi3ForCausalLM | ||
| elif model_type in SSM_MODELS: | ||
| init_cls = OVModelWithMambaForCausalLM | ||
| else: | ||
|
|
@@ -950,6 +952,47 @@ def _from_pretrained( | |
| return causal_model | ||
|
|
||
|
|
||
| class OVPhi3ForCausalLM(OVModelForCausalLM): | ||
| def prepare_inputs_for_generation( | ||
| self, | ||
| input_ids, | ||
| past_key_values=None, | ||
| attention_mask=None, | ||
| inputs_embeds=None, | ||
| cache_position=None, | ||
| position_ids=None, | ||
| use_cache=True, | ||
| logits_to_keep=None, | ||
| **kwargs, | ||
| ): | ||
| # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the | ||
| # process | ||
|
|
||
| # When the first time input length reached long and short factor switching point, enforce re-compute cache | ||
| # It will cause downside of slower at this single token position, however, better than current failure. | ||
| if ( | ||
| past_key_values | ||
| and self.config.rope_scaling | ||
| and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 | ||
| ): | ||
| past_length = cache_position[0] | ||
| if past_length <= self.config.original_max_position_embeddings: | ||
| past_key_values = None | ||
|
|
||
| model_inputs = super().prepare_inputs_for_generation( | ||
| input_ids=input_ids, | ||
| past_key_values=past_key_values, | ||
| attention_mask=attention_mask, | ||
| inputs_embeds=inputs_embeds, | ||
| cache_position=cache_position, | ||
| position_ids=position_ids, | ||
| use_cache=use_cache, | ||
| logits_to_keep=logits_to_keep, | ||
| **kwargs, | ||
| ) | ||
| return model_inputs | ||
|
|
||
|
|
||
| class OVBloomForCausalLM(OVModelForCausalLM): | ||
| # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation | ||
| def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it makes sense to add the test with long prompt. Is this issue reproduced on tiny-model?