Skip to content

Commit d47c563

Browse files
committed
fix rollout, action mask, attention mask bugs
1 parent b6391bd commit d47c563

File tree

10 files changed

+80
-45
lines changed

10 files changed

+80
-45
lines changed

applications/ColossalChat/coati/distributed/agent/agentic_producer.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import ray
88
from coati.distributed.agent.base import BaseAgenticProducer
9-
from transformers import AutoTokenizer
109

1110
DEFAULT_SYSTEM_MESSAGE = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The Assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <reason> </reason> and <answer> </answer> tags, respectively, i.e., <reason> reasoning process here </reason><answer> answer here </answer>."""
1211

@@ -88,13 +87,6 @@ def __init__(
8887
self.tool_workers = tool_workers
8988
self.agentic_config = model_config if not agentic_config else agentic_config
9089
self.agentic_config.update({"model": model_config["path"]})
91-
tokenizer_path = None
92-
if tokenizer_config and "path" in tokenizer_config:
93-
tokenizer_path = tokenizer_config["path"]
94-
elif "path" in model_config:
95-
tokenizer_path = model_config["path"]
96-
assert tokenizer_path is not None, "Tokenizer path must be provided either in tokenizer_config or model_config."
97-
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
9890
self.tools_schema = []
9991
self.tool_call_budget = self.agentic_config.get("tool_call_budget", 3)
10092
self.llm_call_budget = self.agentic_config.get("llm_call_budget", 10)
@@ -258,6 +250,7 @@ def _run_agentic_pipeline(self, messages):
258250
)
259251
)
260252
llm_call_count += 1
253+
self.consumer_global_step = response.pop("consumer_global_step")
261254
response_input_ids = response["input_ids"]
262255
logprobs = response["action_log_probs"]
263256
response_text = self.tokenizer.decode(

applications/ColossalChat/coati/distributed/agent/base.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -135,15 +135,13 @@ def rollout(self, **kwargs) -> Dict[str, torch.Tensor]:
135135
input_ids = torch.nn.functional.pad(
136136
response_input_ids, (to_pad_left, to_pad_right), "constant", value=self.tokenizer.pad_token_id
137137
) # [1, max_length]
138-
attention_mask = torch.nn.functional.pad(
139-
torch.ones_like(response_input_ids), (to_pad_left, to_pad_right), "constant", value=0
140-
) # [1, max_length]
141-
action_mask = torch.nn.functional.pad(
142-
torch.ones(size=(1, response_length)), (0, to_pad_right), "constant", value=0
143-
) # [1, max_length-prompt_length]
138+
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).int() # [1, max_length]
139+
action_mask = input_ids[:, max_prompt_length:].ne(self.tokenizer.pad_token_id).int()
144140
rollouts["attention_mask"].append(attention_mask)
145141
rollouts["action_mask"].append(action_mask)
146-
truncated_logprobs = logprobs[:, :, prompt_length : prompt_length + self.generate_config["max_tokens"]]
142+
truncated_logprobs = logprobs[
143+
:, :, prompt_length : prompt_length + self.generate_config["max_tokens"]
144+
] # truncate to max_new_tokens
147145
logprobs_padded = torch.nn.functional.pad(
148146
truncated_logprobs,
149147
(0, self.generate_config["max_tokens"] - truncated_logprobs.size(-1)),
@@ -177,7 +175,8 @@ def rollout(self, **kwargs) -> Dict[str, torch.Tensor]:
177175
"rollout": self.tokenizer.batch_decode(
178176
rollouts["input_ids"][:, 0], skip_special_tokens=True
179177
),
180-
}
178+
},
179+
ensure_ascii=False,
181180
)
182181
+ "\n"
183182
)

applications/ColossalChat/coati/distributed/agent/math_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def run_python_code(code: str) -> str:
2020
code = code.replace("```python", "```", 1).strip()
2121
if code.startswith("```py"): # qwen3 uses ```py
2222
code = code.replace("```py", "```", 1).strip()
23-
return python_repl.run(code, timeout=20)
23+
return python_repl.run(code, timeout=30)
2424

2525

2626
repl_tool = Tool(

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,6 @@ def loop(self) -> None:
325325
) # for setting start index when resuming training
326326
if self.rank == 0:
327327
print(f"Saved model checkpoint at step {step + 1} in folder {self.save_dir}")
328-
# breakpoint()
329328
if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and (
330329
episode != 0 or step >= self.n_behind
331330
):

applications/ColossalChat/coati/distributed/inference_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ async def generate(
409409
log_probs[generation_id].extend(p)
410410
self.profiler.exit(f"vllm generate {request_id}")
411411
# pad them
412-
max_len = self.sample_params.max_tokens
412+
max_len = sample_params.max_tokens
413413
action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype)
414414

415415
for i, new_token_ids in enumerate(out_tokens):

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def launch_distributed(
6868
eval_interval: int = 100,
6969
eval_save_dir: Optional[str] = None,
7070
eval_generation_config: Optional[Dict[str, Any]] = None,
71-
log_rollout_interval: int = 20,
71+
log_rollout_interval: int = 1,
7272
rollout_save_dir: str = "./rollout",
7373
enable_profiling: bool = False,
7474
n_behind: int = 0,

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,14 @@ def __init__(
9393
reward_model_kwargs = {
9494
k: v
9595
for k, v in grpo_config.items()
96-
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length", "code_verifier_api_url"]
96+
if k
97+
in [
98+
"soft_over_length_punishment",
99+
"max_new_tokens",
100+
"cache_length",
101+
"code_verifier_api_url",
102+
"forced_patterns",
103+
]
97104
}
98105
self.response_format_tags = grpo_config.get("response_format_tags", None)
99106
if producer_idx == 0 and rollout_log_file is not None:
@@ -103,7 +110,7 @@ def __init__(
103110
)
104111
else:
105112
os.makedirs(os.path.dirname(rollout_log_file), exist_ok=True)
106-
self.rollout_log_file = open(rollout_log_file, "w", encoding="utf8")
113+
self.rollout_log_file = open(rollout_log_file, "a", encoding="utf8")
107114
if self.producer_idx == 0:
108115
self.wandb_run = wandb.init(
109116
project=project_name,
@@ -260,6 +267,9 @@ def sync_model(self, episode, step) -> None:
260267
state_dict = ray_broadcast_tensor_dict(
261268
None, self.num_producers, device=self.device, group_name="sync_model"
262269
)
270+
print(
271+
f"[P{self.producer_idx}] Sync model episode {episode} step {(step + 1) // self.num_microbatches - 1} done"
272+
)
263273
if "consumer_global_step" in state_dict:
264274
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
265275
self.load_state_dict(state_dict)
@@ -498,7 +508,8 @@ def rollout(self, input_ids, attention_mask, **kwargs):
498508
"rollout": self.tokenizer.batch_decode(
499509
rollouts["input_ids"][:, 0], skip_special_tokens=True
500510
),
501-
}
511+
},
512+
ensure_ascii=False,
502513
)
503514
+ "\n"
504515
)
@@ -583,8 +594,10 @@ def __init__(
583594
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
584595
self.eval_generation_config.update(eval_generation_config)
585596
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
586-
self.ready_processes = 0
587-
self.condition = asyncio.Condition()
597+
self.ready_processes_sync_model = 0
598+
self.ready_processes_sync_data = 0
599+
self.sync_model_condition = asyncio.Condition()
600+
self.sync_data_condition = asyncio.Condition()
588601
self.data_ready_for_sending = []
589602

590603
@torch.no_grad()
@@ -613,6 +626,7 @@ async def generate(self, input_ids, attention_mask, **kwargs):
613626
).cpu() # CUDA tensor is not serializable by ray
614627
for k in rollouts[0].keys()
615628
}
629+
rollouts["consumer_global_step"] = self.consumer_global_step
616630
return rollouts
617631

618632
@torch.no_grad()
@@ -634,33 +648,33 @@ async def async_sync_model(self, episode, step, num_processes: int = 1) -> None:
634648
Asyncronous version to sync model from consumer to producer.
635649
called by another producer, such as agentic producer.
636650
"""
637-
async with self.condition:
638-
self.ready_processes += 1
651+
async with self.sync_model_condition:
652+
self.ready_processes_sync_model += 1
639653
# Wait until all processes are ready
640-
if self.ready_processes < num_processes:
641-
await self.condition.wait()
654+
if self.ready_processes_sync_model < num_processes:
655+
await self.sync_model_condition.wait()
642656

