Skip to content

Commit 65f5289

Browse files
committed
support resume training, still buggy
1 parent 5c5cb18 commit 65f5289

File tree

16 files changed

+368
-124
lines changed

16 files changed

+368
-124
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(
5555
self.enable_profiling = enable_profiling
5656
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
5757
self.num_microbatches = batch_size // minibatch_size
58+
self.checkpoint_path = model_config.pop("checkpoint_path", None)
5859

5960
self.model_config = model_config
6061
self.plugin_config = plugin_config
@@ -143,6 +144,26 @@ def calculate_effective_group_to_raw_group_mapping(self, step):
143144
return effective_group_to_raw_group_mapping
144145

145146
def loop(self) -> None:
147+
self.profiler.enter("sync_model")
148+
torch.cuda.empty_cache()
149+
state_dict = self.state_dict()
150+
if self.pp_size > 1:
151+
if self.tp_rank == 0 and self.dp_rank == 0:
152+
ray_broadcast_tensor_dict(
153+
state_dict,
154+
src=self.num_producers,
155+
device=self.device,
156+
group_name=f"sync_model_{self.pp_rank}",
157+
)
158+
else:
159+
if self.rank == 0:
160+
ray_broadcast_tensor_dict(
161+
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
162+
)
163+
del state_dict
164+
torch.cuda.empty_cache()
165+
self.profiler.exit("sync_model")
166+
146167
print(
147168
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
148169
)
@@ -286,7 +307,7 @@ def loop(self) -> None:
286307
if self.rank == 0:
287308
print(f"Start saving policy model at step {step + 1}.")
288309
save_path = os.path.join(self.save_dir, f"modeling-episode-{episode}-step-{step + 1}")
289-
self.booster.save_model(self.policy_model, save_path, shard=True)
310+
self.booster.save_model(self.policy_model, save_path, shard=True, use_safetensors=True)
290311
if self.rank == 0:
291312
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
292313

@@ -365,7 +386,7 @@ def __init__(
365386
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
366387
self.model.train()
367388
self.model.gradient_checkpointing_enable()
368-
self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3)
389+
self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3, weight_decay=0.01)
369390
self.accum_loss = torch.zeros(1, device=self.device)
370391

371392
def setup(self):

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,11 @@ def __init__(
7272
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
7373
self.policy_model.train()
7474
self.policy_model.gradient_checkpointing_enable()
75-
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
75+
self.optimizer = HybridAdam(
76+
self.policy_model.parameters(),
77+
lr=grpo_config.get("lr", 1e-6),
78+
weight_decay=grpo_config.get("weight_decay", 0.01),
79+
)
7680
self.accum_loss = torch.zeros(1, device=self.device)
7781
self.accum_kl = torch.zeros(1, device=self.device)
7882
self.accum_entropy = torch.zeros(1, device=self.device)
@@ -153,6 +157,8 @@ def setup(self):
153157
)
154158
if self.policy_loss_fn.beta > 0:
155159
self.reference_model, *_ = self.booster.boost(self.reference_model)
160+
if self.checkpoint_path is not None:
161+
self.booster.load_model(self.policy_model, self.checkpoint_path)
156162
self.plugin.logger.set_level("ERROR")
157163

158164
def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:

