|
3 | 3 | import sys |
4 | 4 | from datetime import datetime |
5 | 5 |
|
| 6 | +import megatron.core |
6 | 7 | import torch |
7 | 8 | import torch.nn.functional as F |
| 9 | +from packaging import version |
8 | 10 |
|
9 | 11 | from swift.llm import git_clone_github |
10 | 12 | from swift.utils import get_logger, is_megatron_available, safe_ddp_context, subprocess_run |
@@ -334,8 +336,13 @@ def forward( |
334 | 336 | # Adjust key, value for inference |
335 | 337 | # =================================================== |
336 | 338 | # rotary_pos_emb = None |
337 | | - query, key, value, _, attn_mask_type = self._adjust_key_value_for_inference( |
338 | | - inference_context, query, key, value, rotary_pos_emb=None) |
| 339 | + megatron_core_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') |
| 340 | + if megatron_core_013: |
| 341 | + query, key, value, _, attn_mask_type, _ = self._adjust_key_value_for_inference( |
| 342 | + inference_context, query, key, value, rotary_pos_emb=None) |
| 343 | + else: |
| 344 | + query, key, value, _, attn_mask_type = self._adjust_key_value_for_inference( |
| 345 | + inference_context, query, key, value, rotary_pos_emb=None) |
339 | 346 |
|
340 | 347 | # TODO: Currently, TE can only accept contiguous tensors for MLA |
341 | 348 | query = query.contiguous() |
|
0 commit comments