Skip to content

Commit dc2153a

Browse files
committed
remove debug lines and fix pre-commit error
1 parent f21184c commit dc2153a

File tree

6 files changed

+31
-42
lines changed

6 files changed

+31
-42
lines changed

trinity/algorithm/policy_loss_fn/ppo_policy_loss.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(
1919
clip_range: Optional[float] = None,
2020
clip_range_low: Optional[float] = None,
2121
clip_range_high: Optional[float] = None,
22+
clip_ratio_c: Optional[float] = 3.0,
2223
loss_agg_mode: Optional[str] = "token-mean",
2324
) -> None:
2425
super().__init__(backend=backend)
@@ -30,8 +31,13 @@ def __init__(
3031
self.clip_range_high = clip_range
3132
else:
3233
self.clip_range_high = clip_range_high
34+
self.clip_ratio_c = clip_ratio_c
3335
assert self.clip_range_low is not None, "clip_range_low must be specified."
3436
assert self.clip_range_high is not None, "clip_range_high must be specified."
37+
assert self.clip_ratio_c is not None and self.clip_ratio_c > 1.0, (
38+
"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0,"
39+
+ f" but get the value: {clip_ratio_c}."
40+
)
3541
self.loss_agg_mode = loss_agg_mode
3642

3743
def __call__( # type: ignore
@@ -43,28 +49,38 @@ def __call__( # type: ignore
4349
**kwargs,
4450
) -> Tuple[torch.Tensor, Dict]:
4551
negative_approx_kl = logprob - old_logprob
52+
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
4653
ratio = torch.exp(negative_approx_kl)
4754
ppo_kl = masked_mean(-negative_approx_kl, action_mask)
4855

49-
pg_losses = -advantages * ratio
56+
pg_losses1 = -advantages * ratio
5057
pg_losses2 = -advantages * torch.clamp(
5158
ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high # type: ignore
5259
)
60+
clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2)
61+
pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses1).float(), action_mask)
5362

54-
pg_loss = masked_loss(
55-
torch.max(pg_losses, pg_losses2), action_mask, loss_agg_mode=self.loss_agg_mode
63+
pg_losses3 = -advantages * self.clip_ratio_c
64+
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
65+
pg_clipfrac_lower = masked_mean(
66+
torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), action_mask
5667
)
57-
pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), action_mask)
68+
69+
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
70+
pg_loss = masked_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode)
71+
5872
metrics = {
5973
"pg_clipfrac": pg_clipfrac.detach().item(),
6074
"ppo_kl": ppo_kl.detach().item(),
6175
"pg_loss": pg_loss.detach().item(),
76+
"pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
6277
}
6378
return pg_loss, metrics
6479

6580
@classmethod
6681
def default_args(cls) -> Dict:
6782
return {
6883
"clip_range": 0.2,
84+
"clip_ratio_c": 3.0,
6985
"loss_agg_mode": "token-mean",
7086
}

trinity/common/config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -926,7 +926,6 @@ def _check_trainer_input(self) -> None:
926926
experience_buffer.batch_size = self.buffer.train_batch_size
927927
experience_buffer.tokenizer_path = self.model.model_path
928928
set_if_none(experience_buffer, "ray_namespace", self.ray_namespace)
929-
# TODO: this cannot apply chat_template_path, as check_model is later than this line
930929
set_if_none(experience_buffer.format, "chat_template", self.model.custom_chat_template)
931930
for aux_name, aux_buffer in trainer_input.auxiliary_buffers.items():
932931
aux_buffer.batch_size = self.buffer.train_batch_size
@@ -1069,7 +1068,7 @@ def _check_model(self) -> None:
10691068
model.critic_model_path = model.model_path
10701069

10711070
# check template
1072-
if model.chat_template_path:
1071+
if model.chat_template_path and model.custom_chat_template is None:
10731072
with open(model.chat_template_path, "r") as f:
10741073
model.custom_chat_template = f.read()
10751074

