Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
a3a2d69
Fix game_segment/weighted_total_loss bugs and refine prompts, compute…
xiongjyu Nov 20, 2025
959a558
Fixed the accumulate_steps bug and added cprofile functionality.
xiongjyu Nov 20, 2025
ecedc5f
Refine the code and fix the bug in data collection.
xiongjyu Nov 22, 2025
2d53d22
Add REINFORCE-style losses and store old_logprob in the buffer.
xiongjyu Nov 23, 2025
c608600
Fix the get_llm_prior bug so that every action receives a logprob
xiongjyu Nov 24, 2025
15e39f6
fixed the history bug in the build_llm_prompt and logs in forward_learn
xiongjyu Nov 24, 2025
7c9acd9
rename advantage_tensor on rft
xiongjyu Nov 24, 2025
738f300
Fixed the action out-of-bounds bug and added a record for forward_col…
xiongjyu Nov 26, 2025
0a166f6
Fixed the misalignment between old_log_prob and log_prob, and correct…
xiongjyu Nov 27, 2025
4f3668e
add some logs for analysying
xiongjyu Nov 27, 2025
2985e60
Polish the code and standardize the format.
xiongjyu Nov 29, 2025
ff98006
Add kL divergence in rft and llm_prior_entropy in collect
xiongjyu Dec 2, 2025
7e43e45
polish config and format
xiongjyu Dec 3, 2025
d6555e5
delete unused files
xiongjyu Dec 3, 2025
b7d42ee
Decouple the training of world_model and LLM.
xiongjyu Dec 9, 2025
95e2347
add cache in the jericho
xiongjyu Dec 10, 2025
9682486
Separate sync and async entry points to simplify the program.
xiongjyu Dec 10, 2025
0a38197
Reference OpenRLHF’s implementation to update vLLM weights in real ti…
xiongjyu Dec 14, 2025
e361039
delete unused orz files
xiongjyu Dec 14, 2025
f957db9
fix a small bug
xiongjyu Dec 15, 2025
628d7d2
Fix action='go' bug; optimize replay buffer with larger capacity; sam…
xiongjyu Dec 16, 2025
c16174f
fix a bug
xiongjyu Dec 16, 2025
35cb4f9
Optimized log-probability computation for the CoT setting.
xiongjyu Dec 17, 2025
2c67a8d
polish and format file
xiongjyu Dec 19, 2025
b16c3e7
Improve single/multi-process LLM training with DeepSpeed
xiongjyu Dec 26, 2025
97c9843
fix the vllm bug when using torchrun
xiongjyu Dec 27, 2025
eb8a4bd
fix the fork bug and polish the vllm about sleep
xiongjyu Dec 27, 2025
9178397
Optimized efficiency and added multiple ways to calculate advantage.
xiongjyu Dec 28, 2025
cb6f7cf
limit the ouput length when using vllm
xiongjyu Dec 28, 2025
19fac8f
polish(pu): add cot-reuse in training, use running-norm in value, pol…
puyuan1996 Dec 29, 2025
88f047b
fix(pu): fix some bugs in reuse-collect-cot in training phase
puyuan1996 Dec 29, 2025
de4b2c0
polish configs and format
xiongjyu Dec 30, 2025
2069e32
delete unuse config
xiongjyu Dec 30, 2025
3ff091e
fix not found go bug
xiongjyu Dec 30, 2025
3888d8e
fix the misalignment bug when reusing cot
xiongjyu Jan 6, 2026
da0d0fd
make the prompt more compact
xiongjyu Jan 6, 2026
5f88151
add lr warmup
xiongjyu Jan 7, 2026
ed89062
add warmup for training world model before training llm and AdaptiveV…
xiongjyu Jan 7, 2026
96dc250
polish the implementation of profile
xiongjyu Jan 7, 2026
66ac376
fix a small bug
xiongjyu Jan 8, 2026
a9593cd
add profile of forward_collect
xiongjyu Jan 8, 2026
5d0f359
add format reward option and fix the cot gradient
xiongjyu Jan 8, 2026
e15e66c
rename kl/clip-ratio metrics
xiongjyu Jan 14, 2026
55e66ed
Optimize the use of format rewards
xiongjyu Jan 14, 2026
0cef8b8
polish the format
xiongjyu Jan 14, 2026
f88989b
Add a complete DDP program, including the process of collect and worl…
xiongjyu Jan 14, 2026
ffcf348
fix some small bugs
xiongjyu Jan 14, 2026
591c420
add the result of valid_actions's prob in vllm_output and samples shu…
xiongjyu Jan 15, 2026
091cc80
add the policy model for reload/offload func
xiongjyu Jan 16, 2026
e512707
fix a bug
xiongjyu Jan 16, 2026
a6bdc0e
Fixed a bug in the advantage feature and used target_value-pred_value…
xiongjyu Jan 17, 2026
250ba63
Optimize parameter definitions and off-policy's implementation
xiongjyu Jan 18, 2026
c3af1c2
fix a small bug
xiongjyu Jan 18, 2026
9d304c6
fix a small bug
xiongjyu Jan 19, 2026
9654f6a
Optimize essential logging; add grad norm metrics; record metrics as …
xiongjyu Feb 1, 2026
1bbcc19
fix a small bug
xiongjyu Feb 2, 2026
335e16c
Fix a bug; refactor LLM prompts into system/user roles; improve CoT o…
xiongjyu Feb 2, 2026
acf5d04
add input/response length metrics and envstep-based tb_logger for lea…
xiongjyu Feb 4, 2026
2ff4a90
add llm_prior_tempearture to apply_temperature_scaling for llm_prior
xiongjyu Feb 5, 2026
ea32a4a
Fixed bugs in prefix_cot & padding and aligned the input contexts of …
xiongjyu Feb 7, 2026
f895ed2
Remove the logits_to_keep parameter and add metrics such as entropy.
xiongjyu Feb 8, 2026
4567190
tmp
xiongjyu Feb 8, 2026
5098395
refine some log
xiongjyu Feb 9, 2026
b865dcb
polish some config
xiongjyu Feb 23, 2026
82e1d29
Fixed the bug caused by fmt_weight being 1
xiongjyu Feb 24, 2026
7c9c922
delete unused file
xiongjyu Feb 24, 2026
2dcc4ae
fix score, action_str, and timestep_dict bug, and add 3 modes of Prio…
xiongjyu Feb 25, 2026
dbec27c
fix self.history_buffer bug and refine logs
xiongjyu Feb 26, 2026
e71eb61
fix some bug when running unizero
xiongjyu Feb 28, 2026
e3f7cfd
fix a bug in evaluator and add enable_rft/wm to control what models n…
xiongjyu Feb 28, 2026
7377220
add bash scripts to run priorzero and format files
xiongjyu Feb 28, 2026
d5c4923
tmp
xiongjyu Feb 28, 2026
3caf28f
add priorzero README.md
xiongjyu Feb 28, 2026
cfc000a
refine cot prefix and add mcts_root_logits_dict
xiongjyu Feb 28, 2026
8bec5f7
tmp
xiongjyu Feb 28, 2026
f49c9d9
add user_prompt_dict to control user_prompt about reward/valid_action…
xiongjyu Feb 28, 2026
9a459f8
fix the bug of pretrained_model and some small bugs
xiongjyu Mar 4, 2026
38efe3d
fix the evaluator_env_num when steping into the _init_eval
xiongjyu Mar 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 30 additions & 185 deletions zoo/jericho/priorzero/priorzero_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ding.envs import BaseEnvManager
from ding.torch_utils import to_ndarray
from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, allreduce_data
from vllm import AsyncLLMEngine, SamplingParams
from vllm import SamplingParams
import os

