Skip to content

Commit 773e392

Browse files
committed
fix rollout, action mask, attention mask bugs
1 parent 5d6aecf commit 773e392

File tree

10 files changed

+80
-45
lines changed

10 files changed

+80
-45
lines changed

applications/ColossalChat/coati/distributed/agent/agentic_producer.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import ray
88
from coati.distributed.agent.base import BaseAgenticProducer
9-
from transformers import AutoTokenizer
109

1110
DEFAULT_SYSTEM_MESSAGE = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The Assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <reason> </reason> and <answer> </answer> tags, respectively, i.e., <reason> reasoning process here </reason><answer> answer here </answer>."""
1211

@@ -88,13 +87,6 @@ def __init__(
8887
self.tool_workers = tool_workers
8988
self.agentic_config = model_config if not agentic_config else agentic_config
9089
self.agentic_config.update({"model": model_config["path"]})
91-
tokenizer_path = None
92-
if tokenizer_config and "path" in tokenizer_config:
93-
tokenizer_path = tokenizer_config["path"]
94-
elif "path" in model_config:
95-
tokenizer_path = model_config["path"]
96-
assert tokenizer_path is not None, "Tokenizer path must be provided either in tokenizer_config or model_config."
97-
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
9890
self.tools_schema = []
9991
self.tool_call_budget = self.agentic_config.get("tool_call_budget", 3)
10092
self.llm_call_budget = self.agentic_config.get("llm_call_budget", 10)
@@ -258,6 +250,7 @@ def _run_agentic_pipeline(self, messages):
258250
)
259251
)
260252
llm_call_count += 1
253+
self.consumer_global_step = response.pop("consumer_global_step")
261254
response_input_ids = response["input_ids"]
262255
logprobs = response["action_log_probs"]
263256
response_text = self.tokenizer.decode(

applications/ColossalChat/coati/distributed/agent/base.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -135,15 +135,13 @@ def rollout(self, **kwargs) -> Dict[str, torch.Tensor]:
135135
input_ids = torch.nn.functional.pad(
136136
response_input_ids, (to_pad_left, to_pad_right), "constant", value=self.tokenizer.pad_token_id
137137
) # [1, max_length]
138-
attention_mask = torch.nn.functional.pad(
139-
torch.ones_like(response_input_ids), (to_pad_left, to_pad_right), "constant", value=0
140-
) # [1, max_length]
141-
action_mask = torch.nn.functional.pad(
142-
torch.ones(size=(1, response_length)), (0, to_pad_right), "constant", value=0
143-
) # [1, max_length-prompt_length]
138+
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).int() # [1, max_length]
139+
action_mask = input_ids[:, max_prompt_length:].ne(self.tokenizer.pad_token_id).int()
144140
rollouts["attention_mask"].append(attention_mask)
145141
rollouts["action_mask"].append(action_mask)
146-
truncated_logprobs = logprobs[:, :, prompt_length : prompt_length + self.generate_config["max_tokens"]]
142+
truncated_logprobs = logprobs[
143+
:, :, prompt_length : prompt_length + self.generate_config["max_tokens"]
144+
] # truncate to max_new_tokens
147145
logprobs_padded = torch.nn.functional.pad(
148146
truncated_logprobs,
149147
(0, self.generate_config["max_tokens"] - truncated_logprobs.size(-1)),
@@ -177,7 +175,8 @@ def rollout(self, **kwargs) -> Dict[str, torch.Tensor]:
177175
"rollout": self.tokenizer.batch_decode(
178176
rollouts["input_ids"][:, 0], skip_special_tokens=True
179177
),
180-
}
178+
},
179+
ensure_ascii=False,
181180
)
182181
+ "\n"
183182
)

applications/ColossalChat/coati/distributed/agent/math_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def run_python_code(code: str) -> str:
2020
code = code.replace("```python", "```", 1).strip()
2121
if code.startswith("```py"): # qwen3 uses ```py
2222
code = code.replace("```py", "```", 1).strip()
23-
return python_repl.run(code, timeout=20)
23+
return python_repl.run(code, timeout=30)
2424

2525

2626
repl_tool = Tool(

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,6 @@ def loop(self) -> None:
325325
) # for setting start index when resuming training
326326
if self.rank == 0:
327327
print(f"Saved model checkpoint at step {step + 1} in folder {self.save_dir}")
328-
# breakpoint()
329328
if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and (
330329
episode != 0 or step >= self.n_behind
331330
):

applications/ColossalChat/coati/distributed/inference_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ async def generate(
409409
log_probs[generation_id].extend(p)
410410
self.profiler.exit(f"vllm generate {request_id}")
411411
# pad them
412-
max_len = self.sample_params.max_tokens
412+
max_len = sample_params.max_tokens
413413
action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype)
414414

415415
for i, new_token_ids in enumerate(out_tokens):

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def launch_distributed(
6161
eval_interval: int = 100,
6262
eval_save_dir: Optional[str] = None,
6363
eval_generation_config: Optional[Dict[str, Any]] = None,
64-
log_rollout_interval: int = 20,
64+
log_rollout_interval: int = 1,
6565
rollout_save_dir: str = "./rollout",
6666
enable_profiling: bool = False,
6767
n_behind: int = 0,

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,14 @@ def __init__(
9393
reward_model_kwargs = {
9494
k: v
9595
for k, v in grpo_config.items()
96-
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length", "code_verifier_api_url"]
96+
if k
97+
in [
98+
"soft_over_length_punishment",
99+
"max_new_tokens",
100+
"cache_length",
101+
"code_verifier_api_url",
102+
"forced_patterns",
103+
]
97104
}
98105
self.response_format_tags = grpo_config.get("response_format_tags", None)
99106
if producer_idx == 0 and rollout_log_file is not None:
@@ -103,7 +110,7 @@ def __init__(
103110
)
104111
else:
105112
os.makedirs(os.path.dirname(rollout_log_file), exist_ok=True)
106-
self.rollout_log_file = open(rollout_log_file, "w", encoding="utf8")
113+
self.rollout_log_file = open(rollout_log_file, "a", encoding="utf8")
107114
if self.producer_idx == 0:
108115
self.wandb_run = wandb.init(
109116
project=project_name,
@@ -257,6 +264,9 @@ def sync_model(self, episode, step) -> None:
257264
state_dict = ray_broadcast_tensor_dict(
258265
None, self.num_producers, device=self.device, group_name="sync_model"
259266
)
267+
print(
268+
f"[P{self.producer_idx}] Sync model episode {episode} step {(step + 1) // self.num_microbatches - 1} done"
269+
)
260270
if "consumer_global_step" in state_dict:
261271
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
262272
self.load_state_dict(state_dict)
@@ -495,7 +505,8 @@ def rollout(self, input_ids, attention_mask, **kwargs):
495505
"rollout": self.tokenizer.batch_decode(
496506
rollouts["input_ids"][:, 0], skip_special_tokens=True
497507
),
498-
}
508+
},
509+
ensure_ascii=False,
499510
)
500511
+ "\n"
501512
)
@@ -580,8 +591,10 @@ def __init__(
580591
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
581592
self.eval_generation_config.update(eval_generation_config)
582593
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
583-
self.ready_processes = 0
584-
self.condition = asyncio.Condition()
594+
self.ready_processes_sync_model = 0
595+
self.ready_processes_sync_data = 0
596+
self.sync_model_condition = asyncio.Condition()
597+
self.sync_data_condition = asyncio.Condition()
585598
self.data_ready_for_sending = []
586599

587600
@torch.no_grad()
@@ -610,6 +623,7 @@ async def generate(self, input_ids, attention_mask, **kwargs):
610623
).cpu() # CUDA tensor is not serializable by ray
611624
for k in rollouts[0].keys()
612625
}
626+
rollouts["consumer_global_step"] = self.consumer_global_step
613627
return rollouts
614628

615629
@torch.no_grad()
@@ -631,33 +645,33 @@ async def async_sync_model(self, episode, step, num_processes: int = 1) -> None:
631645
Asyncronous version to sync model from consumer to producer.
632646
called by another producer, such as agentic producer.
633647
"""
634-
async with self.condition:
635-
self.ready_processes += 1
648+
async with self.sync_model_condition:
649+
self.ready_processes_sync_model += 1
636650
# Wait until all processes are ready
637-
if self.ready_processes < num_processes:
638-
await self.condition.wait()
651+
if self.ready_processes_sync_model < num_processes:
652+
await self.sync_model_condition.wait()
639653