applications/ColossalChat/coati/distributed/launch_zero_bubble.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ def launch_distributed(
5555
eval_dataset_config: Optional[Dict[str, Any]] = None,
5656
eval_interval: int = 100,
5757
eval_save_dir: Optional[str] = None,
58-
eval_generation_config: Optional[Dict[str, Any]] = None,
5958
log_rollout_interval: int = 20,
6059
rollout_save_dir: str = "./rollout",
6160
enable_profiling: bool = False,
@@ -139,7 +138,6 @@ def launch_distributed(
139138
eval_interval=eval_interval,
140139
grpo_config=grpo_config,
141140
eval_save_dir=eval_save_dir,
142-
eval_generation_config=eval_generation_config,
143141
project_name=project_name,
144142
run_name=run_name,
145143
wandb_group_name=wandb_group_name,

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,28 @@ def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
203203
raise NotImplementedError
204204

205205
def loop(self) -> None:
206+
207+
torch.cuda.empty_cache()
208+
self.profiler.enter("sync_model")
209+
if self.consumer_pp_size > 1:
210+
for pp_idx in range(self.consumer_pp_size):
211+
state_dict = ray_broadcast_tensor_dict(
212+
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
213+
)
214+
if "consumer_global_step" in state_dict:
215+
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
216+
self.load_state_dict(state_dict)
217+
else:
218+
state_dict = ray_broadcast_tensor_dict(
219+
None, self.num_producers, device=self.device, group_name="sync_model"
220+
)
221+
if "consumer_global_step" in state_dict:
222+
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
223+
self.load_state_dict(state_dict)
224+
self.profiler.exit("sync_model")
225+
del state_dict
226+
torch.cuda.empty_cache()
227+
206228
num_update_per_episode = len(self.train_dataloader) // self.num_microbatches
207229
num_valid_microbatches = num_update_per_episode * self.num_microbatches
208230

applications/ColossalChat/coati/distributed/reward/reward_fn.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@
2525
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
2626

2727
from .code_reward.utils import check_correctness_code_api as check_correctness_code
28-
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
28+
from .reward_utils import (
29+
extract_boxed_solution,
30+
extract_solution,
31+
find_infinite_loop_start,
32+
validate_response_structure,
33+
)
2934

3035
CANNOT_PARSE_GT_ANSWER = -1
3136
CANNOT_PARSE_PREDICTION = -2
@@ -122,6 +127,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
122127

123128
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
124129

130+
repetition_reward = 1.0 if detect_repetition(decoded_final_answer) == [] else 0.0
131+
125132
final_answer, processed_str = extract_solution(decoded_final_answer)
126133

127134
format_valid = validate_response_structure(processed_str, kwargs["tags"])
@@ -137,6 +144,10 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
137144
if format_valid:
138145
format_acc += 1
139146

147+
# Add repetition reward
148+
if not eval_mode:
149+
reward += repetition_reward
150+
140151
# Check if the sequence is over length
141152
if not eval_mode and res_length >= max_new_tokens:
142153
reward *= 0.0
@@ -182,6 +193,8 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
182193
raise ValueError("no gt_answer is provided, please check your training dataset.")
183194

184195
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
196+
print(f"Decoded final answer: {decoded_final_answer[-500:]}")
197+
repetition_score = find_infinite_loop_start(input_ids[s : e + 1], min_repeats=2, distance=False)
185198

186199
final_answer = extract_boxed_solution(decoded_final_answer)
187200
format_valid = final_answer is not None
@@ -202,6 +215,10 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
202215
if format_valid:
203216
format_acc += 1
204217

218+
if not repetition_score > 0 and not eval_mode:
219+
# award for non-repetition
220+
reward += 2
221+
205222
# Check if the sequence is over length
206223
if not eval_mode and res_length >= max_new_tokens:
207224
reward *= 0.0

applications/ColossalChat/coati/distributed/reward/reward_utils.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
# limitations under the License.
1515

1616
import re
17-
from typing import Dict, Optional, Tuple
17+
from typing import Dict, List, Optional, Tuple
18+
19+
import torch
1820

1921

2022
def validate_response_structure(processed_str: str, tags: Dict = None) -> bool:
@@ -122,3 +124,51 @@ def extract_boxed_solution(text: str) -> Optional[str]:
122124
except Exception:
123125
# Any other unexpected error
124126
return None
127+
128+
129+
import Levenshtein
130+
131+
132+
def is_similar(seq1: List[int], seq2: List[int], threshold: float = 0.9) -> bool:
133+
ratio = Levenshtein.ratio(seq1, seq2)
134+
return ratio >= threshold
135+
136+
137+
def find_infinite_loop_start(token_ids: List[int], min_repeats: int = 2, distance: bool = False) -> float:
138+
n = len(token_ids)
139+
140+
# Step 1: Detect the repeating segment at the end using two pointers
141+
longest_valid_length = 0
142+
start_of_loop = n
143+
144+
for length in range(1, n // min_repeats + 1): # Try different phrase lengths
145+
count = 1 # Reset repetition counter
146+
right = n - length # Start comparing from the second last occurrence
147+
148+
while right - length >= 0:
149+
# Check if the current phrase matches the previous phrase
150+
if distance:
151+
if is_similar(token_ids[right - length : right], token_ids[right : right + length]):
152+
count += 1
153+
else:
154+
break # Stop if repetition is broken
155+
else:
156+
# Use torch.equal() for tensor comparison
157+
if torch.equal(token_ids[right - length : right], token_ids[right : right + length]):
158+
count += 1
159+
else:
160+
break # Stop if repetition is broken
161+
162+
right -= length # Move left to check further
163+
164+
if count >= min_repeats: # Found a valid repeating phrase
165+
longest_valid_length = length
166+
start_of_loop = right # This is where the first cycle of the repetition begins
167+
168+
if longest_valid_length == 0:
169+
return 0.0 # No infinite loop found, return repetition ratio as 0
170+
171+
# Step 2: Compute the repetition ratio
172+
repetition_ratio = (n - start_of_loop) / n
173+
174+
return repetition_ratio

0 commit comments

Comments
 (0)