Skip to content

Commit e501074

Browse files
committed
Fix bug, update verl
1 parent f683ded commit e501074

File tree

4 files changed

+19
-5
lines changed

4 files changed

+19
-5
lines changed

agents/agents/agents/chain/chain_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,11 +214,11 @@ def prepare_chain_messages(self, start_messages: Union[List[dict], np.ndarray]):
214214

215215
return messages_list, other_info_list
216216

217-
def validate_run_args(self, max_steps: int, num_chains: int):
217+
def validate_run_args(self, max_steps: int, num_chains: int, enable_streaming: bool):
218218
assert max_steps >= 1, "max_steps must be at least 1."
219219
assert num_chains >= 1, "num_chains must be at least 1."
220220
for observer in self.streaming_manager.observers:
221-
if isinstance(observer, ConsoleStreamObserver):
221+
if isinstance(observer, ConsoleStreamObserver) and enable_streaming:
222222
assert num_chains == 1, "num_chains must be 1 when ConsoleStreamObserver is used."
223223

224224

@@ -241,7 +241,7 @@ async def run_async(self,
241241
enable_streaming: Whether to enable streaming mode.
242242
streaming_callback: Optional callback for streaming events.
243243
"""
244-
self.validate_run_args(max_steps, num_chains)
244+
self.validate_run_args(max_steps, num_chains, enable_streaming)
245245
Monitor.ensure_started()
246246
self.reset()
247247
messages_list, other_info_list = self.prepare_chain_messages(start_messages)

agents/agents/agents/templates/templates.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,13 @@ def encode(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_te
215215
attention_mask = []
216216
labels = []
217217
action_mask = []
218+
219+
if tokenizer.bos_token and tokenizer.add_bos_token:
220+
input_ids.append(tokenizer.bos_token_id)
221+
attention_mask.append(1)
222+
labels.append(-100)
223+
action_mask.append(0)
224+
218225
for element, mask_flag in zip(elements, mask_flags):
219226
cur_input_ids = tokenizer.encode(element, add_special_tokens=False)
220227
input_ids.extend(cur_input_ids)

agents/agents/agents/templates/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,14 @@ def tokenize_conversation(
139139
:return: input_ids, attention_mask, labels, action_mask
140140
"""
141141
chat = Chat(template=template, messages=messages, tokenizer=tokenizer)
142-
return chat.tokenize(tokenizer, add_generation_prompt=add_generation_prompt, tools=tools)
142+
inputs = chat.tokenize(tokenizer, add_generation_prompt=add_generation_prompt, tools=tools)
143+
if max_length is not None:
144+
inputs['input_ids'] = inputs['input_ids'][:, :max_length]
145+
inputs['attention_mask'] = inputs['attention_mask'][:, :max_length]
146+
inputs['labels'] = inputs['labels'][:, :max_length]
147+
inputs['action_mask'] = inputs['action_mask'][:, :max_length]
148+
149+
return inputs
143150

144151
def convert_inputs_to_vision_inputs(template: str,
145152
inputs: dict,

0 commit comments

Comments
 (0)