Skip to content
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
a3a2d69
Fix game_segment/weighted_total_loss bugs and refine prompts, compute…
xiongjyu Nov 20, 2025
959a558
Fixed the accumulate_steps bug and added cprofile functionality.
xiongjyu Nov 20, 2025
ecedc5f
Refine the code and fix the bug in data collection.
xiongjyu Nov 22, 2025
2d53d22
Add REINFORCE-style losses and store old_logprob in the buffer.
xiongjyu Nov 23, 2025
c608600
Fix the get_llm_prior bug so that every action receives a logprob
xiongjyu Nov 24, 2025
15e39f6
fixed the history bug in the build_llm_prompt and logs in forward_learn
xiongjyu Nov 24, 2025
7c9acd9
rename advantage_tensor on rft
xiongjyu Nov 24, 2025
738f300
Fixed the action out-of-bounds bug and added a record for forward_col…
xiongjyu Nov 26, 2025
0a166f6
Fixed the misalignment between old_log_prob and log_prob, and correct…
xiongjyu Nov 27, 2025
4f3668e
add some logs for analysying
xiongjyu Nov 27, 2025
2985e60
Polish the code and standardize the format.
xiongjyu Nov 29, 2025
ff98006
Add kL divergence in rft and llm_prior_entropy in collect
xiongjyu Dec 2, 2025
7e43e45
polish config and format
xiongjyu Dec 3, 2025
d6555e5
delete unused files
xiongjyu Dec 3, 2025
b7d42ee
Decouple the training of world_model and LLM.
xiongjyu Dec 9, 2025
95e2347
add cache in the jericho
xiongjyu Dec 10, 2025
9682486
Separate sync and async entry points to simplify the program.
xiongjyu Dec 10, 2025
0a38197
Reference OpenRLHF’s implementation to update vLLM weights in real ti…
xiongjyu Dec 14, 2025
e361039
delete unused orz files
xiongjyu Dec 14, 2025
f957db9
fix a small bug
xiongjyu Dec 15, 2025
628d7d2
Fix action='go' bug; optimize replay buffer with larger capacity; sam…
xiongjyu Dec 16, 2025
c16174f
fix a bug
xiongjyu Dec 16, 2025
35cb4f9
Optimized log-probability computation for the CoT setting.
xiongjyu Dec 17, 2025
2c67a8d
polish and format file
xiongjyu Dec 19, 2025
b16c3e7
Improve single/multi-process LLM training with DeepSpeed
xiongjyu Dec 26, 2025
97c9843
fix the vllm bug when using torchrun
xiongjyu Dec 27, 2025
eb8a4bd
fix the fork bug and polish the vllm about sleep
xiongjyu Dec 27, 2025
9178397
Optimized efficiency and added multiple ways to calculate advantage.
xiongjyu Dec 28, 2025
cb6f7cf
limit the ouput length when using vllm
xiongjyu Dec 28, 2025
19fac8f
polish(pu): add cot-reuse in training, use running-norm in value, pol…
puyuan1996 Dec 29, 2025
88f047b
fix(pu): fix some bugs in reuse-collect-cot in training phase
puyuan1996 Dec 29, 2025
de4b2c0
polish configs and format
xiongjyu Dec 30, 2025
2069e32
delete unuse config
xiongjyu Dec 30, 2025
3ff091e
fix not found go bug
xiongjyu Dec 30, 2025
3888d8e
fix the misalignment bug when reusing cot
xiongjyu Jan 6, 2026
da0d0fd
make the prompt more compact
xiongjyu Jan 6, 2026
5f88151
add lr warmup
xiongjyu Jan 7, 2026
ed89062
add warmup for training world model before training llm and AdaptiveV…
xiongjyu Jan 7, 2026
96dc250
polish the implementation of profile
xiongjyu Jan 7, 2026
66ac376
fix a small bug
xiongjyu Jan 8, 2026
a9593cd
add profile of forward_collect
xiongjyu Jan 8, 2026
5d0f359
add format reward option and fix the cot gradient
xiongjyu Jan 8, 2026
e15e66c
rename kl/clip-ratio metrics
xiongjyu Jan 14, 2026
55e66ed
Optimize the use of format rewards
xiongjyu Jan 14, 2026
0cef8b8
polish the format
xiongjyu Jan 14, 2026
f88989b
Add a complete DDP program, including the process of collect and worl…
xiongjyu Jan 14, 2026
ffcf348
fix some small bugs
xiongjyu Jan 14, 2026
591c420
add the result of valid_actions's prob in vllm_output and samples shu…
xiongjyu Jan 15, 2026
091cc80
add the policy model for reload/offload func
xiongjyu Jan 16, 2026
e512707
fix a bug
xiongjyu Jan 16, 2026
a6bdc0e
Fixed a bug in the advantage feature and used target_value-pred_value…
xiongjyu Jan 17, 2026
250ba63
Optimize parameter definitions and off-policy's implementation
xiongjyu Jan 18, 2026
c3af1c2
fix a small bug
xiongjyu Jan 18, 2026
9d304c6
fix a small bug
xiongjyu Jan 19, 2026
9654f6a
Optimize essential logging; add grad norm metrics; record metrics as …
xiongjyu Feb 1, 2026
1bbcc19
fix a small bug
xiongjyu Feb 2, 2026
335e16c
Fix a bug; refactor LLM prompts into system/user roles; improve CoT o…
xiongjyu Feb 2, 2026
acf5d04
add input/response length metrics and envstep-based tb_logger for lea…
xiongjyu Feb 4, 2026
2ff4a90
add llm_prior_tempearture to apply_temperature_scaling for llm_prior
xiongjyu Feb 5, 2026
ea32a4a
Fixed bugs in prefix_cot & padding and aligned the input contexts of …
xiongjyu Feb 7, 2026
f895ed2
Remove the logits_to_keep parameter and add metrics such as entropy.
xiongjyu Feb 8, 2026
4567190
tmp
xiongjyu Feb 8, 2026
5098395
refine some log
xiongjyu Feb 9, 2026
b865dcb
polish some config
xiongjyu Feb 23, 2026
82e1d29
Fixed the bug caused by fmt_weight being 1
xiongjyu Feb 24, 2026
7c9c922
delete unused file
xiongjyu Feb 24, 2026
2dcc4ae
fix score, action_str, and timestep_dict bug, and add 3 modes of Prio…
xiongjyu Feb 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 25 additions & 226 deletions lzero/mcts/buffer/game_buffer_priorzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,158 +18,6 @@
from typing import List, Any, Union, Tuple
from lzero.mcts.buffer.game_buffer_unizero import UniZeroGameBuffer


