|
28 | 28 | from transformers import PreTrainedModel, TrainerCallback |
29 | 29 | from transformers.trainer import Trainer |
30 | 30 | from trl import GRPOTrainer as HFGRPOTrainer |
31 | | -from trl.extras.profiling import profiling_decorator |
| 31 | +from trl.extras.profiling import profiling_context, profiling_decorator |
32 | 32 | from trl.models import prepare_deepspeed |
33 | 33 | from trl.trainer.callbacks import SyncRefModelCallback |
34 | | -from trl.trainer.grpo_trainer import nanmax, nanmin |
| 34 | +from trl.trainer.grpo_trainer import nanmax, nanmin, nanstd |
35 | 35 |
|
36 | 36 | from swift.llm import InferRequest, MultiModelKeys, RequestConfig, RowPreprocessor, get_model_arch, to_device |
37 | 37 | from swift.llm.model.utils import get_llm_model |
@@ -873,19 +873,30 @@ def _score_completions(self, inputs: InputsType) -> Tuple[torch.Tensor, torch.Te |
873 | 873 | completions = [example['messages'][-1]['content'] for example in inputs] |
874 | 874 | rewards_per_func = torch.zeros((len(inputs), len(self.reward_funcs)), device=device) |
875 | 875 |
|
876 | | - for i, (reward_func, reward_model_plugin) in enumerate(zip(self.reward_funcs, self.reward_model_plugins)): |
877 | | - # reward model |
878 | | - if isinstance(reward_func, nn.Module): |
879 | | - rewards_per_func[:, i] = reward_model_plugin(inputs=inputs) |
880 | | - # reward function |
881 | | - else: |
882 | | - # Repeat all input columns (but "messages" and "completion") to match the number of generations |
883 | | - reward_kwargs = RowPreprocessor.rows_to_batched(inputs) |
884 | | - output_reward_func = reward_func(completions, **reward_kwargs) |
| 876 | + for i, (reward_func, reward_model_plugin, reward_func_name) in enumerate( |
| 877 | + zip(self.reward_funcs, self.reward_model_plugins, self.reward_func_names)): |
| 878 | + with profiling_context(self, reward_func_name): |
| 879 | + # reward model |
| 880 | + if isinstance(reward_func, nn.Module): |
| 881 | + output_reward_func = reward_model_plugin(inputs=inputs) |
| 882 | + # reward function |
| 883 | + else: |
| 884 | + # Repeat all input columns (but "messages" and "completion") to match the number of generations |
| 885 | + reward_kwargs = RowPreprocessor.rows_to_batched(inputs) |
| 886 | + output_reward_func = reward_func(completions, **reward_kwargs) |
| 887 | + output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] |
885 | 888 | rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) |
886 | 889 |
|
| 890 | + # If all reward functions return None for a given row, issue a detailed warning |
| 891 | + if torch.isnan(rewards_per_func).all(dim=1).any(): |
| 892 | + nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0] |
| 893 | + row_reward_kwargs = {key: value[nan_row_idx] for key, value in reward_kwargs.items()} |
| 894 | + row_reward_kwargs['completion'] = completions[nan_row_idx] |
| 895 | + logger.warning(f'All reward functions returned None for the following kwargs: {row_reward_kwargs}. ' |
| 896 | + 'Please ensure that at least one reward function returns a valid reward.') |
| 897 | + |
887 | 898 | total_rewards_per_func = gather(rewards_per_func) |
888 | | - total_rewards = (total_rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1) |
| 899 | + total_rewards = (total_rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) |
889 | 900 |
|
890 | 901 | return total_rewards_per_func, total_rewards, completions |
891 | 902 |
|
@@ -1027,10 +1038,11 @@ def _log_metrics(self, inputs, messages, completions, rewards, rewards_per_func) |
1027 | 1038 |
|
1028 | 1039 | self._metrics[mode]['completions/clipped_ratio'].append(clipped_completions_ratio) |
1029 | 1040 |
|
| 1041 | + # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) |
1030 | 1042 | for i, reward_func_name in enumerate(self.reward_func_names): |
1031 | | - mean_rewards = rewards_per_func[:, i].mean().item() |
| 1043 | + mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() |
1032 | 1044 | self._metrics[mode][f'rewards/{reward_func_name}/mean'].append(mean_rewards) |
1033 | | - std_rewards = rewards_per_func[:, i].std().item() |
| 1045 | + std_rewards = nanstd(rewards_per_func[:, i]).item() |
1034 | 1046 | self._metrics[mode][f'rewards/{reward_func_name}/std'].append(std_rewards) |
1035 | 1047 |
|
1036 | 1048 | # Log overall reward stats |
@@ -1071,7 +1083,8 @@ def _compute_loss(self, model, inputs): |
1071 | 1083 | # apply the completion_mask to exclude loss and metrics for overlong completions |
1072 | 1084 | if self.args.overlong_filter and any(truncated_mask): |
1073 | 1085 | if all(truncated_mask): |
1074 | | - logger.info('All completions are overlong, loss and KL will be zero') |
| 1086 | + logger.info('All completions are overlong and truncated, ' |
| 1087 | + 'resulting in NaN some values for some metrics (e.g., KL)') |
1075 | 1088 | truncated_mask = truncated_mask.unsqueeze(-1).expand_as(completion_mask).to(completion_mask.device) |
1076 | 1089 | completion_mask = completion_mask * (~truncated_mask) |
1077 | 1090 |
|
@@ -1341,11 +1354,12 @@ def _engine_infer( |
1341 | 1354 | *, |
1342 | 1355 | use_tqdm: Optional[bool] = False, |
1343 | 1356 | ): |
1344 | | - if self.vllm_mode == 'server': |
1345 | | - self._process_infer_requests_images(infer_requests) |
1346 | | - return self.vllm_client.infer(infer_requests, asdict(request_config), use_tqdm=use_tqdm) |
1347 | | - else: |
1348 | | - return self.engine.infer(infer_requests, request_config, use_tqdm=use_tqdm) |
| 1357 | + with profiling_context(self, 'generate'): |
| 1358 | + if self.vllm_mode == 'server': |
| 1359 | + self._process_infer_requests_images(infer_requests) |
| 1360 | + return self.vllm_client.infer(infer_requests, asdict(request_config), use_tqdm=use_tqdm) |
| 1361 | + else: |
| 1362 | + return self.engine.infer(infer_requests, request_config, use_tqdm=use_tqdm) |
1349 | 1363 |
|
1350 | 1364 | def _process_infer_requests_images(self, infer_requests: List[InferRequest]): |
1351 | 1365 | # Process image format into a format that session.post can accept |
|
0 commit comments