trinity/common/models/vllm_model.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -370,19 +370,14 @@ async def convert_messages_to_experience(
370370
enable_thinking=self.enable_thinking,
371371
) # (seq_length, ), (seq_length, )
372372
logprobs = await self.logprobs(token_ids=token_ids.tolist()) # (seq_length - 1,)
373-
exp = Experience(
373+
return Experience(
374374
tokens=token_ids,
375375
logprobs=logprobs[prompt_length - 1 :],
376376
prompt_length=prompt_length,
377377
action_mask=action_mask[prompt_length:], # Exclude the prompt
378378
prompt_text=self.tokenizer.decode(token_ids[:prompt_length]),
379379
response_text=self.tokenizer.decode(token_ids[prompt_length:]),
380380
)
381-
import torch
382-
torch.set_printoptions(threshold=float('inf'))
383-
print(f"!!!Debug: {exp.tokens=} {exp.logprobs=} {exp.prompt_length=} {exp.action_mask=} {exp.prompt_text=} {exp.response_text=}")
384-
print("sum(action_mask): ", torch.sum(exp.action_mask))
385-
return exp
386381

387382
async def shutdown(self):
388383
"""Shutdown the vLLM v1 engine. This kills child processes forked

trinity/common/workflows/envs/frozen_lake/workflow.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,13 @@ def __init__(
100100
self.use_multistep_prompt = workflow_args.get("use_multistep_prompt", False)
101101
self.desc = workflow_args.get("desc", None)
102102
self.is_slippery = workflow_args.get("is_slippery", False)
103-
print(f"{self.rollout_args =}")
104103
self.max_response_tokens = self.rollout_args.get("max_response_tokens", 10240)
105104

106105
# Extract task-specific arguments
107106
self.raw_task = task.raw_task if hasattr(task, "raw_task") else {}
108107
self.size = self.raw_task.get("size", 1)
109108
self.p = self.raw_task.get("p", 0.8)
110109
self.seed = self.raw_task.get("seed", 42)
111-
print("self.size: ", self.size, "self.p: ", self.p, "self.seed: ", self.seed)
112110

113111
if self.desc is None:
114112
random_map, goal_position = generate_random_map(
@@ -241,11 +239,17 @@ def render(self, mode="tiny_rgb_array"):
241239
room_state = self.render(mode="state").tolist()
242240

243241
if mode == "list":
244-
lookup = lambda cell: GRID_LOOKUP.get(cell, "?").strip("\t").strip()
242+
243+
def lookup(cell):
244+
return GRID_LOOKUP.get(cell, "?").strip("\t").strip()
245+
245246
return [" ".join(lookup(cell) for cell in row) for row in room_state]
246247

247248
if mode == "tiny_rgb_array":
248-
lookup = lambda cell: GRID_LOOKUP.get(cell, "?")
249+
250+
def lookup(cell):
251+
return GRID_LOOKUP.get(cell, "?")
252+
249253
result = "\n".join("".join(lookup(cell) for cell in row) for row in room_state)
250254
return result
251255

@@ -271,7 +275,6 @@ async def run_async(self) -> List[Experience]:
271275

272276
# Run episode until done or max_steps reached
273277
for step in range(self.max_steps):
274-
print("Current step: ", step)
275278
# Format observation for the model
276279
current_obs_str = str(self.current_observation)
277280
user_prompt_content = (
@@ -301,11 +304,9 @@ async def run_async(self) -> List[Experience]:
301304
else:
302305
response_token_len = messages_token_len - init_prompt_token_len
303306
max_tokens = self.max_response_tokens - response_token_len
304-
print(
305-
f"!!!Debug: {max_tokens=} used_response_tokens = {self.max_response_tokens-max_tokens} {messages_token_len=} {init_prompt_token_len=}"
306-
)
307307

308308
if max_tokens <= 0:
309+
# messages = messages[:-1] # TODO: apply this?
309310
self.done = False
310311
self.final_reward = 0
311312
break
@@ -314,13 +315,9 @@ async def run_async(self) -> List[Experience]:
314315
rollout_args = self.rollout_args.copy()
315316
rollout_args["n"] = 1
316317
rollout_args["max_tokens"] = max_tokens
317-
# print("Current step: ", step, rollout_args)
318318
responses = await self.model.chat_async(messages, **rollout_args)
319319
response_text = responses[0].response_text
320320
messages.append({"role": "assistant", "content": response_text})
321-
print(
322-
"raw response: ", response_text
323-
) # sometimes has <think></think> and <action>, somtimes not
324321

325322
# Parse action from response
326323
_, action_str = self._parse_model_response(response_text)
@@ -349,15 +346,6 @@ async def run_async(self) -> List[Experience]:
349346
"success": 1 if self.final_reward == 1.0 else 0,
350347
},
351348
)
352-
print("\n\n\n")
353-
print("full messages: ", messages)
354-
# print("experience.tokens: ", len(experience.tokens))
355-
# print("experience.logprobs: ", len(experience.logprobs))
356-
# print("experience.action_mask: ", len(experience.action_mask))
357-
# print("experience.prompt_length: ", experience.prompt_length)
358-
# print("experience.reward: ", experience.reward)
359-
# print("experience.prompt_text: ", experience.prompt_text)
360-
# print("experience.response_text: ", experience.response_text, "\n\n\n")
361349
return [experience]
362350

363351
def _parse_model_response(self, response: str) -> tuple[str, str]:

trinity/common/workflows/step_wise_workflow.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@
33
from trinity.common.experience import Experience
44
from trinity.common.models.model import ModelWrapper
55
from trinity.common.workflows.workflow import Task, Workflow
6-
from trinity.utils.log import get_logger
7-
8-
logger = get_logger(__name__) # TODO: delete this after debugging
96

107

118
class StepWiseRewardWorkflow(Workflow):
@@ -149,7 +146,6 @@ def run(self) -> list[Experience]:
149146
experiences.extend(exps)
150147
if not continue_run:
151148
break
152-
logger.info(f"Experiences[0]: {experiences[0].response_text}")
153149
reward = self.reward(experiences)
154150
for exp in experiences:
155151
exp.reward = reward

trinity/common/workflows/workflow.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,6 @@ def set_repeat_times(self, repeat_times, run_id_base):
168168
def process_messages_to_experience(self, messages, reward, info={}) -> Experience:
169169
converted_experience = self.model.convert_messages_to_experience(messages)
170170

171-
if converted_experience.info.get("is_truncated", False):
172-
print(f"!!!Debug: a truncation experience with reward {reward} is generated")
173-
# TODO: handle this case
174-
reward = 0
175-
176171
tokens = converted_experience.tokens
177172
log_probs = converted_experience.logprobs
178173
assert converted_experience.action_mask is not None

0 commit comments

Comments
 (0)