|
27 | 27 | from transformers import PreTrainedModel, TrainerCallback
|
28 | 28 | from transformers.trainer import Trainer
|
29 | 29 | from trl import GRPOTrainer as HFGRPOTrainer
|
30 |
| -from trl.extras.profiling import profiling_context, profiling_decorator |
31 | 30 | from trl.models import prepare_deepspeed
|
32 | 31 | from trl.trainer.callbacks import SyncRefModelCallback
|
33 | 32 | from trl.trainer.grpo_trainer import nanmax, nanmin, nanstd
|
|
39 | 38 | from swift.llm.template.template_inputs import StdTemplateInputs
|
40 | 39 | from swift.plugin import loss_scale_map, multi_turns, orms, rm_plugins
|
41 | 40 | from swift.plugin.multi_turn import MultiTurnScheduler
|
42 |
| -from swift.utils import (JsonlWriter, empty_cache, get_current_device, get_device, get_logger, is_vllm_available, |
43 |
| - is_wandb_available, seed_worker, unwrap_model_for_generation) |
| 41 | +from swift.utils import (JsonlWriter, empty_cache, get_current_device, get_device, get_logger, is_swanlab_available, |
| 42 | + is_vllm_available, is_wandb_available, seed_worker, unwrap_model_for_generation) |
44 | 43 | from ..mixin import SwiftMixin
|
45 | 44 | from .rlhf_mixin import RLHFTrainerMixin
|
46 |
| -from .utils import _ForwardRedirection, patch_lora_merge, patch_lora_unmerge |
| 45 | +from .utils import (_ForwardRedirection, patch_lora_merge, patch_lora_unmerge, patch_profiling_context, |
| 46 | + patch_profiling_decorator) |
47 | 47 | from .vllm_client import VLLMClient
|
48 | 48 |
|
49 | 49 | del HFGRPOTrainer.__init__
|
|
52 | 52 | logger = get_logger()
|
53 | 53 | if is_wandb_available():
|
54 | 54 | import wandb
|
| 55 | +if is_swanlab_available(): |
| 56 | + import swanlab |
55 | 57 |
|
56 | 58 | InputsType = List[Dict[str, Union[torch.Tensor, Any]]]
|
57 | 59 | # tuple: (messages, finish_reason)
|
@@ -325,7 +327,7 @@ def cyclic_iter(iterable):
|
325 | 327 | # flag indicating whether the evaluation has started
|
326 | 328 | self.eval_flag = False
|
327 | 329 |
|
328 |
| - @profiling_decorator |
| 330 | + @patch_profiling_decorator |
329 | 331 | def _prepare_inputs(self, generation_batch: dict[str, Union[torch.Tensor,
|
330 | 332 | Any]]) -> dict[str, Union[torch.Tensor, Any]]:
|
331 | 333 | # Prepares inputs for model training/evaluation by managing completion generation and batch handling.
|
@@ -479,7 +481,7 @@ def _template_context(self, template: Template):
|
479 | 481 | template.set_mode(mode)
|
480 | 482 | template.max_length = max_length
|
481 | 483 |
|
482 |
| - @profiling_decorator |
| 484 | + @patch_profiling_decorator |
483 | 485 | def _move_model_to_vllm(self, skip_async_check=False):
|
484 | 486 | deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
485 | 487 | zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
|
@@ -906,7 +908,7 @@ def _score_completions(self, inputs: InputsType) -> Tuple[torch.Tensor, torch.Te
|
906 | 908 |
|
907 | 909 | for i, (reward_func, reward_model_plugin, reward_func_name) in enumerate(
|
908 | 910 | zip(self.reward_funcs, self.reward_model_plugins, self.reward_func_names)):
|
909 |
| - with profiling_context(self, reward_func_name): |
| 911 | + with patch_profiling_context(self, reward_func_name): |
910 | 912 | # reward model
|
911 | 913 | if isinstance(reward_func, nn.Module):
|
912 | 914 | output_reward_func = reward_model_plugin(inputs=inputs)
|
@@ -1110,7 +1112,7 @@ def _apply_chat_template_to_messages_list(self, messages_list: InputsType):
|
1110 | 1112 | prompts_text.append(''.join(processed_context))
|
1111 | 1113 | return prompts_text
|
1112 | 1114 |
|
1113 |
| - @profiling_decorator |
| 1115 | + @patch_profiling_decorator |
1114 | 1116 | def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
1115 | 1117 | # Compute the per-token log probabilities for the model, return_outputs=True in mini-batch training
|
1116 | 1118 | if isinstance(inputs, list):
|
@@ -1275,7 +1277,7 @@ def _padding_free_output_hook(module, args, kwargs, result):
|
1275 | 1277 | remove_handle2.remove()
|
1276 | 1278 |
|
1277 | 1279 | # Get the per-token log probabilities for the completions for the model and the reference model
|
1278 |
| - @profiling_decorator |
| 1280 | + @patch_profiling_decorator |
1279 | 1281 | def _get_per_token_logps(self, model, inputs):
|
1280 | 1282 | from trl.trainer.utils import selective_log_softmax
|
1281 | 1283 | logits_to_keep = inputs['logits_to_keep']
|
@@ -1305,7 +1307,7 @@ def _get_per_token_logps(self, model, inputs):
|
1305 | 1307 | input_ids = input_ids[:, -logits_to_keep:]
|
1306 | 1308 | return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
|
1307 | 1309 |
|
1308 |
| - @profiling_decorator |
| 1310 | + @patch_profiling_decorator |
1309 | 1311 | def _get_last_hidden_state(self, unwrapped_model, inputs, logits_to_keep):
|
1310 | 1312 | # unwrap the model to access the model.model
|
1311 | 1313 | if is_peft_model(unwrapped_model):
|
@@ -1399,7 +1401,7 @@ def _engine_infer(
|
1399 | 1401 | *,
|
1400 | 1402 | use_tqdm: Optional[bool] = False,
|
1401 | 1403 | ) -> List[ChatCompletionResponse]:
|
1402 |
| - with profiling_context(self, 'generate'): |
| 1404 | + with patch_profiling_context(self, 'generate'): |
1403 | 1405 | if self.vllm_mode == 'server':
|
1404 | 1406 | request_keys = ['messages', 'images', 'audios', 'videos', 'tools', 'objects']
|
1405 | 1407 |
|
@@ -1586,6 +1588,16 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non
|
1586 | 1588 | df = df.drop_duplicates(subset=['prompt'])
|
1587 | 1589 | wandb.log({'completions': wandb.Table(dataframe=df)})
|
1588 | 1590 |
|
| 1591 | + if self.args.report_to and 'swanlab' in self.args.report_to and swanlab.get_run() is not None: |
| 1592 | + headers = list(table.keys()) |
| 1593 | + rows = [] |
| 1594 | + for i in range(len(table['step'])): |
| 1595 | + row = [] |
| 1596 | + for header in headers: |
| 1597 | + row.append(table[header][i]) |
| 1598 | + rows.append(row) |
| 1599 | + swanlab.log({'completions': swanlab.echarts.Table().add(headers, rows)}) |
| 1600 | + |
1589 | 1601 | def is_async_generate_eval_rollout_done(self):
|
1590 | 1602 | return not self.eval_flag or not self.eval_queue.empty()
|
1591 | 1603 |
|
|
0 commit comments