|  | 
| 19 | 19 | 
 | 
| 20 | 20 | from ai_edge_torch.generative.layers import kv_cache as kv_utils | 
| 21 | 21 | import ai_edge_torch.generative.layers.model_config as cfg | 
| 22 |  | -import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb | 
| 23 | 22 | from ai_edge_torch.generative.utilities import model_builder | 
| 24 | 23 | import ai_edge_torch.generative.utilities.loader as loading_utils | 
| 25 | 24 | import torch | 
| @@ -59,25 +58,32 @@ def forward( | 
| 59 | 58 |       called_by_generate: bool = True, | 
| 60 | 59 |   ) -> dict[torch.Tensor, kv_utils.KVCache]: | 
| 61 | 60 |     if input_embeds is None: | 
| 62 |  | -      return super().forward(tokens, input_pos, kv_cache) | 
|  | 61 | +      return super().forward( | 
|  | 62 | +          tokens, input_pos, kv_cache, mask, export_config=export_config | 
|  | 63 | +      ) | 
| 63 | 64 | 
 | 
| 64 | 65 |     assert input_embeds is not None | 
| 65 | 66 | 
 | 
| 66 | 67 |     repo_pos = input_pos + 1  # PaliGemma position is 1-based. | 
| 67 | 68 |     # ROPE parameters for all attn_configs are the same. Take the first one. | 
| 68 | 69 |     attn_config = self.config.block_config(0).attn_config | 
| 69 | 70 |     n_elem = int(attn_config.rotary_percentage * attn_config.head_dim) | 
| 70 |  | -    rope = rotary_pos_emb.build_rope(repo_pos, n_elem, attn_config.rotary_base) | 
|  | 71 | +    rope = self.config.build_rope(repo_pos, n_elem, attn_config.rotary_base) | 
| 71 | 72 | 
 | 
| 72 | 73 |     # The first part of input_embeds are image embeddings. Diagonal causal mask | 
| 73 | 74 |     # doesn't work here. | 
| 74 |  | -    embeds_len = input_embeds.shape[1] | 
| 75 | 75 |     if mask is None: | 
|  | 76 | +      embeds_len = input_embeds.shape[1] | 
| 76 | 77 |       mask = torch.zeros(embeds_len, self.config.kv_cache_max) | 
| 77 | 78 |       mask[:, embeds_len:] = float("-inf") | 
| 78 | 79 | 
 | 
| 79 | 80 |     return self._forward_with_embeds( | 
| 80 |  | -        input_embeds, rope, mask, input_pos, kv_cache | 
|  | 81 | +        input_embeds, | 
|  | 82 | +        rope, | 
|  | 83 | +        mask, | 
|  | 84 | +        input_pos, | 
|  | 85 | +        kv_cache, | 
|  | 86 | +        export_config=export_config, | 
| 81 | 87 |     ) | 
| 82 | 88 | 
 | 
| 83 | 89 | 
 | 
|  | 
0 commit comments