643-
# Only one process should reset `ready_processes` and perform the sync
644-
if self.ready_processes == num_processes:
645-
self.ready_processes = 0
646-
self.condition.notify_all() # Notify all waiting processes
657+
# Only one process should reset `ready_processes_sync_model` and perform the sync
658+
if self.ready_processes_sync_model == num_processes:
659+
self.ready_processes_sync_model = 0
660+
self.sync_model_condition.notify_all() # Notify all waiting processes
647661
self.sync_model(episode, step)
648662

649663
async def async_sync_data(self, data: Dict[str, torch.Tensor], num_processes: int = 1) -> None:
650664
# merge data dict
651-
async with self.condition:
652-
self.ready_processes += 1
665+
async with self.sync_data_condition:
666+
self.ready_processes_sync_data += 1
653667
if data:
654668
self.data_ready_for_sending.append(data)
655669

656670
# Wait until all processes are ready
657-
if self.ready_processes < num_processes:
658-
await self.condition.wait()
671+
if self.ready_processes_sync_data < num_processes:
672+
await self.sync_data_condition.wait()
659673

660674
# Only one process should reset `ready_processes` and perform the sync
661-
if self.ready_processes == num_processes: # wait for all producers to join
662-
self.ready_processes = 0
663-
self.condition.notify_all()
675+
if self.ready_processes_sync_data == num_processes: # wait for all producers to join
676+
self.ready_processes_sync_data = 0
677+
self.sync_data_condition.notify_all()
664678
# merge data for sending
665679
if len(self.data_ready_for_sending) >= 1:
666680
batch_rollout_data = {}
@@ -856,7 +870,8 @@ async def rollout(self, input_ids, attention_mask, **kwargs):
856870
"rollout": self.tokenizer.batch_decode(
857871
rollouts["input_ids"][:, 0], skip_special_tokens=True
858872
),
859-
}
873+
},
874+
ensure_ascii=False,
860875
)
861876
+ "\n"
862877
)

