|
8 | 8 |
|
9 | 9 | from swift.llm import get_llm_model
|
10 | 10 | from .base import CommonSequenceParallel
|
11 |
| -from .utils import (SequenceParallelDispatcher, SequenceParallelSampler, _get_per_token_logps_and_entropies_grpo, |
12 |
| - _get_train_sampler_grpo, _prepare_inputs, _prepare_inputs_grpo, get_common_dataloader, |
13 |
| - get_per_token_logps, loss_scale_sp_func, old_policy_grpo, setup_compute_acc, |
14 |
| - split_by_mini_batches_grpo) |
| 11 | +from .utils import (GatherLoss, SequenceParallelDispatcher, SequenceParallelSampler, |
| 12 | + _get_per_token_logps_and_entropies_grpo, _get_train_sampler_grpo, _prepare_inputs, |
| 13 | + _prepare_inputs_grpo, get_common_dataloader, get_per_token_logps, loss_scale_sp_func, |
| 14 | + old_policy_grpo, setup_compute_acc, split_by_mini_batches_grpo) |
15 | 15 |
|
16 | 16 | assert version.parse(torch.__version__) >= version.parse('2.0.0')
|
17 | 17 | torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
@@ -233,6 +233,30 @@ def pre_forward_split_hook(_self, args, kwargs):
|
233 | 233 | if hasattr(base_model, '_update_causal_mask'):
|
234 | 234 | self.causal_mask_func = base_model._update_causal_mask
|
235 | 235 | base_model.register_forward_pre_hook(pre_forward_split_hook, with_kwargs=True)
|
| 236 | + base_model: torch.nn.Module |
| 237 | + |
| 238 | + def moe_aux_loss_hook(module, args, kwargs, output): |
| 239 | + router_logits = getattr(output, 'router_logits', None) |
| 240 | + if router_logits is None: |
| 241 | + return output |
| 242 | + |
| 243 | + attention_mask = kwargs['attention_mask'] |
| 244 | + num_layers = len(router_logits) |
| 245 | + sp_len = router_logits[0].shape[0] |
| 246 | + if isinstance(router_logits, tuple): |
| 247 | + compute_device = router_logits[0].device |
| 248 | + router_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in router_logits], dim=0) |
| 249 | + router_logits, _ = GatherLoss.apply(router_logits, None, self.sp_group) |
| 250 | + router_logits = router_logits.reshape(self.sp_world_size, num_layers, sp_len, |
| 251 | + -1).transpose(0, 1).reshape(num_layers, self.sp_world_size * sp_len, |
| 252 | + -1) |
| 253 | + if attention_mask is not None: |
| 254 | + router_logits = router_logits[:, :attention_mask.shape[1], :] |
| 255 | + output['router_logits'] = tuple([logit.squeeze() for logit in router_logits.split(1, dim=0)]) |
| 256 | + return output |
| 257 | + |
| 258 | + if model.model_info.is_moe_model: |
| 259 | + base_model.register_forward_hook(moe_aux_loss_hook, with_kwargs=True) |
236 | 260 | self.model_dtype = next(model.parameters()).dtype
|
237 | 261 | self.tokenizer = tokenizer
|
238 | 262 |
|
|
0 commit comments