|  | 
| 25 | 25 | from ai_edge_torch.generative.layers import lora as lora_utils | 
| 26 | 26 | import ai_edge_torch.generative.layers.attention_utils as attn_utils | 
| 27 | 27 | import ai_edge_torch.generative.layers.model_config as cfg | 
| 28 |  | -import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb | 
| 29 | 28 | import ai_edge_torch.generative.utilities.loader as loading_utils | 
| 30 | 29 | import torch | 
| 31 | 30 | from torch import nn | 
| @@ -115,23 +114,17 @@ def forward( | 
| 115 | 114 |     # ROPE parameters for all attn_configs are the same. Take the first one. | 
| 116 | 115 |     attn_config = self.config.block_config(0).attn_config | 
| 117 | 116 |     n_elem = int(attn_config.rotary_percentage * attn_config.head_dim) | 
| 118 |  | -    rope = self.config.build_rope( | 
| 119 |  | -        input_pos=input_pos, | 
| 120 |  | -        n_elem=n_elem, | 
| 121 |  | -        base=attn_config.rotary_base, | 
| 122 |  | -        head_dim=attn_config.head_dim, | 
| 123 |  | -        # input_pos=input_pos, n_elem=n_elem, base=attn_config.rotary_base | 
| 124 |  | -    ) | 
|  | 117 | +    rope = self.config.build_rope(input_pos, n_elem, attn_config.rotary_base) | 
| 125 | 118 | 
 | 
| 126 | 119 |     if mask is None: | 
| 127 | 120 |       mask = self.mask_cache.index_select(2, input_pos) | 
| 128 | 121 |       mask = mask[:, :, :, : self.config.kv_cache_max] | 
| 129 | 122 | 
 | 
| 130 |  | -    return self.forward_with_embeds( | 
|  | 123 | +    return self._forward_with_embeds( | 
| 131 | 124 |         input_embeds, rope, mask, input_pos, kv_cache, lora, export_config | 
| 132 | 125 |     ) | 
| 133 | 126 | 
 | 
| 134 |  | -  def forward_with_embeds( | 
|  | 127 | +  def _forward_with_embeds( | 
| 135 | 128 |       self, | 
| 136 | 129 |       input_embeds: torch.Tensor, | 
| 137 | 130 |       rope: Tuple[torch.Tensor, torch.Tensor], | 
|  | 
0 commit comments