class PriorZeroGameBuffer(UniZeroGameBuffer):
"""
[PRIORZERO-MODIFIED]
Enhanced GameBuffer that provides game_segments for LLM policy training.

Modifications:
1. sample() returns game_segments as 4th element
2. Efficient implementation using existing game_segment_list from _make_batch
3. No additional memory overhead (returns references, not copies)
"""

def __init__(self, cfg):
"""Initialize PriorZero Game Buffer."""
super().__init__(cfg)

# [PRIORZERO-NEW] Cache for the last sampled game segments
# This avoids re-sampling when we need game segments
self._last_sampled_game_segments = None
self._last_sampled_batch_indices = None

def sample(
self,
batch_size: int,
policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]
) -> List[Any]:
"""
[PRIORZERO-MODIFIED]
Sample data and prepare current_batch, target_batch, AND game_segments.

Returns:
train_data: [current_batch, target_batch, game_segments]
- current_batch: [obs, action, target_action, mask, indices, weights, make_time, timestep]
- target_batch: [rewards, values, policies]
- game_segments: List of GameSegment objects used in this batch

Note:
game_segments are returned for LLM training (SFT/RFT).
They contain:
- mcts_policy_segment: MCTS visit distributions (for SFT supervision)
- raw_obs_segment: Raw text observations (for LLM prompts)
- reward_segment: Environment rewards (for RFT)
- search_value_segment: MCTS search values (for analysis)
"""
policy._target_model.to(self._cfg.device)
policy._target_model.eval()

