Skip to content

Commit 1e8727f

Browse files
authored
feat: rlhf generation samples log to swanlab (#4907)
* visualize samples add swanlab * fix * fix lint * patch_profiling_context * patch_profiling_decorator
1 parent 67458df commit 1e8727f

File tree

3 files changed

+65
-11
lines changed

3 files changed

+65
-11
lines changed

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from transformers import PreTrainedModel, TrainerCallback
2828
from transformers.trainer import Trainer
2929
from trl import GRPOTrainer as HFGRPOTrainer
30-
from trl.extras.profiling import profiling_context, profiling_decorator
3130
from trl.models import prepare_deepspeed
3231
from trl.trainer.callbacks import SyncRefModelCallback
3332
from trl.trainer.grpo_trainer import nanmax, nanmin, nanstd
@@ -39,11 +38,12 @@
3938
from swift.llm.template.template_inputs import StdTemplateInputs
4039
from swift.plugin import loss_scale_map, multi_turns, orms, rm_plugins
4140
from swift.plugin.multi_turn import MultiTurnScheduler
42-
from swift.utils import (JsonlWriter, empty_cache, get_current_device, get_device, get_logger, is_vllm_available,
43-
is_wandb_available, seed_worker, unwrap_model_for_generation)
41+
from swift.utils import (JsonlWriter, empty_cache, get_current_device, get_device, get_logger, is_swanlab_available,
42+
is_vllm_available, is_wandb_available, seed_worker, unwrap_model_for_generation)
4443
from ..mixin import SwiftMixin
4544
from .rlhf_mixin import RLHFTrainerMixin
46-
from .utils import _ForwardRedirection, patch_lora_merge, patch_lora_unmerge
45+
from .utils import (_ForwardRedirection, patch_lora_merge, patch_lora_unmerge, patch_profiling_context,
46+
patch_profiling_decorator)
4747
from .vllm_client import VLLMClient
4848

4949
del HFGRPOTrainer.__init__
@@ -52,6 +52,8 @@
5252
logger = get_logger()
5353
if is_wandb_available():
5454
import wandb
55+
if is_swanlab_available():
56+
import swanlab
5557

5658
InputsType = List[Dict[str, Union[torch.Tensor, Any]]]
5759
# tuple: (messages, finish_reason)
@@ -325,7 +327,7 @@ def cyclic_iter(iterable):
325327
# flag indicating whether the evaluation has started
326328
self.eval_flag = False
327329

328-
@profiling_decorator
330+
@patch_profiling_decorator
329331
def _prepare_inputs(self, generation_batch: dict[str, Union[torch.Tensor,
330332
Any]]) -> dict[str, Union[torch.Tensor, Any]]:
331333
# Prepares inputs for model training/evaluation by managing completion generation and batch handling.
@@ -479,7 +481,7 @@ def _template_context(self, template: Template):
479481
template.set_mode(mode)
480482
template.max_length = max_length
481483

482-
@profiling_decorator
484+
@patch_profiling_decorator
483485
def _move_model_to_vllm(self, skip_async_check=False):
484486
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
485487
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
@@ -906,7 +908,7 @@ def _score_completions(self, inputs: InputsType) -> Tuple[torch.Tensor, torch.Te
906908

907909
for i, (reward_func, reward_model_plugin, reward_func_name) in enumerate(
908910
zip(self.reward_funcs, self.reward_model_plugins, self.reward_func_names)):
909-
with profiling_context(self, reward_func_name):
911+
with patch_profiling_context(self, reward_func_name):
910912
# reward model
911913
if isinstance(reward_func, nn.Module):
912914
output_reward_func = reward_model_plugin(inputs=inputs)
@@ -1110,7 +1112,7 @@ def _apply_chat_template_to_messages_list(self, messages_list: InputsType):
11101112
prompts_text.append(''.join(processed_context))
11111113
return prompts_text
11121114

1113-
@profiling_decorator
1115+
@patch_profiling_decorator
11141116
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
11151117
# Compute the per-token log probabilities for the model, return_outputs=True in mini-batch training
11161118
if isinstance(inputs, list):
@@ -1275,7 +1277,7 @@ def _padding_free_output_hook(module, args, kwargs, result):
12751277
remove_handle2.remove()
12761278

12771279
# Get the per-token log probabilities for the completions for the model and the reference model
1278-
@profiling_decorator
1280+
@patch_profiling_decorator
12791281
def _get_per_token_logps(self, model, inputs):
12801282
from trl.trainer.utils import selective_log_softmax
12811283
logits_to_keep = inputs['logits_to_keep']
@@ -1305,7 +1307,7 @@ def _get_per_token_logps(self, model, inputs):
13051307
input_ids = input_ids[:, -logits_to_keep:]
13061308
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
13071309

1308-
@profiling_decorator
1310+
@patch_profiling_decorator
13091311
def _get_last_hidden_state(self, unwrapped_model, inputs, logits_to_keep):
13101312
# unwrap the model to access the model.model
13111313
if is_peft_model(unwrapped_model):
@@ -1399,7 +1401,7 @@ def _engine_infer(
13991401
*,
14001402
use_tqdm: Optional[bool] = False,
14011403
) -> List[ChatCompletionResponse]:
1402-
with profiling_context(self, 'generate'):
1404+
with patch_profiling_context(self, 'generate'):
14031405
if self.vllm_mode == 'server':
14041406
request_keys = ['messages', 'images', 'audios', 'videos', 'tools', 'objects']
14051407

@@ -1586,6 +1588,16 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non
15861588
df = df.drop_duplicates(subset=['prompt'])
15871589
wandb.log({'completions': wandb.Table(dataframe=df)})
15881590

1591+
if self.args.report_to and 'swanlab' in self.args.report_to and swanlab.get_run() is not None:
1592+
headers = list(table.keys())
1593+
rows = []
1594+
for i in range(len(table['step'])):
1595+
row = []
1596+
for header in headers:
1597+
row.append(table[header][i])
1598+
rows.append(row)
1599+
swanlab.log({'completions': swanlab.echarts.Table().add(headers, rows)})
1600+
15891601
def is_async_generate_eval_rollout_done(self):
15901602
return not self.eval_flag or not self.eval_queue.empty()
15911603

swift/trainers/rlhf_trainer/reward_trainer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,10 @@ def visualize_samples(self, num_print_samples: int):
7777

7878
if wandb.run is not None:
7979
wandb.log({'completions': wandb.Table(dataframe=df)})
80+
81+
if 'swanlab' in self.args.report_to:
82+
import swanlab
83+
if swanlab.get_run() is not None:
84+
swanlab_table = swanlab.echarts.Table()
85+
swanlab_table.add(headers=df.columns.tolist(), rows=df.values.tolist())
86+
swanlab.log({'completions': swanlab_table})

swift/trainers/rlhf_trainer/utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
import functools
3+
import time
24
from contextlib import contextmanager
35
from types import MethodType
46
from typing import Any, Optional
@@ -8,6 +10,13 @@
810
from peft.tuners.lora import LoraLayer
911
from torch import nn
1012

13+
from swift.utils import is_swanlab_available, is_wandb_available
14+
15+
if is_wandb_available():
16+
import wandb
17+
if is_swanlab_available():
18+
import swanlab
19+
1120

1221
def round_robin(num_reqs, num_workers):
1322
"""Distribute requests evenly across workers using round-robin algorithm.
@@ -125,6 +134,32 @@ def unmerge_patched(self):
125134
del module.unmerge_origin
126135

127136

137+
@contextmanager
138+
def patch_profiling_context(trainer, name: str):
139+
start_time = time.perf_counter()
140+
yield
141+
end_time = time.perf_counter()
142+
duration = end_time - start_time
143+
144+
profiling_metrics = {f'profiling/Time taken: {trainer.__class__.__name__}.{name}': duration}
145+
146+
if 'wandb' in trainer.args.report_to and wandb.run is not None and trainer.accelerator.is_main_process:
147+
wandb.log(profiling_metrics)
148+
149+
if 'swanlab' in trainer.args.report_to and swanlab.get_run() is not None and trainer.accelerator.is_main_process:
150+
swanlab.log(profiling_metrics)
151+
152+
153+
def patch_profiling_decorator(func):
154+
155+
@functools.wraps(func)
156+
def wrapper(self, *args, **kwargs):
157+
with patch_profiling_context(self, func.__name__):
158+
return func(*args, **kwargs)
159+
160+
return wrapper
161+
162+
128163
class _ForwardRedirection:
129164
"""Implements the `forward-redirection`.
130165
Taken from Pytorch-lightning:

0 commit comments

Comments
 (0)