Skip to content

Commit 8fa6f76

Browse files
authored
Bug Config Manager and Add more logger info (#106)
1 parent 339d658 commit 8fa6f76

File tree

8 files changed

+28
-21
lines changed

8 files changed

+28
-21
lines changed

trinity/common/verl_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class ActorModel:
2323
override_config: Dict[str, Any] = field(default_factory=dict)
2424
enable_gradient_checkpointing: bool = True
2525
use_remove_padding: bool = False
26+
use_fused_kernels: bool = False
2627

2728

2829
@dataclass

trinity/common/workflows/envs/alfworld/alfworld_workflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow, Task
77

88
EXAMPLE_PROMPT = """
9-
Observation
9+
Observation:
1010
-= Welcome to TextWorld, ALFRED! =-
1111
1212
You are in the middle of a room. Looking quickly around you, you see a cabinet 4, a cabinet 3, a cabinet 2, a cabinet 1, a countertop 1, a garbagecan 1, a handtowelholder 2, a handtowelholder 1, a sinkbasin 2, a sinkbasin 1, a toilet 1, a toiletpaperhanger 1, and a towelholder 1.
@@ -88,7 +88,7 @@ def parse_action(response):
8888
action = response.split("<action>")[1].split("</action>")[0].strip()
8989
return action
9090
except Exception as e:
91-
print("Error parsing action:", e)
91+
print(f"Error parsing action: {e}, response = {response}")
9292
return ""
9393

9494

trinity/explorer/explorer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ async def explore(self) -> str:
182182
self.eval_explore_step_num = None
183183
while True:
184184
try:
185+
self.logger.info(f"Explore step {self.explore_step_num + 1} started.")
185186
if (
186187
self.eval_explore_step_num is None
187188
and self.explore_step_num % self.config.explorer.eval_interval == 0

trinity/manager/config_manager.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def maintain_session_state(self):
6464

6565
def maintain_list_state(prefix, key_list):
6666
last_idx, del_num = 0, 0
67-
for idx in range(st.session_state[f"_{prefix}_num"]):
67+
for idx in range(st.session_state[f"_{prefix}s_num"]):
6868
if st.session_state.get(f"{prefix}_{idx}_del_flag", False):
6969
del_num += 1
7070
continue
@@ -73,7 +73,7 @@ def maintain_list_state(prefix, key_list):
7373
last_full_key = f"{prefix}_{last_idx}_{key}"
7474
st.session_state[last_full_key] = st.session_state[full_key]
7575
last_idx += 1
76-
st.session_state[f"_{prefix}_num"] -= del_num
76+
st.session_state[f"_{prefix}s_num"] -= del_num
7777

7878
self.eval_dataset_keys = [
7979
"name",
@@ -86,7 +86,7 @@ def maintain_list_state(prefix, key_list):
8686
"logprobs",
8787
"n",
8888
]
89-
maintain_list_state("eval_tasksets", self.eval_dataset_keys)
89+
maintain_list_state("eval_taskset", self.eval_dataset_keys)
9090

9191
self.inference_model_keys = [
9292
"model_path",
@@ -103,7 +103,7 @@ def maintain_list_state(prefix, key_list):
103103
"enable_thinking",
104104
"enable_openai_api",
105105
]
106-
maintain_list_state("auxiliary_models", self.inference_model_keys)
106+
maintain_list_state("auxiliary_model", self.inference_model_keys)
107107

108108
def get_configs(self, *config_names: str, columns_spec: List[int] = None):
109109
CONFIG_GENERATORS.get_configs(*config_names, columns_spec=columns_spec)
@@ -356,7 +356,6 @@ def _generate_verl_config(self):
356356
],
357357
"use_dynamic_bsz": use_dynamic_bsz,
358358
"ppo_max_token_len_per_gpu": ppo_max_token_len_per_gpu,
359-
"kl_loss_type": st.session_state["actor_kl_loss_type"],
360359
"ppo_epochs": st.session_state["ppo_epochs"],
361360
"shuffle": False,
362361
"ulysses_sequence_parallel_size": st.session_state[

trinity/manager/config_registry/trainer_config_manager.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -265,15 +265,6 @@ def set_actor_lr_warmup_steps_ratio(**kwargs):
265265
)
266266

267267

268-
@CONFIG_GENERATORS.register_config(default_value="low_var_kl")
269-
def set_actor_kl_loss_type(**kwargs):
270-
st.selectbox(
271-
"KL Loss Type",
272-
["kl", "abs", "mse", "low_var_kl"],
273-
**kwargs,
274-
)
275-
276-
277268
@CONFIG_GENERATORS.register_config(default_value=["model", "hf_model", "optimizer", "extra"])
278269
def set_actor_checkpoint(**kwargs):
279270
st.multiselect(

trinity/trainer/trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,18 @@ def need_sync(self) -> bool:
5757
def sync_weight(self) -> None:
5858
"""Sync the model weight."""
5959
if self.config.synchronizer.sync_method == SyncMethod.NCCL:
60+
self.logger.info(
61+
f"Trainer synchronizing weights at step {self.engine.train_step_num} starting.."
62+
)
6063
if self.explorer_ref is None:
6164
self.explorer_ref = ray.get_actor(self.config.explorer.name)
6265
explorer_status = ray.get(self.explorer_ref.running_status.remote())
6366
if explorer_status == RunningStatus.STOPPED:
6467
self.logger.warning("Explorer has already stopped. Skipping sync weight.")
6568
return
66-
self.logger.info(f"Trainer synchronizing weights at step {self.engine.train_step_num}.")
69+
self.logger.info(
70+
f"Trainer synchronizing weights at step {self.engine.train_step_num} end."
71+
)
6772
self.engine.sync_weight()
6873

6974
def flush_log(self, step: int) -> None:

trinity/trainer/verl_trainer.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from trinity.common.config import Config
3939
from trinity.common.experience import Experiences
4040
from trinity.trainer.trainer import TrainEngineWrapper
41+
from trinity.utils.log import get_logger
4142
from trinity.utils.monitor import MONITOR
4243

4344

@@ -146,13 +147,14 @@ def __init__(
146147
ray_worker_group_cls,
147148
)
148149
self.init_workers()
149-
self.logger = MONITOR.get(global_config.monitor.monitor_type)(
150+
self.monitor = MONITOR.get(global_config.monitor.monitor_type)(
150151
project=config.trainer.project_name,
151152
name=config.trainer.experiment_name,
152153
role=global_config.trainer.name,
153154
config=global_config,
154155
)
155156
self.reset_experiences_example_table()
157+
self.logger = get_logger(__name__)
156158

157159
def _validate_config(self): # TODO
158160
algorithm = ALGORITHM_TYPE.get(self.algorithm_config.algorithm_type)
@@ -276,7 +278,7 @@ def prepare(self):
276278
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
277279
val_metrics = self._validate()
278280
pprint(f"Initial validation metrics: {val_metrics}")
279-
self.logger.log(data=val_metrics, step=self.global_steps)
281+
self.monitor.log(data=val_metrics, step=self.global_steps)
280282
if self.config.trainer.get("val_only", False):
281283
return
282284

@@ -286,6 +288,7 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl
286288
self.total_training_steps = self.config.trainer.total_training_steps or sys.maxsize
287289

288290
def train_step(self) -> bool: # noqa C901
291+
self.logger.info(f"Training at step {self.global_steps + 1} started.")
289292
metrics = {}
290293
try:
291294
batch, sample_metrics, exp_samples = self.sample_strategy.sample(self.global_steps + 1)
@@ -294,6 +297,7 @@ def train_step(self) -> bool: # noqa C901
294297
print("No more data to train. Stop training.")
295298
return False
296299
self.global_steps += 1
300+
self.logger.info(f"Sampling at step {self.global_steps} done.")
297301
timing_raw = {}
298302
algorithm_config = self.algorithm_manager.get_current_algorithm_config(self.global_steps)
299303
algorithm = ALGORITHM_TYPE.get(algorithm_config.algorithm_type)
@@ -356,8 +360,10 @@ def train_step(self) -> bool: # noqa C901
356360
self.config.trainer.save_freq > 0
357361
and self.global_steps % self.config.trainer.save_freq == 0
358362
):
363+
self.logger.info(f"Saving at step {self.global_steps}.")
359364
with _timer("save_checkpoint", timing_raw):
360365
self._save_checkpoint()
366+
self.logger.info(f"Saved at step {self.global_steps}.")
361367

362368
# collect metrics
363369
if self.algorithm.use_advantage: # TODO
@@ -372,16 +378,19 @@ def train_step(self) -> bool: # noqa C901
372378
self._log_experiences(exp_samples)
373379

374380
# TODO: make a canonical logger that supports various backend
375-
self.logger.log(data=metrics, step=self.global_steps)
381+
self.monitor.log(data=metrics, step=self.global_steps)
376382

377383
train_status = self.global_steps < self.total_training_steps
378384
if not train_status or self.algorithm_manager.need_save(self.global_steps):
379385
if (
380386
self.config.trainer.save_freq == 0
381387
or self.global_steps % self.config.trainer.save_freq != 0
382388
):
389+
self.logger.info(f"Saving at step {self.global_steps}.")
383390
with _timer("save_checkpoint", timing_raw):
384391
self._save_checkpoint()
392+
self.logger.info(f"Saved at step {self.global_steps}.")
393+
self.logger.info(f"Training at step {self.global_steps} finished.")
385394
return train_status
386395

387396
def _log_single_experience(
@@ -412,7 +421,7 @@ def _log_single_experience(
412421
def _log_experiences(self, samples: List[Dict]) -> None:
413422
self.sample_exps_to_log.extend(samples)
414423
if self.global_steps % self.config.trainer.sync_freq == 0:
415-
self.logger.log_table(
424+
self.monitor.log_table(
416425
"rollout_examples", pd.DataFrame(self.sample_exps_to_log), self.global_steps
417426
)
418427
self.reset_experiences_example_table()

trinity/utils/monitor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def log(self, data: dict, step: int, commit: bool = False) -> None:
8181
"""Log metrics."""
8282
for key in data:
8383
self.logger.add_scalar(key, data[key], step)
84+
self.console_logger.info(f"Step {step}: {data}")
8485

8586
def close(self) -> None:
8687
self.logger.close()

0 commit comments

Comments
 (0)