# ======================================================================
# [PRIORZERO-KEY] Sample data and extract game_segments
# ======================================================================
# obtain the current_batch and prepare target context
reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch(
batch_size, self._cfg.reanalyze_ratio
)

# [PRIORZERO-NEW] Extract game_segments from the sampling process
# These were already created in _make_batch, we just need to save them
game_segments = self._last_sampled_game_segments

# Defensive check: ensure game_segments match batch_size
if game_segments is None or len(game_segments) != len(current_batch[4]): # current_batch[4] is batch_index_list
# Fallback: create empty list if something went wrong
import logging
logging.warning(
f"[PriorZeroBuffer] game_segments mismatch: "
f"expected {len(current_batch[4])}, got {len(game_segments) if game_segments else None}. "
f"Falling back to empty list (SFT/RFT will be skipped)."
)
game_segments = []

# ======================================================================
# Standard UniZero processing (unchanged)
# ======================================================================
# current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list]

# target reward, target value
batch_rewards, batch_target_values = self._compute_target_reward_value(
reward_value_context, policy._target_model, current_batch[2], current_batch[-1] # current_batch[2] is batch_target_action
)

# target policy
batch_target_policies_re = self._compute_target_policy_reanalyzed(
policy_re_context, policy._target_model, current_batch[1], current_batch[-1]
) # current_batch[1] is batch_action
batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed(
policy_non_re_context, self.action_space_size
)

# fusion of batch_target_policies_re and batch_target_policies_non_re to batch_target_policies
if 0 < self._cfg.reanalyze_ratio < 1:
batch_target_policies = np.concatenate([batch_target_policies_re, batch_target_policies_non_re])
elif self._cfg.reanalyze_ratio == 1:
batch_target_policies = batch_target_policies_re
elif self._cfg.reanalyze_ratio == 0:
batch_target_policies = batch_target_policies_non_re

target_batch = [batch_rewards, batch_target_values, batch_target_policies]

# ======================================================================
# [PRIORZERO-KEY] Return current_batch, target_batch, AND game_segments
# ======================================================================
train_data = [current_batch, target_batch, game_segments]
return train_data

def _sample_orig_data(self, batch_size: int) -> Tuple[Any]:
"""
[PRIORZERO-MODIFIED]
Override to cache game_segments during sampling.

This avoids double sampling by caching the result for sample() to use.
"""
# Call parent implementation
result = super()._sample_orig_data(batch_size)

# Cache the game_segment_list (first element of result tuple)
game_segment_list = result[0]
self._last_sampled_game_segments = game_segment_list
self._last_sampled_batch_indices = result[2] # batch_index_list

return result

def _sample_orig_data_episode(self, batch_size: int) -> Tuple[Any]:
"""
[PRIORZERO-MODIFIED]
Override to cache game_segments during episode sampling.

This avoids double sampling by caching the result for sample() to use.
"""
# Call parent implementation
result = super()._sample_orig_data_episode(batch_size)

# Cache the game_segment_list (first element of result tuple)
game_segment_list = result[0]
self._last_sampled_game_segments = game_segment_list
self._last_sampled_batch_indices = result[2] # batch_index_list

return result

def clear(self):
"""
[PRIORZERO-MODIFIED]
Clear buffer and cached game segments.
"""
super().clear()
self._last_sampled_game_segments = None
self._last_sampled_batch_indices = None


# ==============================================================================
# Optimized Alternative (Avoids Double Sampling)
# ==============================================================================

