Skip to content

Commit fe1f429

Browse files
committed
Merge branch 'grpo-latest-rebase-main' of https://github.com/hpcaitech/ColossalAI into grpo-latest-rebase-main
2 parents 4152c0b + 73bdfd8 commit fe1f429

File tree

5 files changed

+10
-10
lines changed

5 files changed

+10
-10
lines changed

applications/ColossalChat/coati/experience_maker/naive.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ def make_experience(
119119
generate_kwargs["stop_token_ids"] = stop_token_ids
120120
# Hack: manually initialize cache_position to address transformer version conflict
121121
if generate_kwargs.get("cache_position", None) is None and generate_kwargs.get("use_cache", False) is True:
122-
generate_kwargs["cache_position"] = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
122+
generate_kwargs["cache_position"] = torch.arange(
123+
0, input_ids.shape[1], dtype=torch.long, device=input_ids.device
124+
)
123125
torch.manual_seed(41) # for tp, gurantee the same input for reward model
124126

125127
if self.use_grpo and self.num_generation > 1:

applications/ColossalChat/coati/trainer/kto.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,14 @@ def _train(self, epoch: int):
193193
loss_mean = all_reduce_mean(tensor=loss)
194194
chosen_reward_mean = chosen_rewards.mean()
195195
chosen_rewards_list = [
196-
torch.tensor(0, dtype=chosen_reward_mean.dtype, device=loss.device) for _ in range(dist.get_world_size())
196+
torch.tensor(0, dtype=chosen_reward_mean.dtype, device=loss.device)
197+
for _ in range(dist.get_world_size())
197198
]
198199
dist.all_gather(chosen_rewards_list, chosen_reward_mean)
199200
rejected_reward_mean = rejected_rewards.mean()
200201
rejected_rewards_list = [
201-
torch.tensor(0, dtype=rejected_reward_mean.dtype, device=loss.device) for _ in range(dist.get_world_size())
202+
torch.tensor(0, dtype=rejected_reward_mean.dtype, device=loss.device)
203+
for _ in range(dist.get_world_size())
202204
]
203205
dist.all_gather(rejected_rewards_list, rejected_reward_mean)
204206
chosen_rewards_list = [i for i in chosen_rewards_list if not i.isnan()]

applications/ColossalChat/examples/training_scripts/train_grpo.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,7 @@ def train(args):
8989
actor = AutoModelForCausalLM.from_pretrained(args.pretrain, trust_remote_code=True)
9090
if args.rm_pretrain:
9191
reward_model = RewardModel(args.rm_pretrain, trust_remote_code=True)
92-
ref_model = AutoModelForCausalLM.from_pretrained(
93-
args.pretrain, trust_remote_code=True
94-
)
92+
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain, trust_remote_code=True)
9593

9694
if args.lora_config is not None:
9795
actor = convert_to_lora_module(actor, lora_config=lora_config)

applications/ColossalChat/examples/training_scripts/train_ppo.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,7 @@ def train(args):
102102
coordinator.print_on_master(msg="Flash-attention enabled successfully")
103103
else:
104104
actor = AutoModelForCausalLM.from_pretrained(args.pretrain, trust_remote_code=True)
105-
ref_model = AutoModelForCausalLM.from_pretrained(
106-
args.pretrain, trust_remote_code=True
107-
)
105+
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain, trust_remote_code=True)
108106
if not args.no_neural_reward_model:
109107
reward_model = RewardModel(args.rm_pretrain, trust_remote_code=True)
110108
critic = Critic(args.rm_pretrain)

applications/ColossalChat/tests/test_train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ for lora_rank in ${LORA_RANK[@]}; do
631631
done
632632
done
633633
done
634-
634+
635635

636636
echo "[Test]: testing ORPO ..."
637637

0 commit comments

Comments
 (0)