# Import from local LightZero
Expand Down Expand Up @@ -74,26 +74,24 @@ def extract_raw_obs_text(obs_dict: Dict[str, Any]) -> str:
class PriorZeroCollector(OriginalCollector):
"""
[PRIORZERO-MODIFIED]
Async collector that integrates LLM priors into MCTS-based data collection.

Features:
- Async LLM inference with vLLM engine
- History buffer for each environment (sliding window)
- Robust error handling with retries
- Detailed logging of LLM prior statistics
"""

def __init__(
self,
vllm_engine: AsyncLLMEngine,
llm_prior_generator,
policy_config: Dict,
**kwargs
):
"""
Initialize PriorZeroCollector.

Args:
vllm_engine: vLLM async engine for LLM inference
vllm_engine
policy_config: Policy configuration (contains llm_policy_cfg)
**kwargs: Additional arguments for parent class
"""
Expand All @@ -103,9 +101,7 @@ def __init__(

super().__init__(**kwargs)

self.vllm_engine = vllm_engine
self._vllm_tokenizer = None
# self.policy_config already set by parent class from kwargs
self.llm_prior_generator = llm_prior_generator
self.llm_policy_cfg = policy_config.llm_policy_cfg

# [PRIORZERO-NEW] History buffer for each environment
Expand Down Expand Up @@ -195,124 +191,37 @@ def pad_and_save_last_trajectory(
last_game_segments[i] = None
last_game_priorities[i] = None

async def _get_tokenizer(self):
"""
从 vLLM 引擎获取已加载的 tokenizer 引用。
只在第一次调用时会有极小的 async 开销,之后直接返回内存引用。
"""
if self._vllm_tokenizer is None:
self._vllm_tokenizer = await self.vllm_engine.get_tokenizer()
return self._vllm_tokenizer

async def _async_get_llm_prior(
def _get_llm_prior(
self,
states: List[str],
request_ids: List[str],
valid_actions_list: List[List[str]],
histories: Optional[List[List[Tuple[str, str, float]]]] = None,
timeout: float = 30.0
) -> List[Any]:
"""
[PRIORZERO-SEQUENCE-SCORING]
Async call to calculate the log-probability of full action sequences.
Ensures every action has a logprob by retrying and falling back if needed.
"""

assert self.vllm_engine is not None, "vLLM engine is not initialized."
tokenizer = await self._get_tokenizer()

max_retry = 3
fallback_lp = -1e3

async def run_once(target_missing: List[set], retry_idx: int):
all_prompts_data = []
for i, state in enumerate(states):
if len(target_missing[i]) == 0:
continue
history = histories[i]
instruction = build_llm_prompt(
current_obs=state,
history=history,
use_cot=self.llm_policy_cfg.use_cot
)
context_text = tokenizer.apply_chat_template(
[{"role": "user", "content": instruction}],
tokenize=False,
add_generation_prompt=True
)
context_tokens = tokenizer.encode(context_text)
context_len = len(context_tokens)

actions = list(target_missing[i])

for act_idx, action in enumerate(actions):
formatted_action = f"{action}{tokenizer.eos_token}"
full_text = context_text + formatted_action
unique_req_id = f"{request_ids[i]}_act_{act_idx}_retry{retry_idx}"
all_prompts_data.append({
"idx": i,
"action_str": action,
"full_text": full_text,
"context_len": context_len,
"req_id": unique_req_id
})

sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1,
prompt_logprobs=1,
)

async def get_sequence_score(item):
results_generator = self.vllm_engine.generate(item["full_text"], sampling_params, item["req_id"])
final_output = None
async for request_output in results_generator:
final_output = request_output

action_logprobs_list = final_output.prompt_logprobs[item["context_len"]:]
total_score, valid_tokens = 0.0, 0
for token_dict in action_logprobs_list:
if token_dict:
lp_obj = next(iter(token_dict.values()))
total_score += lp_obj.logprob
valid_tokens += 1
if valid_tokens == 0:
return item["idx"], item["action_str"], None
return item["idx"], item["action_str"], total_score / valid_tokens

tasks = [get_sequence_score(item) for item in all_prompts_data]
results = await asyncio.wait_for(asyncio.gather(*tasks), timeout=timeout)
final_priors = [{} for _ in range(len(states))]
for i, action_str, score in results:
if score is not None:
final_priors[i][action_str] = score
return final_priors

priors = [{} for _ in range(len(states))]
missing = [set(actions) for actions in valid_actions_list]

for retry_idx in range(max_retry + 1):
try:
new_priors = await run_once(missing, retry_idx)
except Exception as e:
self._logger.error(f"Batch LLM critical error (retry {retry_idx}): {e}")
new_priors = [{} for _ in range(len(states))]

for i in range(len(states)):
priors[i].update(new_priors[i])
missing[i] -= set(new_priors[i].keys())

if all(len(m) == 0 for m in missing):
break

# Fill any remaining missing actions with fallback
for i, remaining in enumerate(missing):
if remaining:
self._logger.warning(f"[LLM prior] missing actions after retries, fill fallback: {remaining}")
for act in remaining:
priors[i][act] = fallback_lp

return priors
assert self.llm_prior_generator is not None, "llm_prior_generator is None."
all_prompts = []
all_labels = []
for i, actions in enumerate(valid_actions_list):
state = states[i]
history = histories[i]
prompt = build_llm_prompt(current_obs=state, history=history, use_cot=self.llm_policy_cfg.use_cot)
for action in actions:
all_prompts.append(prompt)
all_labels.append(action)

all_prior_scores = self.llm_prior_generator._generate_vllm(all_prompts, all_labels, reduction='mean')
llm_prior, idx = [], 0
for env_id in range(len(states)):
tmp_dict = {}
for action in valid_actions_list[env_id]:
tmp_dict[action] = all_prior_scores[idx]
idx = idx + 1
llm_prior.append(tmp_dict)
return llm_prior

@contextmanager
def _profile_block(self, name: str):
Expand Down Expand Up @@ -341,59 +250,8 @@ def _record_profile_time(self, name: str, elapsed: float) -> None:
f"{time.time():.3f}\tname={name}\tcount={self._profile_stats[name]['count']}\t"
f"total_s={self._profile_stats[name]['total']:.4f}\tavg_s={avg:.4f}\tmax_s={self._profile_stats[name]['max']:.4f}\n"
)

async def _log_llm_response(
self,
raw_obs_text: str,
history: List[Tuple[str, str, float]],
valid_actions: List[str],
) -> None:
"""
Periodically log LLM output for a debug prompt and current valid actions.
"""
self._llm_call_count += 1
if self._llm_call_count != 1 and (self._llm_call_count % self.prompt_log_interval != 0):
return
tokenizer = await self._get_tokenizer()
instruction = build_llm_prompt(
current_obs=raw_obs_text,
history=history,
use_cot=self.llm_policy_cfg.use_cot,
)
prompt_text = tokenizer.apply_chat_template(
[{"role": "user", "content": instruction}],
tokenize=False,
add_generation_prompt=True,
)
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=self.llm_policy_cfg.generate_max_len,
top_p=1.0,
)
try:
result_gen = self.vllm_engine.generate(
prompt_text,
sampling_params,
request_id=f"llm_call_count_{self._llm_call_count}",
)
async for request_output in result_gen:
if request_output.finished:
llm_output_text = request_output.outputs[0].text or ""
break
except Exception as e:
llm_output_text = f"[LLM logging error: {repr(e)}]"
llm_output_text = llm_output_text.strip()