class PriorZeroGameBufferOptimized(UniZeroGameBuffer):
"""
[PRIORZERO-OPTIMIZED]
Expand All @@ -195,16 +43,14 @@ def sample(self, batch_size: int, policy) -> List[Any]:
batch_size, self._cfg.reanalyze_ratio
)

# Get cached game segments (set by our overridden _make_batch)
game_segments = self._cached_game_segments or []

obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list, raw_obs_list, history_obs_list, action_logprob_list = current_batch
# Standard processing
batch_rewards, batch_target_values = self._compute_target_reward_value(
reward_value_context, policy._target_model, current_batch[2], current_batch[-1]
reward_value_context, policy._target_model, current_batch[2], timestep_list
)

batch_target_policies_re = self._compute_target_policy_reanalyzed(
policy_re_context, policy._target_model, current_batch[1], current_batch[-1]
policy_re_context, policy._target_model, current_batch[1], timestep_list
)
batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed(
policy_non_re_context, self.action_space_size
Expand All @@ -219,7 +65,7 @@ def sample(self, batch_size: int, policy) -> List[Any]:

target_batch = [batch_rewards, batch_target_values, batch_target_policies]

return [current_batch, target_batch, game_segments]
return [current_batch, target_batch]

def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
"""
Expand All @@ -243,6 +89,8 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
# Rest of the code is identical to parent's _make_batch
batch_size = len(batch_index_list)
obs_list, action_list, mask_list = [], [], []
raw_obs_list, history_obs_list = [], []
action_logprob_list = []
timestep_list = []
bootstrap_action_list = []

Expand Down Expand Up @@ -272,6 +120,16 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
)
)
raw_obs_list.append(game_segment_list[i].get_unroll_raw_obs(
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
))
history_obs_list.append(game_segment_list[i].get_unroll_histroy_obs(
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
))
action_logprob_list.append(game_segment_list[i].get_unroll_action_logprob(
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
))

action_list.append(actions_tmp)
mask_list.append(mask_tmp)
timestep_list.append(timestep_tmp)
Expand All @@ -291,6 +149,10 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list]
for i in range(len(current_batch)):
current_batch[i] = np.asarray(current_batch[i])

current_batch.append(raw_obs_list)
current_batch.append(history_obs_list)
current_batch.append(action_logprob_list)

total_transitions = self.get_num_of_transitions()

Expand Down Expand Up @@ -319,71 +181,8 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:

return reward_value_context, policy_re_context, policy_non_re_context, current_batch


# ==============================================================================
# Factory Function
# ==============================================================================

def create_priorzero_buffer(cfg, optimized: bool = True):
"""
Factory function to create PriorZero game buffer.

Args:
cfg: Configuration dict
optimized: If True, use optimized version (recommended)

Returns:
buffer: PriorZero game buffer instance
"""
if optimized:
return PriorZeroGameBufferOptimized(cfg)
else:
return PriorZeroGameBuffer(cfg)


if __name__ == "__main__":
print("="*80)
print("PriorZero Game Buffer - Unit Tests")
print("="*80)

# Create mock config
class MockConfig:
def __init__(self):
self.device = 'cpu'
self.env_type = 'not_board_games'
self.game_segment_length = 200
self.num_unroll_steps = 5
self.td_steps = 5
self.batch_size = 32
self.use_priority = False
self.reanalyze_ratio = 0.0
self.sample_type = 'transition'
self.replay_buffer_size = 10000
self.model = type('obj', (object,), {
'model_type': 'mlp',
'action_space_size': 10,
'observation_shape': 128,
})()

cfg = MockConfig()

# Test both versions
for name, buffer_class in [
("Standard", PriorZeroGameBuffer),
("Optimized", PriorZeroGameBufferOptimized)
]:
print(f"\nTesting {name} Buffer:")
print("-" * 40)

buffer = buffer_class(cfg)
print(f"✓ Buffer created: {type(buffer).__name__}")
print(f" - sample_type: {buffer.sample_type}")
print(f" - action_space_size: {buffer.action_space_size}")

# Note: Full testing would require mock GameSegments and Policy
# For now, just verify instantiation
print(f"✓ {name} buffer initialized successfully")

print("\n" + "="*80)
print("✓ All tests passed!")
print("="*80)
def _clear(self):
self.game_pos_priorities = []
self.game_segment_buffer = []
self.game_segment_game_pos_look_up = []