640-
# Only one process should reset `ready_processes` and perform the sync
641-
if self.ready_processes == num_processes:
642-
self.ready_processes = 0
643-
self.condition.notify_all() # Notify all waiting processes
654+
# Only one process should reset `ready_processes_sync_model` and perform the sync
655+
if self.ready_processes_sync_model == num_processes:
656+
self.ready_processes_sync_model = 0
657+
self.sync_model_condition.notify_all() # Notify all waiting processes
644658
self.sync_model(episode, step)
645659

646660
async def async_sync_data(self, data: Dict[str, torch.Tensor], num_processes: int = 1) -> None:
647661
# merge data dict
648-
async with self.condition:
649-
self.ready_processes += 1
662+
async with self.sync_data_condition:
663+
self.ready_processes_sync_data += 1
650664
if data:
651665
self.data_ready_for_sending.append(data)
652666

653667
# Wait until all processes are ready
654-
if self.ready_processes < num_processes:
655-
await self.condition.wait()
668+
if self.ready_processes_sync_data < num_processes:
669+
await self.sync_data_condition.wait()
656670

657671
# Only one process should reset `ready_processes` and perform the sync
658-
if self.ready_processes == num_processes: # wait for all producers to join
659-
self.ready_processes = 0
660-
self.condition.notify_all()
672+
if self.ready_processes_sync_data == num_processes: # wait for all producers to join
673+
self.ready_processes_sync_data = 0
674+
self.sync_data_condition.notify_all()
661675
# merge data for sending
662676
if len(self.data_ready_for_sending) >= 1:
663677
batch_rollout_data = {}
@@ -853,7 +867,8 @@ async def rollout(self, input_ids, attention_mask, **kwargs):
853867
"rollout": self.tokenizer.batch_decode(
854868
rollouts["input_ids"][:, 0], skip_special_tokens=True
855869
),
856-
}
870+
},
871+
ensure_ascii=False,
857872
)
858873
+ "\n"
859874
)