applications/ColossalChat/coati/distributed/reward/reward_fn.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020

2121
import json
22+
import re
2223

2324
import torch
2425
from latex2sympy2_extended import NormalizationConfig
@@ -126,6 +127,12 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
126127

127128
format_valid = validate_response_structure(processed_str, kwargs["tags"])
128129

130+
if "forced_patterns" in kwargs and kwargs["forced_patterns"]:
131+
forced_patterns = kwargs["forced_patterns"]
132+
format_valid = format_valid and all(
133+
[re.search(pattern, decoded_final_answer) is not None for pattern in forced_patterns]
134+
)
135+
129136
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
130137
if final_answer is not None:
131138
if eval_mode or format_valid:
@@ -161,7 +168,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
161168
tokenizer = kwargs["tokenizer"]
162169
eval_mode = kwargs.get("eval_mode", False)
163170
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
164-
acc_score = 10.0
171+
acc_score = 1.0
165172
reward = torch.tensor(0.0)
166173
format_acc = torch.tensor(0.0)
167174
ans_acc = torch.tensor(0.0)
@@ -182,15 +189,18 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
182189
raise ValueError("no gt_answer is provided, please check your training dataset.")
183190

184191
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
185-
# print(f"decoded_final_answer: {decoded_final_answer[-100:]}", gt_answer)
186192
final_answer = extract_boxed_solution(decoded_final_answer)
187193
format_valid = final_answer is not None
188194
if "tags" in kwargs and kwargs["tags"]:
189195
tags = kwargs["tags"]
190196
format_valid = format_valid and all(
191197
[decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags]
192198
)
193-
199+
if "forced_patterns" in kwargs and kwargs["forced_patterns"]:
200+
forced_patterns = kwargs["forced_patterns"]
201+
format_valid = format_valid and all(
202+
[re.search(pattern, decoded_final_answer) is not None for pattern in forced_patterns]
203+
)
194204
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
195205
if final_answer is not None:
196206
if eval_mode or format_valid:
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"chat_template": "{%- if tools %}\\n {{- \'<|im_start|>system\\\\n\' }}\\n {%- if messages[0].role == \'system\' %}\\n {{- messages[0].content + \'\\\\n\\\\n\' }}\\n {%- endif %}\\n {{- \\"# Tools\\\\n\\\\nYou may call one or more functions to assist with the user query.\\\\n\\\\nYou are provided with function signatures within <tools></tools> XML tags:\\\\n<tools>\\" }}\\n {%- for tool in tools %}\\n {{- \\"\\\\n\\" }}\\n {{- tool | tojson }}\\n {%- endfor %}\\n {{- \\"\\\\n</tools>\\\\n\\\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\\\n<tool_call>\\\\n{\\\\\\"name\\\\\\": <function-name>, \\\\\\"arguments\\\\\\": <args-json-object>}\\\\n</tool_call><|im_end|>\\\\n\\" }}\\n{%- else %}\\n {%- if messages[0].role == \'system\' %}\\n {{- \'<|im_start|>system\\\\n\' + messages[0].content + \'<|im_end|>\\\\n\' }}\\n {%- endif %}\\n{%- endif %}\\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\\n{%- for message in messages[::-1] %}\\n {%- set index = (messages|length - 1) - loop.index0 %}\\n {%- if ns.multi_step_tool and message.role == \\"user\\" and message.content is string and not(message.content.startswith(\'<tool_response>\') and message.content.endswith(\'</tool_response>\')) %}\\n {%- set ns.multi_step_tool = false %}\\n {%- set ns.last_query_index = index %}\\n {%- endif %}\\n{%- endfor %}\\n{%- for message in messages %}\\n {%- if message.content is string %}\\n {%- set content = message.content %}\\n {%- else %}\\n {%- set content = \'\' %}\\n {%- endif %}\\n {%- if (message.role == \\"user\\") or (message.role == \\"system\\" and not loop.first) %}\\n {{- \'<|im_start|>\' + message.role + \'\\\\n\' + content + \'<|im_end|>\' + \'\\\\n\' }}\\n {%- elif message.role == \\"assistant\\" %}\\n {{- \'<|im_start|>\' + message.role + \'\\\\n\' + content }}\\n {%- if message.tool_calls %}\\n {%- for tool_call in message.tool_calls %}\\n {%- if (loop.first and content) or (not loop.first) %}\\n {{- \'\\\\n\' }}\\n {%- endif %}\\n {%- if tool_call.function %}\\n {%- set tool_call = tool_call.function %}\\n {%- endif %}\\n {{- \'<tool_call>\\\\n{\\"name\\": \\"\' }}\\n {{- tool_call.name }}\\n {{- \'\\", \\"arguments\\": \' }}\\n {%- if tool_call.arguments is string %}\\n {{- tool_call.arguments }}\\n {%- else %}\\n {{- tool_call.arguments | tojson }}\\n {%- endif %}\\n {{- \'}\\\\n</tool_call>\' }}\\n {%- endfor %}\\n {%- endif %}\\n {{- \'<|im_end|>\\\\n\' }}\\n {%- elif message.role == \\"tool\\" %}\\n {%- if loop.first or (messages[loop.index0 - 1].role != \\"tool\\") %}\\n {{- \'<|im_start|>user\' }}\\n {%- endif %}\\n {{- \'\\\\n<tool_response>\\\\n\' }}\\n {{- content }}\\n {{- \'\\\\n</tool_response>\' }}\\n {%- if loop.last or (messages[loop.index0 + 1].role != \\"tool\\") %}\\n {{- \'<|im_end|>\\\\n\' }}\\n {%- endif %}\\n {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n {{- \'<|im_start|>assistant\\\\n\' }}\\n {%- if enable_thinking is defined and enable_thinking is false %}\\n {{- \'<think>\\\\n\\\\n</think>\\\\n\\\\n\' }}\\n {%- endif %}\\n{%- endif %}",
3+
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
4+
"stop_ids": [
5+
7
6+
],
7+
"end_of_assistant": "<|im_end|>"
8+
}

