Skip to content

Commit 2b46ab1

Browse files
committed
simplify _run_agentic_pipeline; fix old_log_probs
1 parent d47c563 commit 2b46ab1

File tree

4 files changed

+17
-25
lines changed

4 files changed

+17
-25
lines changed

applications/ColossalChat/coati/distributed/agent/agentic_producer.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,7 @@ def _run_agentic_pipeline(self, messages):
224224
if llm_call_count > self.llm_call_budget:
225225
print(f"LLM call budget exceeded: {llm_call_count} > {self.llm_call_budget}. Stopping.")
226226
del self.async_llm_engine_map[request_id]
227-
while messages[-1]["role"] == "tool":
228-
messages.pop()
229-
return messages, logprobs
227+
return messages, response_input_ids, logprobs
230228
inputs = self._build_prompt(messages, return_dict=True, return_tensors="pt")
231229
if num_prompt_tokens == 0:
232230
num_prompt_tokens = inputs["input_ids"].size(-1)
@@ -235,9 +233,7 @@ def _run_agentic_pipeline(self, messages):
235233
f"Max tokens exceeded: Current have generated {inputs['input_ids'].size(-1) - num_prompt_tokens} tokens > {self.generate_config.get('max_tokens', 512)}. Stopping."
236234
)
237235
del self.async_llm_engine_map[request_id]
238-
while messages[-1]["role"] == "tool":
239-
messages.pop()
240-
return messages, logprobs
236+
return messages, response_input_ids, logprobs
241237
async_producer = self._select_async_producer(request_id=request_id)
242238
agentic_generate_config = copy.deepcopy(self.generate_config)
243239
agentic_generate_config["max_tokens"] = self.agentic_config.get("max_tokens", 2048)
@@ -262,7 +258,7 @@ def _run_agentic_pipeline(self, messages):
262258
if tool_call_count > self.tool_call_budget:
263259
print(f"Tool call budget exceeded: {tool_call_count} > {self.tool_call_budget}. Stopping.")
264260
del self.async_llm_engine_map[request_id]
265-
return messages, logprobs
261+
return messages, response_input_ids, logprobs
266262
tool_call_count += len(assistant_message["tool_calls"])
267263
handlers = []
268264
for tool_call in assistant_message["tool_calls"]:
@@ -277,4 +273,4 @@ def _run_agentic_pipeline(self, messages):
277273
else:
278274
# no further tool call, return the messages
279275
del self.async_llm_engine_map[request_id]
280-
return messages, logprobs
276+
return messages, response_input_ids, logprobs

applications/ColossalChat/coati/distributed/agent/base.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -123,32 +123,30 @@ def rollout(self, **kwargs) -> Dict[str, torch.Tensor]:
123123
)
124124

125125
for i in range(self.num_generations):
126-
_messages, logprobs = results[i]
127-
response_input_ids = self._build_prompt(
128-
_messages, return_dict=True, return_tensors="pt", add_generation_prompt=False
129-
)["input_ids"]
126+
# due to the multiround feature, action_mask and attention_mask need to be recomputed
127+
_messages, response_input_ids, logprobs = results[i]
130128
# truncate if too long
131-
response_input_ids = response_input_ids[:, : self.grpo_config["max_length"] - to_pad_left]
129+
response_input_ids = response_input_ids[0, :, : self.grpo_config["max_length"] - to_pad_left]
132130
# add left right padding
133-
to_pad_right = self.grpo_config["max_length"] - response_input_ids.shape[1] - to_pad_left
134-
response_length = response_input_ids.shape[1] - prompt_length
131+
to_pad_right = self.grpo_config["max_length"] - response_input_ids.size(-1) - to_pad_left
135132
input_ids = torch.nn.functional.pad(
136133
response_input_ids, (to_pad_left, to_pad_right), "constant", value=self.tokenizer.pad_token_id
137134
) # [1, max_length]
138135
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).int() # [1, max_length]
139136
action_mask = input_ids[:, max_prompt_length:].ne(self.tokenizer.pad_token_id).int()
137+
response_length = action_mask.sum().item()
140138
rollouts["attention_mask"].append(attention_mask)
141139
rollouts["action_mask"].append(action_mask)
142140
truncated_logprobs = logprobs[
143-
:, :, prompt_length : prompt_length + self.generate_config["max_tokens"]
141+
0, :, prompt_length : prompt_length + self.generate_config["max_tokens"]
144142
] # truncate to max_new_tokens
145143
logprobs_padded = torch.nn.functional.pad(
146144
truncated_logprobs,
147145
(0, self.generate_config["max_tokens"] - truncated_logprobs.size(-1)),
148146
"constant",
149147
value=0.0,
150148
) # [1, max_new_tokens]
151-
rollouts["action_log_probs"].append(logprobs_padded[0])
149+
rollouts["action_log_probs"].append(logprobs_padded)
152150
rollouts["response_idx"].append(
153151
torch.tensor(
154152
[

applications/ColossalChat/coati/distributed/loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ def forward(
3737
total_effective_tokens_in_batch: torch.Tensor = None,
3838
) -> torch.Tensor:
3939
if action_mask is None:
40-
ratio = (log_probs - log_probs.detach()).exp()
40+
ratio = (log_probs - old_log_probs.detach()).exp()
4141
else:
42-
ratio = ((log_probs - log_probs.detach()) * action_mask).exp()
42+
ratio = ((log_probs - old_log_probs.detach()) * action_mask).exp()
4343

4444
surr1 = ratio * advantages
4545
surr2 = ratio.clamp(1 - self.clip_eps_low, 1 + self.clip_eps_high) * advantages

applications/ColossalChat/rl_example.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -429,18 +429,16 @@
429429
"max_tokens": 2048,
430430
}
431431
grpo_config["forced_patterns"] = [
432-
r"<tool_response>\n.+\n</tool_response>"
432+
r"<tool_response>\n.+\n</tool_response>" # please modify based on your tool response format
433433
] # force at least one correct tool call
434434
else:
435435
raise ValueError(f"Unsupported agentic model type: {args.agentic_type}")
436436
else:
437437
agentic_config = None
438438

439-
tokenizer_config = {
440-
"path": args.model,
441-
"trust_remote_code": True,
442-
"chat_template": args.chat_template,
443-
}
439+
tokenizer_config = {"path": args.model, "trust_remote_code": True}
440+
if args.chat_template is not None:
441+
tokenizer_config["chat_template"] = args.chat_template
444442

445443
launch_distributed(
446444
num_producers=args.num_inferencer,

0 commit comments

Comments
 (0)