2 changes: 1 addition & 1 deletion lzero/model/unizero_world_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2064,7 +2064,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
value_priority=value_priority,
intermediate_tensor_x=intermediate_tensor_x,
obs_embeddings=detached_obs_embeddings, # <-- 新增
)
), inverse_scalar_transform_handle(outputs.logits_value.reshape(-1, outputs.logits_value.shape[-1])).detach()


# TODO: test correctness
Expand Down
10 changes: 0 additions & 10 deletions lzero/worker/muzero_segment_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,16 +477,6 @@ def collect(
if self.policy_config.use_ture_chance_label_in_chance_encoder:
append_kwargs['chance'] = self.chance_dict_tmp[env_id]

# [PRIORZERO-NEW] Add raw_obs_text if available in obs (not info!)
# Jericho env puts raw_obs_text in the obs dictionary
if env_id == 0 and collected_step < 5: # Debug first few steps
print(f"[OBS_DEBUG] Step {collected_step} env {env_id}: obs keys = {list(obs.keys())}")
print(f"[OBS_DEBUG] obs type = {type(obs)}")
if 'raw_obs_text' in obs:
print(f"[OBS_DEBUG] Found raw_obs_text: {str(obs['raw_obs_text'])[:100]}...")
else:
print(f"[OBS_DEBUG] NO raw_obs_text in obs!")

if 'raw_obs_text' in obs:
append_kwargs['raw_obs_text'] = obs['raw_obs_text']
elif 'raw_obs_text' in info:
Expand Down
26 changes: 23 additions & 3 deletions zoo/jericho/envs/jericho_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from collections import OrderedDict

import gym
import numpy as np
Expand Down Expand Up @@ -49,12 +50,13 @@ class JerichoEnv(BaseEnv):
'max_seq_len': 512,
'remove_stuck_actions': False,
'add_location_and_inventory': False,
# 'for_unizero': False,
'for_unizero': True,
'save_replay': False,
'save_replay_path': None,
'env_type': "zork1",
'collect_policy_mode': "agent"
'collect_policy_mode': "agent",
'use_cache': True,
'cache_size': 100000,
}

def __init__(self, cfg: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -93,6 +95,12 @@ def __init__(self, cfg: Dict[str, Any]) -> None:
self.add_location_and_inventory: bool = self.cfg['add_location_and_inventory']
self.for_unizero: bool = self.cfg['for_unizero']

self.use_cache = self.cfg['use_cache']
if self.use_cache:
self.cache_size = self.cfg['cache_size']
self.cache_buffer = OrderedDict()
print(f'[jericho]: use_cache: {self.use_cache}, cache_size={self.cache_size}')

# Initialize the tokenizer once (only in rank 0 process if distributed)
if JerichoEnv.tokenizer is None:
if self.rank == 0:
Expand Down Expand Up @@ -138,7 +146,18 @@ def prepare_obs(self, obs: str, return_str: bool = False) -> Dict[str, Any]:
raw_obs_text = obs # Save original text BEFORE any modification

if self._action_list is None:
self._action_list = self._env.get_valid_actions()
if self.use_cache:
cache_key = self._env.get_world_state_hash()
if cache_key in self.cache_buffer:
self.cache_buffer.move_to_end(cache_key)
self._action_list = self.cache_buffer[cache_key]
else:
self._action_list = self._env.get_valid_actions()
self.cache_buffer[cache_key] = self._action_list
if len(self.cache_buffer) > self.cache_size:
self.cache_buffer.popitem(last=False)
else:
self._action_list = self._env.get_valid_actions()

# Filter available actions based on whether stuck actions are removed.
if self.remove_stuck_actions:
Expand Down Expand Up @@ -344,6 +363,7 @@ def step(self, action: Union[int, np.ndarray, str], return_str: bool = False) ->
previous_obs: Optional[str] = self.last_observation if (self.remove_stuck_actions and self.last_observation is not None) else None

observation, reward, done, info = self._env.step(action_str)
info['action_str'] = action_str

self._timestep += 1
if not self.for_unizero:
Expand Down
Loading