-
Notifications
You must be signed in to change notification settings - Fork 186
feature(xjy): Refine PriorZero Implementation #441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev-multitask-balance-clean-rft
Are you sure you want to change the base?
feature(xjy): Refine PriorZero Implementation #441
Conversation
…_llm_prior, and SFT loss
…lect to cprofile.
…ed the REINFORCE-series loss computation.
…me. Single-GPU works; multi-GPU not tested yet.
| for i in range(num_engines): | ||
| bundle_indices = None | ||
| if tensor_parallel_size > 1: | ||
| bundle_indices = get_bundle_indices(shared_pg, i, tensor_parallel_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是参考的ray官方改进吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个vllm_engine基本和openrlhf这部分是一样的;不过目前只使用一个vllm,并且tensor_parallel_size =1;因为显存够
…ple for world-model training; train LLM only on latest trajectories
| for action in actions: | ||
| prior.append(llm_prior_logprob[idx][action]) | ||
| policy_priors.append(prior) | ||
| policy_priors = self.pad_to_fixed_length(data=policy_priors, target_len=self.cfg.model.action_space_size, pad_val=-1e9) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注意检查这里valid_actions_list的顺序与action_mask的对应关系
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里检查了没问题
| # ============ LLM Loss Metrics ============ | ||
| # ============ LLM Loss Metrics ============ | ||
| 'llm_sft_loss', # Supervised fine-tuning loss | ||
| 'llm_rft_loss', # Reinforcement fine-tuning loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前_forward_learn没有计算这些统计量吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分我修改了,现在在llm里面统计了,这个地方确实没有
| return samples | ||
| T = len(raw_obs_list[0]) | ||
|
|
||
| for b in range(B): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO:更高效的构建方法
| if num_of_transitions >= replay_buffer.replay_buffer_size: | ||
| all_data = replay_buffer.sample(batch_size=replay_buffer.replay_buffer_size, policy=policy) | ||
| replay_buffer._clear() | ||
| trainer.train_rft_from_priorzero_batch(all_data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO:可以控制训练的off_policy程度
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO:llm训练速度是否过慢呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
world_model也是从replay_buffer中采样训练的吧?只清空llm训练的,不能清空world-model训练的,world_model训练是需要比较大的buffer size的
| if coordinator.can_collect(): | ||
| logger.info(f"\n[Iter {learner.train_iter}] Starting async collect...") | ||
|
|
||
| async def collect_fn(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
collect和train异步这个目前测试通过了吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我一直都没用这个异步的;因为train不是应该等collect结束以后再sample吗?这里应该不能异步吧
| self.log_state_to_tb() | ||
|
|
||
| def _broadcast_to_vllm(self): | ||
| use_prefix_cache = getattr(self.strategy.args, "enable_prefix_caching", False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
权重更新目前是串行的,可以改成@ray.remote的异步更新,另外应该不需要每次llm train后都更新,可以设置一个频率
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
但现在每次llm train完,一次要训 256 * 10条数据,应该已经很多了呀;按照batch_size=64来说的话,都更新40次模型参数了;还不需要同步吗
…ish sys-template and max-gen-length, use k3 kl and batch_size=128
这个 PR 主要完善了 PriorZero的实现与开发流程,修复了若干影响训练正确性和稳定性的关键问题,并对训练逻辑、损失计算、数据采集进行了系统性的增强。
本 PR 已完成的工作
• 修复了 PriorZero 训练流程中的多个关键 bug,包括 game segment 构建、loss 计算、log-prob 对齐以及 action 处理中的错误。
• 完善了 REINFORCE / RFT 风格的策略优化实现,在 buffer 中正确存储并使用 old_logprob,保证策略更新的正确性。
• 补充并规范了训练过程中的统计指标,包括 KL divergence、policy entropy 等,用于更好地监控训练状态。
• 优化了 Collector 与 Replay Buffer 的数据流转逻辑,提升数据一致性与采样稳定性,减少隐式错误。
• 引入并验证了单卡场景下的 vLLM 权重同步机制。
• 多 GPU / 多节点场景下的 vLLM 权重同步与稳定性验证