Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class ActorModel:
override_config: Dict[str, Any] = field(default_factory=dict)
enable_gradient_checkpointing: bool = True
use_remove_padding: bool = False
use_fused_kernels: bool = False


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions trinity/common/workflows/envs/alfworld/alfworld_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow, Task

EXAMPLE_PROMPT = """
Observation
Observation:
-= Welcome to TextWorld, ALFRED! =-

You are in the middle of a room. Looking quickly around you, you see a cabinet 4, a cabinet 3, a cabinet 2, a cabinet 1, a countertop 1, a garbagecan 1, a handtowelholder 2, a handtowelholder 1, a sinkbasin 2, a sinkbasin 1, a toilet 1, a toiletpaperhanger 1, and a towelholder 1.
Expand Down Expand Up @@ -88,7 +88,7 @@ def parse_action(response):
action = response.split("<action>")[1].split("</action>")[0].strip()
return action
except Exception as e:
print("Error parsing action:", e)
print(f"Error parsing action: {e}, response = {response}")
return ""


Expand Down
1 change: 1 addition & 0 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ async def explore(self) -> str:
self.eval_explore_step_num = None
while True:
try:
self.logger.info(f"Explore step {self.explore_step_num + 1} started.")
if (
self.eval_explore_step_num is None
and self.explore_step_num % self.config.explorer.eval_interval == 0
Expand Down
9 changes: 4 additions & 5 deletions trinity/manager/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def maintain_session_state(self):

def maintain_list_state(prefix, key_list):
last_idx, del_num = 0, 0
for idx in range(st.session_state[f"_{prefix}_num"]):
for idx in range(st.session_state[f"_{prefix}s_num"]):
if st.session_state.get(f"{prefix}_{idx}_del_flag", False):
del_num += 1
continue
Expand All @@ -73,7 +73,7 @@ def maintain_list_state(prefix, key_list):
last_full_key = f"{prefix}_{last_idx}_{key}"
st.session_state[last_full_key] = st.session_state[full_key]
last_idx += 1
st.session_state[f"_{prefix}_num"] -= del_num
st.session_state[f"_{prefix}s_num"] -= del_num

self.eval_dataset_keys = [
"name",
Expand All @@ -86,7 +86,7 @@ def maintain_list_state(prefix, key_list):
"logprobs",
"n",
]
maintain_list_state("eval_tasksets", self.eval_dataset_keys)
maintain_list_state("eval_taskset", self.eval_dataset_keys)

self.inference_model_keys = [
"model_path",
Expand All @@ -103,7 +103,7 @@ def maintain_list_state(prefix, key_list):
"enable_thinking",
"enable_openai_api",
]
maintain_list_state("auxiliary_models", self.inference_model_keys)
maintain_list_state("auxiliary_model", self.inference_model_keys)

def get_configs(self, *config_names: str, columns_spec: List[int] = None):
CONFIG_GENERATORS.get_configs(*config_names, columns_spec=columns_spec)
Expand Down Expand Up @@ -356,7 +356,6 @@ def _generate_verl_config(self):
],
"use_dynamic_bsz": use_dynamic_bsz,
"ppo_max_token_len_per_gpu": ppo_max_token_len_per_gpu,
"kl_loss_type": st.session_state["actor_kl_loss_type"],
"ppo_epochs": st.session_state["ppo_epochs"],
"shuffle": False,
"ulysses_sequence_parallel_size": st.session_state[
Expand Down
9 changes: 0 additions & 9 deletions trinity/manager/config_registry/trainer_config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,15 +265,6 @@ def set_actor_lr_warmup_steps_ratio(**kwargs):
)


@CONFIG_GENERATORS.register_config(default_value="low_var_kl")
def set_actor_kl_loss_type(**kwargs):
st.selectbox(
"KL Loss Type",
["kl", "abs", "mse", "low_var_kl"],
**kwargs,
)


@CONFIG_GENERATORS.register_config(default_value=["model", "hf_model", "optimizer", "extra"])
def set_actor_checkpoint(**kwargs):
st.multiselect(
Expand Down
7 changes: 6 additions & 1 deletion trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,18 @@ def need_sync(self) -> bool:
def sync_weight(self) -> None:
"""Sync the model weight."""
if self.config.synchronizer.sync_method == SyncMethod.NCCL:
self.logger.info(
f"Trainer synchronizing weights at step {self.engine.train_step_num} starting.."
)
if self.explorer_ref is None:
self.explorer_ref = ray.get_actor(self.config.explorer.name)
explorer_status = ray.get(self.explorer_ref.running_status.remote())
if explorer_status == RunningStatus.STOPPED:
self.logger.warning("Explorer has already stopped. Skipping sync weight.")
return
self.logger.info(f"Trainer synchronizing weights at step {self.engine.train_step_num}.")
self.logger.info(
f"Trainer synchronizing weights at step {self.engine.train_step_num} end."
)
self.engine.sync_weight()

def flush_log(self, step: int) -> None:
Expand Down
17 changes: 13 additions & 4 deletions trinity/trainer/verl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from trinity.common.config import Config
from trinity.common.experience import Experiences
from trinity.trainer.trainer import TrainEngineWrapper
from trinity.utils.log import get_logger
from trinity.utils.monitor import MONITOR


Expand Down Expand Up @@ -146,13 +147,14 @@ def __init__(
ray_worker_group_cls,
)
self.init_workers()
self.logger = MONITOR.get(global_config.monitor.monitor_type)(
self.monitor = MONITOR.get(global_config.monitor.monitor_type)(
project=config.trainer.project_name,
name=config.trainer.experiment_name,
role=global_config.trainer.name,
config=global_config,
)
self.reset_experiences_example_table()
self.logger = get_logger(__name__)

def _validate_config(self): # TODO
algorithm = ALGORITHM_TYPE.get(self.algorithm_config.algorithm_type)
Expand Down Expand Up @@ -276,7 +278,7 @@ def prepare(self):
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
val_metrics = self._validate()
pprint(f"Initial validation metrics: {val_metrics}")
self.logger.log(data=val_metrics, step=self.global_steps)
self.monitor.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get("val_only", False):
return

Expand All @@ -286,6 +288,7 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl
self.total_training_steps = self.config.trainer.total_training_steps or sys.maxsize

def train_step(self) -> bool: # noqa C901
self.logger.info(f"Training at step {self.global_steps + 1} started.")
metrics = {}
try:
batch, sample_metrics, exp_samples = self.sample_strategy.sample(self.global_steps + 1)
Expand All @@ -294,6 +297,7 @@ def train_step(self) -> bool: # noqa C901
print("No more data to train. Stop training.")
return False
self.global_steps += 1
self.logger.info(f"Sampling at step {self.global_steps} done.")
timing_raw = {}
algorithm_config = self.algorithm_manager.get_current_algorithm_config(self.global_steps)
algorithm = ALGORITHM_TYPE.get(algorithm_config.algorithm_type)
Expand Down Expand Up @@ -356,8 +360,10 @@ def train_step(self) -> bool: # noqa C901
self.config.trainer.save_freq > 0
and self.global_steps % self.config.trainer.save_freq == 0
):
self.logger.info(f"Saving at step {self.global_steps}.")
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()
self.logger.info(f"Saved at step {self.global_steps}.")

# collect metrics
if self.algorithm.use_advantage: # TODO
Expand All @@ -372,16 +378,19 @@ def train_step(self) -> bool: # noqa C901
self._log_experiences(exp_samples)

# TODO: make a canonical logger that supports various backend
self.logger.log(data=metrics, step=self.global_steps)
self.monitor.log(data=metrics, step=self.global_steps)

train_status = self.global_steps < self.total_training_steps
if not train_status or self.algorithm_manager.need_save(self.global_steps):
if (
self.config.trainer.save_freq == 0
or self.global_steps % self.config.trainer.save_freq != 0
):
self.logger.info(f"Saving at step {self.global_steps}.")
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()
self.logger.info(f"Saved at step {self.global_steps}.")
self.logger.info(f"Training at step {self.global_steps} finished.")
return train_status

def _log_single_experience(
Expand Down Expand Up @@ -412,7 +421,7 @@ def _log_single_experience(
def _log_experiences(self, samples: List[Dict]) -> None:
self.sample_exps_to_log.extend(samples)
if self.global_steps % self.config.trainer.sync_freq == 0:
self.logger.log_table(
self.monitor.log_table(
"rollout_examples", pd.DataFrame(self.sample_exps_to_log), self.global_steps
)
self.reset_experiences_example_table()
Expand Down
1 change: 1 addition & 0 deletions trinity/utils/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def log(self, data: dict, step: int, commit: bool = False) -> None:
"""Log metrics."""
for key in data:
self.logger.add_scalar(key, data[key], step)
self.console_logger.info(f"Step {step}: {data}")

def close(self) -> None:
self.logger.close()
Expand Down