applications/ColossalChat/coati/distributed/reward/reward_fn.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020

2121
import json
22+
import re
2223

2324
import torch
2425
from latex2sympy2_extended import NormalizationConfig
@@ -126,6 +127,12 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
126127

127128
format_valid = validate_response_structure(processed_str, kwargs["tags"])
128129

130+
if "forced_patterns" in kwargs and kwargs["forced_patterns"]:
131+
forced_patterns = kwargs["forced_patterns"]
132+
format_valid = format_valid and all(
133+
[re.search(pattern, decoded_final_answer) is not None for pattern in forced_patterns]
134+
)
135+
129136
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
130137
if final_answer is not None:
131138
if eval_mode or format_valid:
@@ -161,7 +168,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
161168
tokenizer = kwargs["tokenizer"]
162169
eval_mode = kwargs.get("eval_mode", False)
163170
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
164-
acc_score = 10.0
171+
acc_score = 1.0
165172
reward = torch.tensor(0.0)
166173
format_acc = torch.tensor(0.0)
167174
ans_acc = torch.tensor(0.0)
@@ -182,15 +189,18 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
182189
raise ValueError("no gt_answer is provided, please check your training dataset.")
183190

184191
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
185-
# print(f"decoded_final_answer: {decoded_final_answer[-100:]}", gt_answer)
186192
final_answer = extract_boxed_solution(decoded_final_answer)
187193
format_valid = final_answer is not None
188194
if "tags" in kwargs and kwargs["tags"]:
189195
tags = kwargs["tags"]
190196
format_valid = format_valid and all(
191197
[decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags]
192198
)
193-
199+
if "forced_patterns" in kwargs and kwargs["forced_patterns"]:
200+
forced_patterns = kwargs["forced_patterns"]
201+
format_valid = format_valid and all(
202+
[re.search(pattern, decoded_final_answer) is not None for pattern in forced_patterns]
203+
)
194204
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
195205
if final_answer is not None:
196206
if eval_mode or format_valid:
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"chat_template": "{%- if tools %}\\n {{- \'<|im_start|>system\\\\n\' }}\\n {%- if messages[0].role == \'system\' %}\\n {{- messages[0].content + \'\\\\n\\\\n\' }}\\n {%- endif %}\\n {{- \\"# Tools\\\\n\\\\nYou may call one or more functions to assist with the user query.\\\\n\\\\nYou are provided with function signatures within <tools></tools> XML tags:\\\\n<tools>\\" }}\\n {%- for tool in tools %}\\n {{- \\"\\\\n\\" }}\\n {{- tool | tojson }}\\n {%- endfor %}\\n {{- \\"\\\\n</tools>\\\\n\\\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\\\n<tool_call>\\\\n{\\\\\\"name\\\\\\": <function-name>, \\\\\\"arguments\\\\\\": <args-json-object>}\\\\n</tool_call><|im_end|>\\\\n\\" }}\\n{%- else %}\\n {%- if messages[0].role == \'system\' %}\\n {{- \'<|im_start|>system\\\\n\' + messages[0].content + \'<|im_end|>\\\\n\' }}\\n {%- endif %}\\n{%- endif %}\\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\\n{%- for message in messages[::-1] %}\\n {%- set index = (messages|length - 1) - loop.index0 %}\\n {%- if ns.multi_step_tool and message.role == \\"user\\" and message.content is string and not(message.content.startswith(\'<tool_response>\') and message.content.endswith(\'</tool_response>\')) %}\\n {%- set ns.multi_step_tool = false %}\\n {%- set ns.last_query_index = index %}\\n {%- endif %}\\n{%- endfor %}\\n{%- for message in messages %}\\n {%- if message.content is string %}\\n {%- set content = message.content %}\\n {%- else %}\\n {%- set content = \'\' %}\\n {%- endif %}\\n {%- if (message.role == \\"user\\") or (message.role == \\"system\\" and not loop.first) %}\\n {{- \'<|im_start|>\' + message.role + \'\\\\n\' + content + \'<|im_end|>\' + \'\\\\n\' }}\\n {%- elif message.role == \\"assistant\\" %}\\n {{- \'<|im_start|>\' + message.role + \'\\\\n\' + content }}\\n {%- if message.tool_calls %}\\n {%- for tool_call in message.tool_calls %}\\n {%- if (loop.first and content) or (not loop.first) %}\\n {{- \'\\\\n\' }}\\n {%- endif %}\\n {%- if tool_call.function %}\\n {%- set tool_call = tool_call.function %}\\n {%- endif %}\\n {{- \'<tool_call>\\\\n{\\"name\\": \\"\' }}\\n {{- tool_call.name }}\\n {{- \'\\", \\"arguments\\": \' }}\\n {%- if tool_call.arguments is string %}\\n {{- tool_call.arguments }}\\n {%- else %}\\n {{- tool_call.arguments | tojson }}\\n {%- endif %}\\n {{- \'}\\\\n</tool_call>\' }}\\n {%- endfor %}\\n {%- endif %}\\n {{- \'<|im_end|>\\\\n\' }}\\n {%- elif message.role == \\"tool\\" %}\\n {%- if loop.first or (messages[loop.index0 - 1].role != \\"tool\\") %}\\n {{- \'<|im_start|>user\' }}\\n {%- endif %}\\n {{- \'\\\\n<tool_response>\\\\n\' }}\\n {{- content }}\\n {{- \'\\\\n</tool_response>\' }}\\n {%- if loop.last or (messages[loop.index0 + 1].role != \\"tool\\") %}\\n {{- \'<|im_end|>\\\\n\' }}\\n {%- endif %}\\n {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n {{- \'<|im_start|>assistant\\\\n\' }}\\n {%- if enable_thinking is defined and enable_thinking is false %}\\n {{- \'<think>\\\\n\\\\n</think>\\\\n\\\\n\' }}\\n {%- endif %}\\n{%- endif %}",
3+
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
4+
"stop_ids": [
5+
7
6+
],
7+
"end_of_assistant": "<|im_end|>"
8+
}

