diff --git a/swift/trainers/rlhf_trainer/reward_trainer.py b/swift/trainers/rlhf_trainer/reward_trainer.py index 0355343909..069d0b0cc5 100644 --- a/swift/trainers/rlhf_trainer/reward_trainer.py +++ b/swift/trainers/rlhf_trainer/reward_trainer.py @@ -76,3 +76,9 @@ def visualize_samples(self, num_print_samples: int): if wandb.run is not None: wandb.log({'completions': wandb.Table(dataframe=df)}) + elif 'neptune' in self.args.report_to: + import neptune + from neptune.types import File + + if neptune.run is not None: + neptune.run['completions'].upload(File.as_html(df))