with open(self._llm_output_log_path, mode='a', encoding='utf-8') as f:
f.write(
f"llm_call_count={self._llm_call_count}\t"
f"valid_actions={valid_actions}\n"
f"llm_input={prompt_text}\n"
f"llm_output={llm_output_text}\n"
"----\n"
)

async def collect(

def collect(
self,
num_segments: Optional[int] = None,
train_iter: int = 0,
Expand All @@ -406,9 +264,8 @@ async def collect(

Main changes from parent:
1. Extract text observations from environment
2. Async call to LLM to get action priors
3. Pass LLM priors to policy forward pass
4. Update history buffers after each step
2. Pass LLM priors to policy forward pass
3. Update history buffers after each step

Args:
num_segments: Number of segments to collect
Expand Down Expand Up @@ -535,24 +392,12 @@ async def collect(
valid_actions_list.append(valid_actions)

if self.policy_config.llm_policy_cfg.enable_llm:
request_ids = []
for _ in range(len(raw_obs_list)):
self._llm_prior_req_counter += 1
request_ids.append(f"collect_{self._llm_prior_req_counter}")

with self._profile_block(name='collect_get_llm_prior_profile'):
llm_prior_logprob = await self._async_get_llm_prior(
llm_prior_logprob = self._get_llm_prior(
states=raw_obs_list,
request_ids=request_ids,
valid_actions_list=valid_actions_list, # [PRIORZERO] Pass valid actions
histories=histories_list
)
if raw_obs_list:
await self._log_llm_response(
raw_obs_text=raw_obs_list[0],
history=histories_list[0],
valid_actions=valid_actions_list[0],
)
else:
llm_prior_logprob = [None for i in range(len(valid_actions_list))]

Expand Down
42 changes: 26 additions & 16 deletions zoo/jericho/priorzero/priorzero_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,12 @@ def get_priorzero_config(
## LLM 参数
# llm_model_name = "Qwen/Qwen2.5-1.5B-Instruct" # Smaller model for faster iteration
llm_model_name = "/mnt/afs/wanzunian/niuyazhe/xiongjyu/models/Qwen2.5-0.5B-Instruct"
total_batch_size = 128 # Total batch size across all GPUs
train_batch_size = 128 # Total batch size across all GPUs
GPUS = 1
micro_batch_size = 16 # Micro batch size per GPU
gradient_accumulation_steps = total_batch_size // micro_batch_size
gradient_accumulation_steps = train_batch_size // micro_batch_size // GPUS
rft_loss_type = 'reinforce++' # 'reinforce' | 'reinforce++' | 'ppo-simple-adv'
use_cot = True # Whether to use chain-of-thought prompting
use_cot = False # Whether to use chain-of-thought prompting
history_length = 5
llm_learn_num_samples = 512
replay_buffer_size = llm_learn_num_samples
Expand Down Expand Up @@ -179,28 +180,37 @@ def get_priorzero_config(
priority_prob_alpha=0.6,
priority_prob_beta=0.4,
llm_policy_cfg=dict(
# 是否使用大模型的相关参数
enable_llm=True,
pretrain_llm_path=llm_model_name,
history_length=history_length,
use_cot=use_cot,
llm_learn_num_samples=llm_learn_num_samples,
enable_sft=False,
enable_rft=True,
rft_loss_type=rft_loss_type,
rft_clip_epsilon=0.2,
rft_kl_coef=0.01,

llm_learning_rate=1e-5,
llm_weight_decay=0.01,
sft_loss_weight=1, # Weight of SFT loss in total loss
rft_loss_weight=1,
llm_micro_batch_size=micro_batch_size,

llm_gradient_accumulation_steps=gradient_accumulation_steps,
prompt_log_interval=1000, # 隔多久step输出模型的回答和valid action进行对比

# 模型相关参数
pretrain_llm_path=llm_model_name,
history_length=history_length,
use_cot=use_cot,
prompt_max_len=2048,
generate_max_len=128,
temperature = 1.0,
top_p = 1.0,

# 训练相关参数
zero_stage=0,
train_batch_size=train_batch_size,
micro_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
learning_rate=1e-5,
weight_decay=0.01,

# loss相关参数
rft_loss_type=rft_loss_type,
rft_clip_epsilon=0.2,
rft_kl_coef=0.01,

# vllm 相关参数
vllm_tensor_parallel_size=1,
gpu_memory_utilization=0.2,
),
Expand Down
Loading