applications/ColossalChat/rl_example.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@
131131
default=1.0,
132132
help="Top p for sampling. Please check the generation arguments documentation for your backend.",
133133
)
134+
parser.add_argument("-ct", "--chat-template", type=str, default=None, help="Chat template to use for the model.")
134135
parser.add_argument("-s", "--system-prompt", type=str, default=None, help="System prompt for data construction.")
135136
parser.add_argument("-mnt", "--max-new-tokens", type=int, default=1024 * 4 - 512, help="Max length for generation.")
136137
parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.")
@@ -381,11 +382,20 @@
381382
"llm_call_budget": 10,
382383
"max_tokens": 2048,
383384
}
385+
grpo_config["forced_patterns"] = [
386+
r"<tool_response>\n.+\n</tool_response>"
387+
] # force at least one correct tool call
384388
else:
385389
raise ValueError(f"Unsupported agentic model type: {args.agentic_type}")
386390
else:
387391
agentic_config = None
388392

393+
tokenizer_config = {
394+
"path": args.model,
395+
"trust_remote_code": True,
396+
"chat_template": args.chat_template,
397+
}
398+
389399
launch_distributed(
390400
num_producers=args.num_inferencer,
391401
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.producer_tensor_parallel_size)
@@ -407,6 +417,7 @@
407417
train_model_config=train_model_config,
408418
grpo_config=grpo_config,
409419
agentic_config=agentic_config,
420+
tokenizer_config=tokenizer_config,
410421
plugin_config={
411422
"tp_size": args.tensor_parallel_size,
412423
"pp_size": args.pipeline_parallel_size,
@@ -434,7 +445,7 @@
434445
eval_interval=args.eval_interval,
435446
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
436447
eval_generation_config=eval_generation_config,
437-
log_rollout_interval=20,
448+
log_rollout_interval=1,
438449
rollout_save_dir=args.rollout_save_dir,
439450
enable_profiling=args.enable_profiling,
440451
n_behind=args.n_behind,

0 commit comments

Comments
 (0)