applications/ColossalChat/rl_example.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@
131131
default=1.0,
132132
help="Top p for sampling. Please check the generation arguments documentation for your backend.",
133133
)
134+
parser.add_argument("-ct", "--chat-template", type=str, default=None, help="Chat template to use for the model.")
134135
parser.add_argument("-s", "--system-prompt", type=str, default=None, help="System prompt for data construction.")
135136
parser.add_argument("-mnt", "--max-new-tokens", type=int, default=1024 * 4 - 512, help="Max length for generation.")
136137
parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.")
@@ -427,11 +428,20 @@
427428
"llm_call_budget": 10,
428429
"max_tokens": 2048,
429430
}
431+
grpo_config["forced_patterns"] = [
432+
r"<tool_response>\n.+\n</tool_response>"
433+
] # force at least one correct tool call
430434
else:
431435
raise ValueError(f"Unsupported agentic model type: {args.agentic_type}")
432436
else:
433437
agentic_config = None
434438

439+
tokenizer_config = {
440+
"path": args.model,
441+
"trust_remote_code": True,
442+
"chat_template": args.chat_template,
443+
}
444+
435445
launch_distributed(
436446
num_producers=args.num_inferencer,
437447
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.producer_tensor_parallel_size)
@@ -453,6 +463,7 @@
453463
train_model_config=train_model_config,
454464
grpo_config=grpo_config,
455465
agentic_config=agentic_config,
466+
tokenizer_config=tokenizer_config,
456467
plugin_config={
457468
"tp_size": args.tensor_parallel_size,
458469
"pp_size": args.pipeline_parallel_size,
@@ -480,7 +491,7 @@
480491
eval_interval=args.eval_interval,
481492
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
482493
eval_generation_config=eval_generation_config,
483-
log_rollout_interval=20,
494+
log_rollout_interval=1,
484495
rollout_save_dir=args.rollout_save_dir,
485496
enable_profiling=args.enable_profiling,
486497
n_behind=args.n_behind,

0 commit comments

Comments
 (0)