[Feat] support debug_train mode in RL Trainer#1423
[Feat] support debug_train mode in RL Trainer#1423YanhuiDua wants to merge 2 commits intoInternLM:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Adds a debug_train mode to RLTrainer to support training workers using a fixed trajectory file (rather than generating new rollouts each step), and adjusts trajectory loading/validation to better support this workflow.
Changes:
- Add
debug_trainflag toRLTrainerConfig/RLTrainerand routefit()to a new_fit_debug_train()path. - Refactor training loop into
_fit_normal(),_fit_debug_rollout(), and_fit_debug_train()and adjust rollout offload behavior. - Extend trajectory handling: fallback prompt tokenization when
train_prompt_idsis absent, and add a new_load_trajectories()that reconstructsRLDataFlowItemgroups from saved trajectory JSONL.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 7 comments.
| File | Description |
|---|---|
| xtuner/v1/train/rl_trainer.py | Adds debug-train execution path; changes rollout/train loop structure; adds trajectory loading and prompt tokenization fallback. |
| xtuner/v1/data_proto/rl_data.py | Relaxes training-validity checks to allow either response_ids or response text (with updated logging). |
| tests/ray/test_rl_trainer.py | Enables debug_train in the RL trainer integration test to cover the new mode. |
Comments suppressed due to low confidence (1)
xtuner/v1/train/rl_trainer.py:1029
- After changing
_load_trajectoriesto producelist[list[RLDataFlowItem]],_compute_metricsstill assumesdata_groupsis a list of dicts (data[0]["reward"]). This is now inconsistent with the trajectory loading path and will break if used. Please update it to useRLDataFlowItemfields (e.g.data[0].env.judger.reward["score"]) or remove the unused helper.
def _compute_metrics(self, data_groups):
correctness = [1 if data[0]["reward"] > 0 else 0 for data in data_groups]
acc = sum(correctness) / len(correctness)
return acc
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| prompt_ids = group[0].data.extra_info["train_prompt_ids"] | ||
| if "train_prompt_ids" not in group[0].data.extra_info: | ||
| prompt_ids = ( | ||
| self.tokenizer(group[0].data.extra_info["raw_prompt"], return_tensors="pt")["input_ids"] |
There was a problem hiding this comment.
In the debug_train fallback path, prompt_ids are built via tokenizer(raw_prompt, return_tensors="pt") without disabling special tokens. This diverges from the dataset tokenize path (e.g. RLTokenizeFn uses add_special_tokens=False) and can change prompt length/tokenization, breaking alignment assumptions (e.g. logprob padding and shifted labels) in training. Please tokenize raw_prompt consistently with the dataset (disable special tokens / reuse the same tokenize helper).
| self.tokenizer(group[0].data.extra_info["raw_prompt"], return_tensors="pt")["input_ids"] | |
| self.tokenizer( | |
| group[0].data.extra_info["raw_prompt"], | |
| return_tensors="pt", | |
| add_special_tokens=False, | |
| )["input_ids"] |
| trajectory_files = list(self._work_dir.rglob("*trajectory.jsonl")) | ||
| if trajectory_files: | ||
| # 找到当前worker_log_dir目录下最新的轨迹 | ||
| latest_trajectory_file = max(trajectory_files, key=lambda p: p.stat().st_mtime) | ||
| self.logger.info(f"Debug train mode will use fixed trajectory file at {latest_trajectory_file}") | ||
| else: | ||
| self.logger.info( | ||
| f"Trajectory file in {self._work_dir} not found. Running one rollout step to generate it." |
There was a problem hiding this comment.
_fit_debug_train picks the most recently modified *trajectory.jsonl under work_dir, which can include evaluation trajectories (e.g. eval_0_trajectory.jsonl) and inadvertently train on eval data. Consider filtering to rollout trajectories (e.g. rollout_idx_*_trajectory.jsonl) or allowing an explicit trajectory path to be passed in.
| trajectory_files = list(self._work_dir.rglob("*trajectory.jsonl")) | |
| if trajectory_files: | |
| # 找到当前worker_log_dir目录下最新的轨迹 | |
| latest_trajectory_file = max(trajectory_files, key=lambda p: p.stat().st_mtime) | |
| self.logger.info(f"Debug train mode will use fixed trajectory file at {latest_trajectory_file}") | |
| else: | |
| self.logger.info( | |
| f"Trajectory file in {self._work_dir} not found. Running one rollout step to generate it." | |
| # Only use rollout trajectories (e.g. rollout_idx_*_trajectory.jsonl) to avoid training on eval data | |
| all_trajectory_files = list(self._work_dir.rglob("*trajectory.jsonl")) | |
| trajectory_files = [p for p in all_trajectory_files if p.name.startswith("rollout_idx_")] | |
| if trajectory_files: | |
| # 找到当前worker_log_dir目录下最新的轨迹 | |
| latest_trajectory_file = max(trajectory_files, key=lambda p: p.stat().st_mtime) | |
| self.logger.info(f"Debug train mode will use fixed trajectory file at {latest_trajectory_file}") | |
| else: | |
| self.logger.info( | |
| f"Rollout trajectory file in {self._work_dir} not found. Running one rollout step to generate it." |
| # offload rollout model | ||
| ray.get(self._rollout_dataflow.pause.remote()) | ||
| ray.get(self._rollout_env_controller.offload.remote()) |
There was a problem hiding this comment.
In debug_train mode, _rollout_step() already pauses the dataflow and offloads the rollout env. The extra pause/offload immediately afterwards is redundant work and adds Ray + HTTP overhead. Consider removing the duplicate calls or guarding them based on whether a rollout was just executed.
| except json.JSONDecodeError: | ||
| pass |
There was a problem hiding this comment.
_load_trajectories swallows json.JSONDecodeError silently and still resets temp_str, which can hide corrupted/partial trajectory files and lead to training on incomplete data with no signal. At minimum log the decode error (and file position/object snippet), or raise in debug_train mode so failures are visible.
| except json.JSONDecodeError: | |
| pass | |
| except json.JSONDecodeError as e: | |
| logger = get_logger(__name__) | |
| try: | |
| position = f.tell() | |
| except Exception: | |
| position = "unknown" | |
| snippet = temp_str.strip() | |
| logger.error( | |
| "Failed to decode JSON trajectory from %s at position %s: %s; snippet: %r", | |
| save_path, | |
| position, | |
| e, | |
| snippet[:200], | |
| ) |
| data_dict: dict[str, list[RLDataFlowItem]] = {} | ||
| data_groups: list[list[RLDataFlowItem]] = [] | ||
| brace_count = 0 | ||
| temp_str = "" | ||
|
|
||
| with open(save_path, encoding="utf-8") as f: | ||
| for line in f: | ||
| item = json.loads(line) | ||
| messages = item["messages"] | ||
| responses = item["response"] | ||
| rewards = item["reward"] | ||
| group = [] | ||
| for response, reward in zip(responses, rewards): | ||
| group.append( | ||
| { | ||
| "messages": messages, | ||
| "response_str": response, | ||
| "reward": reward, | ||
| } | ||
| ) | ||
| data_groups.append(group) | ||
| return data_groups | ||
| temp_str += line | ||
| stripped = line.strip() | ||
|
|
||
| # 如果整行是左括号或右括号,按照逻辑更新计数器 | ||
| if stripped == "{": | ||
| brace_count += 1 | ||
| elif stripped == "}": | ||
| brace_count -= 1 | ||
|
|
||
| # 当计数器归零(匹配完整对象)且内容不为空时处理 | ||
| if brace_count == 0 and temp_str.strip(): | ||
| try: | ||
| # 将提取的字符串转为 dict | ||
| obj = json.loads(temp_str.strip()) | ||
| if "action_id" in obj: | ||
| # 构造 RLDataFlowItem 并添加到列表中 | ||
| data_item = RLDataFlowItem( | ||
| uid=RLUIDItem(action_id=obj["action_id"]), | ||
| data=RLDatasetItem( | ||
| extra_info={"raw_prompt": obj["prompt"]}, | ||
| ), | ||
| env=RLEnvDataItem( | ||
| rollout=RLRolloutResponseItem( | ||
| response=obj["response"], | ||
| num_return_tokens=obj["response_len"], | ||
| finish_reason=obj["finish_reason"], | ||
| state=RolloutState.COMPLETED, | ||
| extra_info={}, | ||
| ), | ||
| judger=RLJudgerResponseItem(reward={"score": obj["reward"]}), | ||
| ), | ||
| ) | ||
| data_dict.setdefault(obj["action_id"], []).append(data_item) | ||
| except json.JSONDecodeError: |
There was a problem hiding this comment.
Type hint mismatch: data_dict is declared as dict[str, list[RLDataFlowItem]], but action_id is serialized/deserialized as an int and used directly as the key. Update the annotation to dict[int, ...] or cast the key to str consistently.
| if not ids_valid and not response_valid: | ||
| # NOTE: `response_ids` is the critical field for token-in-token-out mode, so we ensure it's not empty. | ||
| logger.warning( | ||
| "Invalid dataflow item found during training: no response or response_ids and skip this item." | ||
| logger.error( | ||
| "Invalid dataflow item found during training: no response_ids and no response and skip this item." | ||
| ) | ||
| return False | ||
| if not response_valid: | ||
| # NOTE: check valid response string for judger inputs | ||
| logger.warning("Invalid dataflow item found during training: empty response string and skip this item.") | ||
| return False | ||
| elif not ids_valid: | ||
| logger.warning( |
There was a problem hiding this comment.
The inline NOTE above this branch says response_ids is the critical field and implies it must be non-empty, but the updated logic now allows training to proceed when response_ids is missing as long as response text exists. Please update the NOTE/docstring to reflect the new accepted cases so the validation intent remains clear.
| ), | ||
| ) | ||
| data_dict.setdefault(obj["action_id"], []).append(data_item) | ||
| except json.JSONDecodeError: |
There was a problem hiding this comment.
'except' clause does nothing but pass and there is no explanatory comment.
No description provided.