diff --git a/jericho_priorzero_1020.yml b/jericho_priorzero_1020.yml deleted file mode 100644 index f4ef0b339..000000000 --- a/jericho_priorzero_1020.yml +++ /dev/null @@ -1,454 +0,0 @@ -name: base -channels: - - pytorch - - nvidia - - defaults -dependencies: - - _libgcc_mutex=0.1=main - - _openmp_mutex=5.1=1_gnu - - anaconda-anon-usage=0.4.4=py310hc06175d_0 - - archspec=0.2.3=pyhd3eb1b0_0 - - asttokens=2.0.5=pyhd3eb1b0_0 - - attrs=23.1.0=py310h06a4308_0 - - beautifulsoup4=4.12.2=py310h06a4308_0 - - blas=1.0=mkl - - boltons=23.0.0=py310h06a4308_0 - - brotli-python=1.0.9=py310h6a678d5_8 - - bzip2=1.0.8=h5eee18b_6 - - c-ares=1.19.1=h5eee18b_0 - - ca-certificates=2024.3.11=h06a4308_0 - - certifi=2024.2.2=py310h06a4308_0 - - cffi=1.16.0=py310h5eee18b_1 - - chardet=4.0.0=py310h06a4308_1003 - - charset-normalizer=2.0.4=pyhd3eb1b0_0 - - click=8.1.7=py310h06a4308_0 - - cmake=3.26.4=h96355d8_0 - - conda=23.5.2=py310h06a4308_0 - - conda-build=24.3.0=py310h06a4308_0 - - conda-content-trust=0.2.0=py310h06a4308_1 - - conda-index=0.4.0=pyhd3eb1b0_0 - - conda-libmamba-solver=23.7.0=py310h06a4308_0 - - conda-package-handling=2.2.0=py310h06a4308_1 - - conda-package-streaming=0.9.0=py310h06a4308_0 - - cryptography=42.0.5=py310hdda0065_1 - - cuda-cudart=12.1.105=0 - - cuda-cupti=12.1.105=0 - - cuda-libraries=12.1.0=0 - - cuda-nvrtc=12.1.105=0 - - cuda-nvtx=12.1.105=0 - - cuda-opencl=12.5.39=0 - - cuda-runtime=12.1.0=0 - - cuda-version=12.5=3 - - distro=1.9.0=py310h06a4308_0 - - exceptiongroup=1.2.0=py310h06a4308_0 - - executing=0.8.3=pyhd3eb1b0_0 - - expat=2.6.2=h6a678d5_0 - - ffmpeg=4.3=hf484d3e_0 - - fmt=9.1.0=hdb19cb5_1 - - freetype=2.12.1=h4a9f257_0 - - frozendict=2.4.2=py310h5eee18b_0 - - gmp=6.2.1=h295c915_3 - - gmpy2=2.1.2=py310heeb90bb_0 - - gnutls=3.6.15=he1e5248_0 - - icu=73.1=h6a678d5_0 - - idna=3.7=py310h06a4308_0 - - intel-openmp=2023.1.0=hdb19cb5_46306 - - ipython=8.20.0=py310h06a4308_0 - - jedi=0.18.1=py310h06a4308_1 - - jpeg=9e=h5eee18b_1 - - jsonpatch=1.33=py310h06a4308_1 - - jsonpointer=2.1=pyhd3eb1b0_0 - - jsonschema-specifications=2023.7.1=py310h06a4308_0 - - krb5=1.20.1=h143b758_1 - - lame=3.100=h7b6447c_0 - - lcms2=2.12=h3be6417_0 - - ld_impl_linux-64=2.38=h1181459_1 - - lerc=3.0=h295c915_0 - - libarchive=3.6.2=h6ac8c49_3 - - libcublas=12.1.0.26=0 - - libcufft=11.0.2.4=0 - - libcufile=1.10.0.4=0 - - libcurand=10.3.6.39=0 - - libcurl=8.7.1=h251f7ec_0 - - libcusolver=11.4.4.55=0 - - libcusparse=12.0.2.55=0 - - libdeflate=1.17=h5eee18b_1 - - libedit=3.1.20230828=h5eee18b_0 - - libev=4.33=h7f8727e_1 - - libffi=3.4.4=h6a678d5_1 - - libgcc-ng=11.2.0=h1234567_1 - - libgomp=11.2.0=h1234567_1 - - libiconv=1.16=h5eee18b_3 - - libidn2=2.3.4=h5eee18b_0 - - libjpeg-turbo=2.0.0=h9bf148f_0 - - liblief=0.12.3=h6a678d5_0 - - libmamba=1.5.8=hfe524e5_2 - - libmambapy=1.5.8=py310h2dafd23_2 - - libnghttp2=1.57.0=h2d74bed_0 - - libnpp=12.0.2.50=0 - - libnvjitlink=12.1.105=0 - - libnvjpeg=12.1.1.14=0 - - libpng=1.6.39=h5eee18b_0 - - libsolv=0.7.24=he621ea3_1 - - libssh2=1.11.0=h251f7ec_0 - - libstdcxx-ng=11.2.0=h1234567_1 - - libtasn1=4.19.0=h5eee18b_0 - - libtiff=4.5.1=h6a678d5_0 - - libunistring=0.9.10=h27cfd23_0 - - libuuid=1.41.5=h5eee18b_0 - - libuv=1.44.2=h5eee18b_0 - - libwebp-base=1.3.2=h5eee18b_0 - - libxml2=2.10.4=hfdd30dd_2 - - llvm-openmp=14.0.6=h9e868ea_0 - - lz4-c=1.9.4=h6a678d5_1 - - markupsafe=2.1.3=py310h5eee18b_0 - - matplotlib-inline=0.1.6=py310h06a4308_0 - - menuinst=2.1.0=py310h06a4308_0 - - mkl=2023.1.0=h213fc3f_46344 - - mkl-service=2.4.0=py310h5eee18b_1 - - mkl_fft=1.3.8=py310h5eee18b_0 - - mkl_random=1.2.4=py310hdb19cb5_0 - - more-itertools=10.1.0=py310h06a4308_0 - - mpc=1.1.0=h10f8cd9_1 - - mpfr=4.0.2=hb69a4c5_1 - - mpmath=1.3.0=py310h06a4308_0 - - ncurses=6.4=h6a678d5_0 - - nettle=3.7.3=hbbd107a_1 - - numpy=1.26.4=py310h5f9d8c6_0 - - numpy-base=1.26.4=py310hb5e798b_0 - - openh264=2.1.1=h4ff587b_0 - - openjpeg=2.4.0=h3ad879b_0 - - openssl=3.0.13=h7f8727e_2 - - packaging=23.2=py310h06a4308_0 - - parso=0.8.3=pyhd3eb1b0_0 - - patch=2.7.6=h7b6447c_1001 - - patchelf=0.17.2=h6a678d5_0 - - pcre2=10.42=hebb0a14_1 - - pexpect=4.8.0=pyhd3eb1b0_3 - - pillow=10.3.0=py310h5eee18b_0 - - pkginfo=1.10.0=py310h06a4308_0 - - platformdirs=3.10.0=py310h06a4308_0 - - prompt-toolkit=3.0.43=py310h06a4308_0 - - prompt_toolkit=3.0.43=hd3eb1b0_0 - - psutil=5.9.0=py310h5eee18b_0 - - ptyprocess=0.7.0=pyhd3eb1b0_2 - - pure_eval=0.2.2=pyhd3eb1b0_0 - - py-lief=0.12.3=py310h6a678d5_0 - - pybind11-abi=4=hd3eb1b0_1 - - pycosat=0.6.6=py310h5eee18b_1 - - pycparser=2.21=pyhd3eb1b0_0 - - pygments=2.15.1=py310h06a4308_1 - - pyopenssl=24.0.0=py310h06a4308_0 - - pysocks=1.7.1=py310h06a4308_0 - - python=3.10.14=h955ad1f_1 - - python-libarchive-c=2.9=pyhd3eb1b0_1 - - pytorch-cuda=12.1=ha16c6d3_5 - - pytorch-mutex=1.0=cuda - - pytz=2024.1=py310h06a4308_0 - - pyyaml=6.0.1=py310h5eee18b_0 - - readline=8.2=h5eee18b_0 - - referencing=0.30.2=py310h06a4308_0 - - reproc=14.2.4=h6a678d5_2 - - reproc-cpp=14.2.4=h6a678d5_2 - - requests=2.32.2=py310h06a4308_0 - - rhash=1.4.3=hdbd6064_0 - - rpds-py=0.10.6=py310hb02cf49_0 - - ruamel.yaml=0.17.21=py310h5eee18b_0 - - ruamel.yaml.clib=0.2.6=py310h5eee18b_1 - - six=1.16.0=pyhd3eb1b0_1 - - soupsieve=2.5=py310h06a4308_0 - - sqlite=3.45.3=h5eee18b_0 - - stack_data=0.2.0=pyhd3eb1b0_0 - - tbb=2021.8.0=hdb19cb5_0 - - tk=8.6.14=h39e8969_0 - - tomli=2.0.1=py310h06a4308_0 - - toolz=0.12.0=py310h06a4308_0 - - tqdm=4.66.4=py310h2f386ee_0 - - traitlets=5.7.1=py310h06a4308_0 - - truststore=0.8.0=py310h06a4308_0 - - urllib3=2.2.1=py310h06a4308_0 - - wcwidth=0.2.5=pyhd3eb1b0_0 - - wheel=0.43.0=py310h06a4308_0 - - xz=5.4.6=h5eee18b_1 - - yaml=0.2.5=h7b6447c_0 - - yaml-cpp=0.8.0=h6a678d5_1 - - zlib=1.2.13=h5eee18b_1 - - zstandard=0.22.0=py310h2c38b39_0 - - zstd=1.5.5=hc292b87_2 - - pip: - - absl-py==2.1.0 - - accelerate==1.10.1 - - aiohappyeyeballs==2.4.0 - - aiohttp==3.10.5 - - aiosignal==1.3.1 - - ale-py==0.8.1 - - annotated-types==0.7.0 - - anyio==4.11.0 - - astor==0.8.1 - - astunparse==1.6.3 - - async-timeout==4.0.3 - - av==12.3.0 - - beartype==0.18.5 - - bitmath==1.3.3.1 - - blake3==1.0.8 - - blis==1.3.0 - - box2d-py==2.3.5 - - cachetools==6.2.1 - - catalogue==2.0.10 - - cbor2==5.7.0 - - cloudpathlib==0.23.0 - - cloudpickle==3.0.0 - - comm==0.2.2 - - compressed-tensors==0.11.0 - - confection==0.1.5 - - contourpy==1.2.1 - - cupy-cuda12x==13.6.0 - - cycler==0.12.1 - - cymem==2.0.11 - - cython==0.29.37 - - datasets==4.2.0 - - debugpy==1.8.5 - - decorator==4.4.2 - - deprecation==2.1.0 - - depyf==0.19.0 - - di-engine==0.5.3 - - di-toolkit==0.3.0 - - di-treetensor==0.4.1 - - diffusers==0.30.0 - - dill==0.3.8 - - diskcache==5.6.3 - - dm-control==1.0.22 - - dm-env==1.6 - - dm-tree==0.1.8 - - dnspython==2.6.1 - - docker-pycreds==0.4.0 - - docstring-parser==0.17.0 - - easydict==1.9 - - einops==0.8.1 - - email-validator==2.3.0 - - en-core-web-sm==3.8.0 - - enum-tools==0.12.0 - - etils==1.7.0 - - expecttest==0.2.1 - - farama-notifications==0.0.4 - - fastapi==0.119.1 - - fastapi-cli==0.0.13 - - fastapi-cloud-cli==0.3.1 - - fasteners==0.19 - - fastrlock==0.8.3 - - filelock==3.20.0 - - flask==2.0.3 - - fonttools==4.53.1 - - frozenlist==1.4.1 - - fsspec==2024.6.0 - - gguf==0.17.1 - - gitdb==4.0.11 - - gitpython==3.1.43 - - glfw==2.7.0 - - grpcio==1.75.1 - - gym==0.25.1 - - gym-notices==0.0.8 - - gymnasium==0.28.0 - - h11==0.16.0 - - h5py==3.11.0 - - hbutils==0.10.0 - - hf-xet==1.1.10 - - hickle==5.0.3 - - httpcore==1.0.9 - - httptools==0.7.1 - - httpx==0.28.1 - - huggingface-hub==0.35.3 - - hypothesis==6.103.0 - - imageio==2.35.1 - - imageio-ffmpeg==0.5.1 - - importlib-metadata==8.4.0 - - importlib-resources==6.4.4 - - iniconfig==2.3.0 - - interegular==0.3.3 - - ipykernel==6.29.5 - - ipywidgets==8.1.3 - - itsdangerous==2.2.0 - - jax-jumpy==1.0.0 - - jericho==3.3.0 - - jinja2==3.1.6 - - jiter==0.11.1 - - joblib==1.4.2 - - jsonschema==4.25.1 - - jupyter-client==8.6.2 - - jupyter-core==5.7.2 - - jupyterlab-widgets==3.0.11 - - kiwisolver==1.4.5 - - labmaze==1.0.6 - - langcodes==3.5.0 - - language-data==1.3.0 - - lark==1.2.2 - - lightning-utilities==0.11.6 - - lightzero==0.2.0 - - line-profiler==5.0.0 - - llguidance==0.7.30 - - llvmlite==0.44.0 - - lm-format-enforcer==0.11.3 - - lockfile==0.12.2 - - loguru==0.7.3 - - lxml==5.3.0 - - marisa-trie==1.3.1 - - markdown==3.9 - - markdown-it-py==3.0.0 - - matplotlib==3.9.2 - - mdurl==0.1.2 - - minigrid==2.2.1 - - mistral-common==1.8.5 - - mjrl==1.0.0 - - moviepy==1.0.3 - - mpire==2.10.2 - - msgpack==1.1.2 - - msgspec==0.19.0 - - mujoco==3.2.2 - - mujoco-py==2.1.2.14 - - multidict==6.0.5 - - multiprocess==0.70.16 - - murmurhash==1.0.13 - - nest-asyncio==1.6.0 - - networkx==3.3 - - ninja==1.13.0 - - nltk==3.9.2 - - numba==0.61.2 - - nvidia-cublas-cu12==12.8.4.1 - - nvidia-cuda-cupti-cu12==12.8.90 - - nvidia-cuda-nvrtc-cu12==12.8.93 - - nvidia-cuda-runtime-cu12==12.8.90 - - nvidia-cudnn-cu12==9.10.2.21 - - nvidia-cufft-cu12==11.3.3.83 - - nvidia-cufile-cu12==1.13.1.3 - - nvidia-curand-cu12==10.3.9.90 - - nvidia-cusolver-cu12==11.7.3.90 - - nvidia-cusparse-cu12==12.5.8.93 - - nvidia-cusparselt-cu12==0.7.1 - - nvidia-ml-py==13.580.82 - - nvidia-nccl-cu12==2.27.3 - - nvidia-nvjitlink-cu12==12.8.93 - - nvidia-nvtx-cu12==12.8.90 - - nvitop==1.5.3 - - openai==2.5.0 - - openai-harmony==0.0.4 - - opencv-python==4.10.0.84 - - opencv-python-headless==4.12.0.88 - - optree==0.11.0 - - orjson==3.10.7 - - outlines-core==0.2.11 - - pandas==2.3.3 - - partial-json-parser==0.2.1.1.post6 - - pastel==0.2.1 - - peft==0.17.1 - - pip==24.2 - - pluggy==1.6.0 - - poethepoet==0.10.0 - - pot==0.9.4 - - preshed==3.0.10 - - proglog==0.1.10 - - prometheus-client==0.23.1 - - prometheus-fastapi-instrumentator==7.1.0 - - protobuf==5.27.3 - - py-cpuinfo==9.0.0 - - pyarrow==21.0.0 - - pybase64==1.4.2 - - pybullet==3.2.6 - - pycountry==24.6.1 - - pydantic==2.12.3 - - pydantic-core==2.41.4 - - pydantic-extra-types==2.10.6 - - pygame==2.6.1 - - pympler==1.1 - - pynng==0.8.1 - - pyopengl==3.1.7 - - pyparsing==3.1.2 - - pytest==8.4.2 - - python-dateutil==2.9.0.post0 - - python-dotenv==1.1.1 - - python-etcd==0.4.5 - - python-graphviz==0.20.3 - - python-json-logger==4.0.0 - - python-multipart==0.0.20 - - pytimeparse==1.1.8 - - pytorch-lightning==2.4.0 - - pyzmq==26.2.0 - - ray==2.50.1 - - redis==6.4.0 - - regex==2024.7.24 - - responses==0.25.8 - - rich==13.7.1 - - rich-toolkit==0.15.1 - - rignore==0.7.1 - - safetensors==0.4.4 - - scikit-learn==1.5.1 - - scipy==1.14.1 - - seaborn==0.13.2 - - sentencepiece==0.2.1 - - sentry-sdk==2.42.0 - - setproctitle==1.3.3 - - setuptools==66.1.1 - - shellingham==1.5.4 - - shimmy==0.2.1 - - simple-parsing==0.1.7 - - smart-open==7.4.0 - - smmap==5.0.1 - - sniffio==1.3.1 - - sortedcontainers==2.4.0 - - soundfile==0.13.1 - - soxr==1.0.0 - - spacy==3.8.7 - - spacy-legacy==3.0.12 - - spacy-loggers==1.0.5 - - srsly==2.5.1 - - starlette==0.48.0 - - sympy==1.14.0 - - tabulate==0.9.0 - - tensorboard==2.20.0 - - tensorboard-data-server==0.7.2 - - tensorboardx==2.6.4 - - tensordict==0.5.0 - - termcolor==2.4.0 - - thinc==8.3.6 - - threadpoolctl==3.5.0 - - tiktoken==0.12.0 - - tokenizers==0.22.1 - - tomlkit==0.13.2 - - torch==2.8.0 - - torchaudio==2.8.0 - - torchcde==0.2.5 - - torchdiffeq==0.2.4 - - torchelastic==0.2.2 - - torchmetrics==1.4.1 - - torchsde==0.2.6 - - torchvision==0.23.0 - - tornado==6.4.1 - - trampoline==0.1.2 - - transformers==4.57.1 - - treevalue==1.4.12 - - triton==3.4.0 - - trueskill==0.4.5 - - typer==0.19.2 - - types-dataclasses==0.6.6 - - typing-extensions==4.15.0 - - typing-inspection==0.4.2 - - tzdata==2025.2 - - urlobject==3.0.0 - - uvicorn==0.38.0 - - uvloop==0.22.1 - - vllm==0.11.0 - - wandb==0.17.7 - - wasabi==1.1.3 - - watchfiles==1.1.1 - - weasel==0.4.1 - - websockets==15.0.1 - - werkzeug==2.0.3 - - widgetsnbextension==4.0.11 - - wrapt==2.0.0 - - xformers==0.0.32.post1 - - xgrammar==0.1.25 - - xxhash==3.6.0 - - yapf==0.29.0 - - yarl==1.9.4 - - yattag==1.16.1 - - zipp==3.20.0 -prefix: /opt/conda diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py index 99b22b852..0ec97a12c 100644 --- a/lzero/entry/utils.py +++ b/lzero/entry/utils.py @@ -528,9 +528,11 @@ def calculate_update_per_collect( collected_transitions_tensor ).item() updates = int(total_collected_transitions * cfg.policy.replay_ratio) + print(f"\ntotal_collected_transitions={total_collected_transitions}\tupdates={updates}\n") else: # In a single-process setup. updates = int(collected_transitions_num * cfg.policy.replay_ratio) + print(f"collected_transitions_num={collected_transitions_num}\tupdates={updates}") return max(1, updates) # Ensure at least one update. diff --git a/lzero/mcts/buffer/__init__.py b/lzero/mcts/buffer/__init__.py index d7ccb0678..541dd35c5 100644 --- a/lzero/mcts/buffer/__init__.py +++ b/lzero/mcts/buffer/__init__.py @@ -8,3 +8,4 @@ from .game_buffer_stochastic_muzero import StochasticMuZeroGameBuffer from .game_buffer_rezero_mz import ReZeroMZGameBuffer from .game_buffer_rezero_ez import ReZeroEZGameBuffer +from .game_buffer_priorzero import PriorZeroGameBufferOptimized diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index 253935652..d1a854988 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -145,14 +145,6 @@ def _sample_orig_data(self, batch_size: int, print_priority_logs: bool = False) game_segment = self.game_segment_buffer[game_segment_idx] game_segment_list.append(game_segment) - - # print(f'len(game_segment)=:len(game_segment.action_segment): {len(game_segment)}') - # print(f'len(game_segment.obs_segment): {game_segment.obs_segment.shape[0]}') - - # In the reanalysis phase, `pos_in_game_segment` should be a multiple of `num_unroll_steps`. - # Indices exceeding `game_segment_length` are padded with the next segment and are not updated - # in the current implementation. Therefore, we need to sample `pos_in_game_segment` within - # [0, game_segment_length - num_unroll_steps] to avoid padded data. if self._cfg.action_type == 'varied_action_space': # For some environments (e.g., Jericho), the action space size may be different. diff --git a/lzero/mcts/buffer/game_buffer_priorzero.py b/lzero/mcts/buffer/game_buffer_priorzero.py index c9dda2cf5..d4cdd6c62 100644 --- a/lzero/mcts/buffer/game_buffer_priorzero.py +++ b/lzero/mcts/buffer/game_buffer_priorzero.py @@ -1,210 +1,66 @@ -# game_buffer_priorzero.py -""" -[PRIORZERO] Enhanced Game Buffer for PriorZero - -This module extends UniZeroGameBuffer to support LLM policy training (SFT + RFT). - -Key Features: -- Returns game_segments in sample() for LLM training data extraction -- Efficient indexing to avoid duplicating large observation data -- Robust handling of edge cases (partial batches, variable-length segments) -- Minimal memory overhead (only stores references, not copies) - -Author: PriorZero Team -Date: 2025-01-21 -""" - import numpy as np from typing import List, Any, Union, Tuple from lzero.mcts.buffer.game_buffer_unizero import UniZeroGameBuffer +from lzero.policy import to_detach_cpu_numpy, concat_output_value, inverse_scalar_transform +from lzero.mcts.utils import prepare_observation +import torch -class PriorZeroGameBuffer(UniZeroGameBuffer): - """ - [PRIORZERO-MODIFIED] - Enhanced GameBuffer that provides game_segments for LLM policy training. - - Modifications: - 1. sample() returns game_segments as 4th element - 2. Efficient implementation using existing game_segment_list from _make_batch - 3. No additional memory overhead (returns references, not copies) - """ +class PriorZeroGameBufferOptimized(UniZeroGameBuffer): def __init__(self, cfg): - """Initialize PriorZero Game Buffer.""" super().__init__(cfg) - - # [PRIORZERO-NEW] Cache for the last sampled game segments - # This avoids re-sampling when we need game segments - self._last_sampled_game_segments = None - self._last_sampled_batch_indices = None - - def sample( - self, - batch_size: int, - policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] - ) -> List[Any]: + self.last_pos_in_transition = 0 + + def fetch_latest_batch(self, batch_size: int, policy) -> List[Any]: """ - [PRIORZERO-MODIFIED] - Sample data and prepare current_batch, target_batch, AND game_segments. + Fetch latest batch for LLM training. Returns: - train_data: [current_batch, target_batch, game_segments] - - current_batch: [obs, action, target_action, mask, indices, weights, make_time, timestep] - - target_batch: [rewards, values, policies] - - game_segments: List of GameSegment objects used in this batch - - Note: - game_segments are returned for LLM training (SFT/RFT). - They contain: - - mcts_policy_segment: MCTS visit distributions (for SFT supervision) - - raw_obs_segment: Raw text observations (for LLM prompts) - - reward_segment: Environment rewards (for RFT) - - search_value_segment: MCTS search values (for analysis) + [raw_obs_list, history_obs_list, llm_prior_per_tok_list, batch_target_values, cot_prefix_list, llm_action] + CoT prefix list is added for CoT reuse optimization. """ policy._target_model.to(self._cfg.device) policy._target_model.eval() - # ====================================================================== - # [PRIORZERO-KEY] Sample data and extract game_segments - # ====================================================================== - # obtain the current_batch and prepare target context reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch( - batch_size, self._cfg.reanalyze_ratio + batch_size, self._cfg.reanalyze_ratio, fetch_latest=True ) - # [PRIORZERO-NEW] Extract game_segments from the sampling process - # These were already created in _make_batch, we just need to save them - game_segments = self._last_sampled_game_segments - - # Defensive check: ensure game_segments match batch_size - if game_segments is None or len(game_segments) != len(current_batch[4]): # current_batch[4] is batch_index_list - # Fallback: create empty list if something went wrong - import logging - logging.warning( - f"[PriorZeroBuffer] game_segments mismatch: " - f"expected {len(current_batch[4])}, got {len(game_segments) if game_segments else None}. " - f"Falling back to empty list (SFT/RFT will be skipped)." - ) - game_segments = [] - - # ====================================================================== - # Standard UniZero processing (unchanged) - # ====================================================================== - # current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list] + obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list, raw_obs_list, history_obs_list, llm_prior_per_tok_list, cot_prefix_list, llm_action_list = current_batch - # target reward, target value - batch_rewards, batch_target_values = self._compute_target_reward_value( - reward_value_context, policy._target_model, current_batch[2], current_batch[-1] # current_batch[2] is batch_target_action + # Standard processing + batch_rewards, batch_target_values, batch_pred_values = self._compute_target_reward_value_and_pred_value( + reward_value_context, policy._target_model, action_list, bootstrap_action_list, timestep_list ) - # target policy - batch_target_policies_re = self._compute_target_policy_reanalyzed( - policy_re_context, policy._target_model, current_batch[1], current_batch[-1] - ) # current_batch[1] is batch_action - batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( + batch_target_policies = self._compute_target_policy_non_reanalyzed( policy_non_re_context, self.action_space_size ) - # fusion of batch_target_policies_re and batch_target_policies_non_re to batch_target_policies - if 0 < self._cfg.reanalyze_ratio < 1: - batch_target_policies = np.concatenate([batch_target_policies_re, batch_target_policies_non_re]) - elif self._cfg.reanalyze_ratio == 1: - batch_target_policies = batch_target_policies_re - elif self._cfg.reanalyze_ratio == 0: - batch_target_policies = batch_target_policies_non_re - - target_batch = [batch_rewards, batch_target_values, batch_target_policies] - - # ====================================================================== - # [PRIORZERO-KEY] Return current_batch, target_batch, AND game_segments - # ====================================================================== - train_data = [current_batch, target_batch, game_segments] - return train_data - - def _sample_orig_data(self, batch_size: int) -> Tuple[Any]: - """ - [PRIORZERO-MODIFIED] - Override to cache game_segments during sampling. - - This avoids double sampling by caching the result for sample() to use. - """ - # Call parent implementation - result = super()._sample_orig_data(batch_size) - - # Cache the game_segment_list (first element of result tuple) - game_segment_list = result[0] - self._last_sampled_game_segments = game_segment_list - self._last_sampled_batch_indices = result[2] # batch_index_list - - return result - - def _sample_orig_data_episode(self, batch_size: int) -> Tuple[Any]: - """ - [PRIORZERO-MODIFIED] - Override to cache game_segments during episode sampling. - - This avoids double sampling by caching the result for sample() to use. - """ - # Call parent implementation - result = super()._sample_orig_data_episode(batch_size) - - # Cache the game_segment_list (first element of result tuple) - game_segment_list = result[0] - self._last_sampled_game_segments = game_segment_list - self._last_sampled_batch_indices = result[2] # batch_index_list + # CoT reuse optimization: return cot_prefix_list + # IMPORTANT: Validate return value before returning to ensure broadcast compatibility + result = [raw_obs_list, history_obs_list, llm_prior_per_tok_list, batch_target_values, batch_pred_values, cot_prefix_list, llm_action_list] return result - - def clear(self): - """ - [PRIORZERO-MODIFIED] - Clear buffer and cached game segments. - """ - super().clear() - self._last_sampled_game_segments = None - self._last_sampled_batch_indices = None - - -# ============================================================================== -# Optimized Alternative (Avoids Double Sampling) -# ============================================================================== - -class PriorZeroGameBufferOptimized(UniZeroGameBuffer): - """ - [PRIORZERO-OPTIMIZED] - More efficient version that avoids double sampling by modifying _make_batch minimally. - - This version uses a monkey-patch approach to intercept orig_data during parent's _make_batch call. - """ - - def __init__(self, cfg): - super().__init__(cfg) - self._cached_game_segments = None - + def sample(self, batch_size: int, policy) -> List[Any]: """Sample data with game_segments (optimized version).""" policy._target_model.to(self._cfg.device) policy._target_model.eval() - # Reset cache - self._cached_game_segments = None - - # Call parent's _make_batch (which will trigger our hook) reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch( batch_size, self._cfg.reanalyze_ratio ) - # Get cached game segments (set by our overridden _make_batch) - game_segments = self._cached_game_segments or [] - + obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list, raw_obs_list, history_obs_list, llm_prior_per_tok_list, cot_prefix_list, llm_action_list = current_batch # Standard processing batch_rewards, batch_target_values = self._compute_target_reward_value( - reward_value_context, policy._target_model, current_batch[2], current_batch[-1] + reward_value_context, policy._target_model, current_batch[2], timestep_list ) batch_target_policies_re = self._compute_target_policy_reanalyzed( - policy_re_context, policy._target_model, current_batch[1], current_batch[-1] + policy_re_context, policy._target_model, current_batch[1], timestep_list ) batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( policy_non_re_context, self.action_space_size @@ -219,30 +75,31 @@ def sample(self, batch_size: int, policy) -> List[Any]: target_batch = [batch_rewards, batch_target_values, batch_target_policies] - return [current_batch, target_batch, game_segments] + return [current_batch, target_batch] - def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: - """ - [PRIORZERO-OPTIMIZED] - Minimally modified to cache game_segment_list during sampling. + def _make_batch(self, batch_size: int, reanalyze_ratio: float, fetch_latest: bool = False) -> Tuple[Any]: - This is a full override of parent's _make_batch to avoid double sampling. - Code is mostly copied from parent, with one key addition: caching game_segments. - """ # Sample original data - if self.sample_type == 'transition': - orig_data = self._sample_orig_data(batch_size) - elif self.sample_type == 'episode': - orig_data = self._sample_orig_data_episode(batch_size) + if not fetch_latest: + if self.sample_type == 'transition': + orig_data = self._sample_orig_data(batch_size) + elif self.sample_type == 'episode': + orig_data = self._sample_orig_data_episode(batch_size) + else: + if self.sample_type == 'transition': + orig_data = self._fetch_latest_orig_data(batch_size) + elif self.sample_type == 'episode': + raise ValueError("fetch_latest with episode sampling not supported.") game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data - # [PRIORZERO-KEY] Cache game_segments for sample() to use - self._cached_game_segments = game_segment_list - # Rest of the code is identical to parent's _make_batch batch_size = len(batch_index_list) obs_list, action_list, mask_list = [], [], [] + raw_obs_list, history_obs_list = [], [] + llm_prior_per_tok_list = [] + cot_prefix_list = [] # CoT reuse optimization + llm_action_list = [] timestep_list = [] bootstrap_action_list = [] @@ -272,6 +129,22 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True ) ) + raw_obs_list.append(game_segment_list[i].get_unroll_raw_obs( + pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True + )) + history_obs_list.append(game_segment_list[i].get_unroll_histroy_obs( + pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True + )) + llm_prior_per_tok_list.append(game_segment_list[i].get_unroll_llm_prior_per_tok( + pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True + )) + cot_prefix_list.append(game_segment_list[i].get_unroll_cot_prefix( + pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True + )) + llm_action_list.append(game_segment_list[i].get_unroll_llm_action( + pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True + )) + action_list.append(actions_tmp) mask_list.append(mask_tmp) timestep_list.append(timestep_tmp) @@ -291,12 +164,40 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list] for i in range(len(current_batch)): current_batch[i] = np.asarray(current_batch[i]) + # 检查 vllm和policy_model的输入上下文是否一致 + assert len(raw_obs_list) == len(history_obs_list) == len(llm_prior_per_tok_list) == len(cot_prefix_list) == len(llm_action_list) + B, T = len(raw_obs_list), len(raw_obs_list[0]) + for b in range(B): + for t in range(T - 1): + current_obs = raw_obs_list[b][t] + current_hist = history_obs_list[b][t] + + old_prefix_cot = llm_prior_per_tok_list[b][t+1]['prefix_cot'] + old_current_obs = llm_prior_per_tok_list[b][t+1]['current_obs'] + old_history = llm_prior_per_tok_list[b][t+1]['history'] + old_logprob = llm_prior_per_tok_list[b][t+1]['old_action_logprob'] + cot_prefix = cot_prefix_list[b][t+1] + llm_action = llm_action_list[b][t+1] + + assert llm_action in old_logprob + assert old_current_obs == current_obs and old_history == current_hist and old_prefix_cot == cot_prefix + + current_batch.append(raw_obs_list) + current_batch.append(history_obs_list) + current_batch.append(llm_prior_per_tok_list) + current_batch.append(cot_prefix_list) # CoT reuse optimization + current_batch.append(llm_action_list) total_transitions = self.get_num_of_transitions() - reward_value_context = self._prepare_reward_value_context( - batch_index_list, game_segment_list, pos_in_game_segment_list, total_transitions - ) + if not fetch_latest: + reward_value_context = self._prepare_reward_value_context( + batch_index_list, game_segment_list, pos_in_game_segment_list, total_transitions + ) + else: + reward_value_context = self._prepare_reward_value_context_and_pred_values( + batch_index_list, game_segment_list, pos_in_game_segment_list, total_transitions + ) reanalyze_num = max(int(batch_size * reanalyze_ratio), 1) if reanalyze_ratio > 0 else 0 self.reanalyze_num = reanalyze_num @@ -319,71 +220,329 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: return reward_value_context, policy_re_context, policy_non_re_context, current_batch + def _clear(self): + self.game_pos_priorities = [] + self.game_segment_buffer = [] + self.game_segment_game_pos_look_up = [] + + + def _fetch_latest_orig_data(self, batch_size: int) -> Tuple: + """ + Overview: + Sample original data which includes: + - game_segment_list: A list of game segments. + - pos_in_game_segment_list: Transition index in the game (relative index). + - batch_index_list: The index of the start transition of the sampled mini-batch in the replay buffer. + - weights_list: The weight concerning the priority. + - make_time: The time the batch is made (for correctly updating the replay buffer when data is deleted). + Arguments: + - batch_size (:obj:`int`): The size of the batch. + - print_priority_logs (:obj:`bool`): Whether to print logs related to priority statistics, defaults to False. + """ + assert self._beta > 0, "Beta should be greater than 0" + num_of_transitions = self.get_num_of_transitions() + + probs = self.game_pos_priorities ** self._alpha + 1e-6 + probs /= probs.sum() -# ============================================================================== -# Factory Function -# ============================================================================== - -def create_priorzero_buffer(cfg, optimized: bool = True): - """ - Factory function to create PriorZero game buffer. - - Args: - cfg: Configuration dict - optimized: If True, use optimized version (recommended) - - Returns: - buffer: PriorZero game buffer instance - """ - if optimized: - return PriorZeroGameBufferOptimized(cfg) - else: - return PriorZeroGameBuffer(cfg) - - -if __name__ == "__main__": - print("="*80) - print("PriorZero Game Buffer - Unit Tests") - print("="*80) - - # Create mock config - class MockConfig: - def __init__(self): - self.device = 'cpu' - self.env_type = 'not_board_games' - self.game_segment_length = 200 - self.num_unroll_steps = 5 - self.td_steps = 5 - self.batch_size = 32 - self.use_priority = False - self.reanalyze_ratio = 0.0 - self.sample_type = 'transition' - self.replay_buffer_size = 10000 - self.model = type('obj', (object,), { - 'model_type': 'mlp', - 'action_space_size': 10, - 'observation_shape': 128, - })() - - cfg = MockConfig() - - # Test both versions - for name, buffer_class in [ - ("Standard", PriorZeroGameBuffer), - ("Optimized", PriorZeroGameBufferOptimized) - ]: - print(f"\nTesting {name} Buffer:") - print("-" * 40) - - buffer = buffer_class(cfg) - print(f"✓ Buffer created: {type(buffer).__name__}") - print(f" - sample_type: {buffer.sample_type}") - print(f" - action_space_size: {buffer.action_space_size}") - - # Note: Full testing would require mock GameSegments and Policy - # For now, just verify instantiation - print(f"✓ {name} buffer initialized successfully") - - print("\n" + "="*80) - print("✓ All tests passed!") - print("="*80) + # 主要改动: 由sample改成了确定的取最后batch_size个样本 + if batch_size == -1: + batch_index_list = list(range(num_of_transitions))[self.last_pos_in_transition:] + else: + batch_index_list = list(range(num_of_transitions))[-batch_size:] + self.last_pos_in_transition = num_of_transitions + + if self._cfg.reanalyze_outdated: + batch_index_list.sort() + + weights_list = (num_of_transitions * probs[batch_index_list]) ** (-self._beta) + weights_list /= weights_list.max() # Normalize weights + + game_segment_list = [] + pos_in_game_segment_list = [] + + for idx in batch_index_list: + game_segment_idx, pos_in_game_segment = self.game_segment_game_pos_look_up[idx] + game_segment_idx -= self.base_idx # Adjust index based on base index + game_segment = self.game_segment_buffer[game_segment_idx] + + game_segment_list.append(game_segment) + assert len(game_segment.obs_segment) == len(game_segment.raw_obs_segment) == len(game_segment.cot_prefix_segment) + if pos_in_game_segment + self._cfg.num_unroll_steps + self._cfg.model.frame_stack_num > len(game_segment.obs_segment): + max_safe_pos = max(0, len(game_segment.obs_segment) - self._cfg.num_unroll_steps - self._cfg.model.frame_stack_num) + pos_in_game_segment = np.random.randint(0, max_safe_pos + 1) + + # print(f'len(game_segment)=:len(game_segment.action_segment): {len(game_segment)}') + # print(f'len(game_segment.obs_segment): {game_segment.obs_segment.shape[0]}') + + # In the reanalysis phase, `pos_in_game_segment` should be a multiple of `num_unroll_steps`. + # Indices exceeding `game_segment_length` are padded with the next segment and are not updated + # in the current implementation. Therefore, we need to sample `pos_in_game_segment` within + # [0, game_segment_length - num_unroll_steps] to avoid padded data. + + if self._cfg.action_type == 'varied_action_space': + # For some environments (e.g., Jericho), the action space size may be different. + # To ensure we can always unroll `num_unroll_steps` steps starting from the sampled position (without exceeding segment length), + # we avoid sampling from the last `num_unroll_steps` steps of the game segment. + if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps: + pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps, 1).item() + + segment_len = len(game_segment.action_segment) + if pos_in_game_segment >= segment_len - 1: + # If the segment is very short (length 0 or 1), we can't randomly sample a position + # before the last one. The only safe position is 0. + if segment_len > 1: + # If the segment has at least 2 actions, we can safely sample from [0, len-2]. + # The upper bound for np.random.choice is exclusive, so (segment_len - 1) is correct. + pos_in_game_segment = np.random.choice(segment_len - 1, 1).item() + else: + # If segment length is 0 or 1, the only valid/safe position is 0. + pos_in_game_segment = 0 + + else: + # For environments with a fixed action space (e.g., Atari), + # we can safely sample from the entire game segment range. + if pos_in_game_segment >= self._cfg.game_segment_length: + pos_in_game_segment = np.random.choice(self._cfg.game_segment_length, 1).item() + + segment_len = len(game_segment.action_segment) + if pos_in_game_segment >= segment_len - 1: + # If the segment is very short (length 0 or 1), we can't randomly sample a position + # before the last one. The only safe position is 0. + if segment_len > 1: + # If the segment has at least 2 actions, we can safely sample from [0, len-2]. + # The upper bound for np.random.choice is exclusive, so (segment_len - 1) is correct. + pos_in_game_segment = np.random.choice(segment_len - 1, 1).item() + else: + # If segment length is 0 or 1, the only valid/safe position is 0. + pos_in_game_segment = 0 + + pos_in_game_segment_list.append(pos_in_game_segment) + + + # make_time = [time.time() for _ in range(len(batch_index_list))] + + # Set the make_time for each sample (set to 0 for now, but can be the actual time if needed). + make_time = [0. for _ in range(len(batch_index_list))] + + orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time) + + return orig_data + + # 从原来的_prepare_reward_value_context函数修改得到 + def _prepare_reward_value_context_and_pred_values( + self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[Any], + total_transitions: int + ) -> List[Any]: + """ + Overview: + prepare the context of rewards and values for calculating TD value target in reanalyzing part. + Arguments: + - batch_index_list (:obj:`list`): the index of start transition of sampled minibatch in replay buffer + - game_segment_list (:obj:`list`): list of game segments + - pos_in_game_segment_list (:obj:`list`): list of transition index in game_segment + - total_transitions (:obj:`int`): number of collected transitions + Returns: + - reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, + td_steps_list, action_mask_segment, to_play_segment + """ + zero_obs = game_segment_list[0].zero_obs() + + pred_obs_list = [] + pred_mask = [] + + value_obs_list = [] + # the value is valid or not (out of game_segment) + value_mask = [] + rewards_list = [] + game_segment_lens = [] + # for board games + action_mask_segment, to_play_segment = [], [] + + root_values = [] + + td_steps_list = [] + for game_segment, state_index in zip(game_segment_list, pos_in_game_segment_list): + game_segment_len = len(game_segment) + game_segment_lens.append(game_segment_len) + # original buffer td-steps + td_steps = np.clip(self._cfg.td_steps, 1, max(1, game_segment_len - state_index)).astype(np.int32) + + # prepare the corresponding observations for bootstrapped values o_{t+k} + # o[t+ td_steps, t + td_steps + stack frames + num_unroll_steps] + # t=2+3 -> o[2+3, 2+3+4+5] -> o[5, 14] + game_obs_pred = game_segment.get_unroll_obs(state_index, self._cfg.num_unroll_steps) + game_obs = game_segment.get_unroll_obs(state_index + td_steps, self._cfg.num_unroll_steps) + + rewards_list.append(game_segment.reward_segment) + + # for board games + action_mask_segment.append(game_segment.action_mask_segment) + to_play_segment.append(game_segment.to_play_segment) + + truncation_length = game_segment_len + + for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + # get the bootstrapped target obs + td_steps_list.append(td_steps) + # index of bootstrapped obs o_{t+td_steps} + bootstrap_index = current_index + td_steps + + beg_index = current_index - state_index + end_index = beg_index + self._cfg.model.frame_stack_num + + if bootstrap_index < truncation_length: + value_mask.append(1) + # the stacked obs in time t + obs = game_obs[beg_index:end_index] + else: + value_mask.append(0) + obs = zero_obs + + if current_index < truncation_length: + pred_mask.append(1) + obs_pred = game_obs_pred[beg_index:end_index] + else: + pred_mask.append(0) + obs_pred = zero_obs + + value_obs_list.append(obs) + pred_obs_list.append(obs_pred) + + reward_value_context = [ + value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, root_values, game_segment_lens, td_steps_list, + action_mask_segment, to_play_segment, pred_obs_list, pred_mask + ] + return reward_value_context + + # 从原来的_compute_target_reward_value函数修改得到 + def _compute_target_reward_value_and_pred_value(self, reward_value_context: List[Any], model: Any, batch_action_pred, batch_action, batch_timestep) -> Tuple[Any, Any]: + """ + Overview: + prepare reward and value targets from the context of rewards and values. + Arguments: + - reward_value_context (:obj:'list'): the reward value context + - model (:obj:'torch.tensor'):model of the target model + Returns: + - batch_value_prefixs (:obj:'np.ndarray): batch of value prefix + - batch_target_values (:obj:'np.ndarray): batch of value estimation + """ + value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, root_values, game_segment_lens, td_steps_list, action_mask_segment, \ + to_play_segment, pred_obs_list, pred_mask = reward_value_context # noqa + # transition_batch_size = game_segment_batch_size * (num_unroll_steps+1) + transition_batch_size = len(value_obs_list) + + batch_target_values, batch_rewards, batch_pred_values = [], [], [] + with torch.no_grad(): + value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) + pred_obs_list = prepare_observation(pred_obs_list, self._cfg.model.model_type) + + network_output = [] + network_output_pred = [] + + batch_obs = torch.from_numpy(value_obs_list).to(self._cfg.device) + batch_obs_pred = torch.from_numpy(pred_obs_list).to(self._cfg.device) + + # =============== NOTE: The key difference with MuZero ================= + # calculate the bootstrapped value and target value + # NOTE: batch_obs(value_obs_list) is at t+td_steps, batch_action is at timestep t+td_steps + if self.task_id is not None: + # m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep, task_id=self.task_id) + m_output = model.initial_inference(batch_obs, batch_action, task_id=self.task_id) + m_output_pred = model.initial_inference(batch_obs_pred, batch_action_pred, task_id=self.task_id) + + else: + m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep) + m_output_pred = model.initial_inference(batch_obs_pred, batch_action_pred, start_pos=batch_timestep) + + # ====================================================================== + + # if not in training, obtain the scalars of the value/reward + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self.value_support), + m_output.policy_logits + ] + ) + [m_output_pred.latent_state, m_output_pred.value, m_output_pred.policy_logits] = to_detach_cpu_numpy( + [ + m_output_pred.latent_state, + inverse_scalar_transform(m_output_pred.value, self.value_support), + m_output_pred.policy_logits + ] + ) + + network_output.append(m_output) + network_output_pred.append(m_output_pred) + + if self._cfg.use_root_value: + value_numpy = np.array(root_values) + raise ValueError("error!!!") + else: + # use the predicted values + value_numpy = concat_output_value(network_output) + pred_numpy = concat_output_value(network_output_pred) + + # 不考虑 board_games的情况 + value_numpy = value_numpy.reshape(-1) * ( + np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list + ) + pred_numpy = pred_numpy.reshape(-1) + + value_numpy= value_numpy * np.array(value_mask) + value_list = value_numpy.tolist() + + pred_numpy = pred_numpy * np.array(pred_mask) + pred_list = pred_numpy.tolist() + + + horizon_id, value_index = 0, 0 + + for game_segment_len_non_re, reward_list, state_index, to_play_list in zip(game_segment_lens, rewards_list, + pos_in_game_segment_list, + to_play_segment): + target_values = [] + target_rewards = [] + pred_values = [] + base_index = state_index + + # =========== NOTE =============== + # if game_segment_len_non_re < self._cfg.game_segment_length: + # # The last segment of one episode, the target value of excess part should be 0 + # truncation_length = game_segment_len_non_re + # else: + # # game_segment_len is game_segment.action_segment.shape[0] + # # action_segment.shape[0] = reward_segment.shape[0] or action_segment.shape[0] = reward_segment.shape[0] + 1 + # truncation_length = game_segment_len_non_re + # assert reward_list.shape[0] + 1 == game_segment_len_non_re or reward_list.shape[0] == game_segment_len_non_re + + truncation_length = game_segment_len_non_re + + for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + bootstrap_index = current_index + td_steps_list[value_index] + for i, reward in enumerate(reward_list[current_index:bootstrap_index]): + # 不考虑 board_games的情况 + value_list[value_index] += reward * self._cfg.discount_factor ** i + horizon_id += 1 + + # TODO: check the boundary condition + target_values.append(value_list[value_index]) + pred_values.append(pred_list[value_index]) + + if current_index < len(reward_list): + target_rewards.append(reward_list[current_index]) + else: + target_rewards.append(np.array(0.)) + + value_index += 1 + + batch_rewards.append(target_rewards) + batch_target_values.append(target_values) + batch_pred_values.append(pred_values) + + batch_rewards = np.asarray(batch_rewards) + batch_target_values = np.asarray(batch_target_values) + batch_pred_values = np.asarray(batch_pred_values) + + return batch_rewards, batch_target_values, batch_pred_values \ No newline at end of file diff --git a/lzero/mcts/buffer/game_buffer_unizero.py b/lzero/mcts/buffer/game_buffer_unizero.py index 3bb9bf2ca..03180a24b 100644 --- a/lzero/mcts/buffer/game_buffer_unizero.py +++ b/lzero/mcts/buffer/game_buffer_unizero.py @@ -540,8 +540,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: return batch_target_policies_re - def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any, batch_action, batch_timestep) -> Tuple[ - Any, Any]: + def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any, batch_action, batch_timestep) -> Tuple[Any, Any]: """ Overview: prepare reward and value targets from the context of rewards and values. @@ -609,6 +608,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A value_numpy= value_numpy * np.array(value_mask) value_list = value_numpy.tolist() + horizon_id, value_index = 0, 0 for game_segment_len_non_re, reward_list, state_index, to_play_list in zip(game_segment_lens, rewards_list, diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py index d69671ac5..b2a9d7f5a 100644 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -2064,7 +2064,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar value_priority=value_priority, intermediate_tensor_x=intermediate_tensor_x, obs_embeddings=detached_obs_embeddings, # <-- 新增 - ) + ), inverse_scalar_transform_handle(outputs.logits_value.reshape(-1, outputs.logits_value.shape[-1])).detach() # TODO: test correctness diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 06fa3b580..733c1b6a8 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -535,7 +535,7 @@ def collect( # --- Episode Termination Handling --- if done: collected_episode += 1 - reward = info['eval_episode_return'] + reward = info['score'] log_info = {'reward': reward, 'time': self._env_info[env_id]['time'], 'step': self._env_info[env_id]['step']} if not collect_with_pure_policy: log_info['visit_entropy'] = visit_entropies_lst[env_id] / eps_steps_lst[env_id] if eps_steps_lst[env_id] > 0 else 0 diff --git a/lzero/worker/muzero_evaluator.py b/lzero/worker/muzero_evaluator.py index 01fabd38c..c3440f064 100644 --- a/lzero/worker/muzero_evaluator.py +++ b/lzero/worker/muzero_evaluator.py @@ -92,12 +92,10 @@ def __init__( f'./{self._exp_name}/log/{self._instance_name}', self._instance_name ) else: - # TODO(username): Refine logger setup for UniZero multitask with DDP v2. - if tb_logger is not None: - self._logger, _ = build_logger( - f'./{self._exp_name}/log/{self._instance_name}', self._instance_name, need_tb=False - ) - self._tb_logger = tb_logger + self._logger, _ = build_logger( + f'./{self._exp_name}/log/{self._instance_name}', self._instance_name, need_tb=False + ) + self._tb_logger = tb_logger self._rank = get_rank() print(f'rank {self._rank}, self.task_id: {self.task_id}') @@ -199,7 +197,7 @@ def eval( envstep: int = -1, n_episode: Optional[int] = None, return_trajectory: bool = False, - ) -> Tuple[bool, Dict[str, Any]]: + ) -> Dict[str, Any]: """ Overview: Run a full evaluation process. It will evaluate the current policy, log the results, @@ -358,8 +356,8 @@ def eval( dones[env_id] = done if episode_timestep.done: self._policy.reset([env_id]) - reward = episode_timestep.info['eval_episode_return'] - saved_info = {'eval_episode_return': episode_timestep.info['eval_episode_return']} + reward = episode_timestep.info['score'] + saved_info = {'eval_episode_return': episode_timestep.info['score']} if 'episode_info' in episode_timestep.info: saved_info.update(episode_timestep.info['episode_info']) eval_monitor.update_info(env_id, saved_info) @@ -407,64 +405,10 @@ def eval( duration = self._timer.value episode_return = eval_monitor.get_episode_return() info = { - 'train_iter': train_iter, - 'ckpt_name': f'iteration_{train_iter}.pth.tar', - 'episode_count': n_episode, - 'envstep_count': envstep_count, 'avg_envstep_per_episode': envstep_count / n_episode if n_episode > 0 else 0, - 'evaluate_time': duration, - 'avg_envstep_per_sec': envstep_count / duration if duration > 0 else 0, - 'avg_time_per_episode': n_episode / duration if duration > 0 else 0, 'reward_mean': np.mean(episode_return), 'reward_std': np.std(episode_return), 'reward_max': np.max(episode_return), 'reward_min': np.min(episode_return), } - episode_info = eval_monitor.get_episode_info() - if episode_info is not None: - info.update(episode_info) - - print(f'rank {self._rank}, self.task_id: {self.task_id}') - self._logger.info(self._logger.get_tabulate_vars_hor(info)) - - # Log to TensorBoard and WandB. - for k, v in info.items(): - if k in ['train_iter', 'ckpt_name', 'each_reward'] or not np.isscalar(v): - continue - if self.task_id is None: - self._tb_logger.add_scalar(f'{self._instance_name}_iter/{k}', v, train_iter) - self._tb_logger.add_scalar(f'{self._instance_name}_step/{k}', v, envstep) - else: - self._tb_logger.add_scalar(f'{self._instance_name}_iter_task{self.task_id}/{k}', v, train_iter) - self._tb_logger.add_scalar(f'{self._instance_name}_step_task{self.task_id}/{k}', v, envstep) - if self.policy_config.use_wandb: - wandb.log({f'{self._instance_name}_step/{k}': v}, step=envstep) - - # Check for new best performance and save checkpoint. - mean_episode_return = np.mean(episode_return) - if mean_episode_return > self._max_episode_return: - if save_ckpt_fn: - save_ckpt_fn('ckpt_best.pth.tar') - self._max_episode_return = mean_episode_return - - # Check if the stop condition is met. - stop_flag = mean_episode_return >= self._stop_value and train_iter > 0 - if stop_flag: - self._logger.info( - f"[LightZero serial pipeline] Current episode_return: {mean_episode_return} is greater than " - f"stop_value: {self._stop_value}. The agent is considered converged." - ) - - # TODO(username): Finalize DDP synchronization for evaluation results. - # if get_world_size() > 1: - # objects = [stop_flag, episode_info] - # print(f'rank {self._rank}, self.task_id: {self.task_id}') - # print('before broadcast_object_list') - # broadcast_object_list(objects, src=0) - # print('evaluator after broadcast_object_list') - # stop_flag, episode_info = objects - - episode_info = to_item(episode_info) - if return_trajectory: - episode_info['trajectory'] = game_segments - return stop_flag, episode_info \ No newline at end of file + return info \ No newline at end of file diff --git a/lzero/worker/muzero_segment_collector.py b/lzero/worker/muzero_segment_collector.py index 7c265630b..39b154774 100644 --- a/lzero/worker/muzero_segment_collector.py +++ b/lzero/worker/muzero_segment_collector.py @@ -477,16 +477,6 @@ def collect( if self.policy_config.use_ture_chance_label_in_chance_encoder: append_kwargs['chance'] = self.chance_dict_tmp[env_id] - # [PRIORZERO-NEW] Add raw_obs_text if available in obs (not info!) - # Jericho env puts raw_obs_text in the obs dictionary - if env_id == 0 and collected_step < 5: # Debug first few steps - print(f"[OBS_DEBUG] Step {collected_step} env {env_id}: obs keys = {list(obs.keys())}") - print(f"[OBS_DEBUG] obs type = {type(obs)}") - if 'raw_obs_text' in obs: - print(f"[OBS_DEBUG] Found raw_obs_text: {str(obs['raw_obs_text'])[:100]}...") - else: - print(f"[OBS_DEBUG] NO raw_obs_text in obs!") - if 'raw_obs_text' in obs: append_kwargs['raw_obs_text'] = obs['raw_obs_text'] elif 'raw_obs_text' in info: @@ -566,7 +556,7 @@ def collect( self._total_episode_count += 1 info = { - 'reward': episode_timestep.info['eval_episode_return'], + 'reward': episode_timestep.info['score'], 'time': self._env_info[env_id]['time'], 'step': self._env_info[env_id]['step'], } diff --git a/zoo/jericho/envs/jericho_env.py b/zoo/jericho/envs/jericho_env.py index e6ac44a2b..553db9f68 100644 --- a/zoo/jericho/envs/jericho_env.py +++ b/zoo/jericho/envs/jericho_env.py @@ -4,6 +4,7 @@ import json from datetime import datetime from typing import Any, Dict, List, Optional, Union +from collections import OrderedDict import gym import numpy as np @@ -49,12 +50,13 @@ class JerichoEnv(BaseEnv): 'max_seq_len': 512, 'remove_stuck_actions': False, 'add_location_and_inventory': False, - # 'for_unizero': False, 'for_unizero': True, 'save_replay': False, 'save_replay_path': None, 'env_type': "zork1", - 'collect_policy_mode': "agent" + 'collect_policy_mode': "agent", + 'use_cache': True, + 'cache_size': 100000, } def __init__(self, cfg: Dict[str, Any]) -> None: @@ -93,6 +95,12 @@ def __init__(self, cfg: Dict[str, Any]) -> None: self.add_location_and_inventory: bool = self.cfg['add_location_and_inventory'] self.for_unizero: bool = self.cfg['for_unizero'] + self.use_cache = self.cfg['use_cache'] + if self.use_cache: + self.cache_size = self.cfg['cache_size'] + self.cache_buffer = OrderedDict() + print(f'[jericho]: use_cache: {self.use_cache}, cache_size={self.cache_size}') + # Initialize the tokenizer once (only in rank 0 process if distributed) if JerichoEnv.tokenizer is None: if self.rank == 0: @@ -138,7 +146,18 @@ def prepare_obs(self, obs: str, return_str: bool = False) -> Dict[str, Any]: raw_obs_text = obs # Save original text BEFORE any modification if self._action_list is None: - self._action_list = self._env.get_valid_actions() + if self.use_cache: + cache_key = self._env.get_world_state_hash() + if cache_key in self.cache_buffer: + self.cache_buffer.move_to_end(cache_key) + self._action_list = self.cache_buffer[cache_key] + else: + self._action_list = self._env.get_valid_actions() + self.cache_buffer[cache_key] = self._action_list + if len(self.cache_buffer) > self.cache_size: + self.cache_buffer.popitem(last=False) + else: + self._action_list = self._env.get_valid_actions() # Filter available actions based on whether stuck actions are removed. if self.remove_stuck_actions: @@ -245,7 +264,7 @@ def reset(self, return_str: bool = False) -> Dict[str, Any]: self.finished = False self._init_flag = True self._action_list = None - self.episode_return = 0.0 + self.episode_return = info['score'] if 'score' in info else 0.0 self._timestep = 0 self.episode_history = [] if self.collect_policy_mode == 'expert': @@ -344,6 +363,7 @@ def step(self, action: Union[int, np.ndarray, str], return_str: bool = False) -> previous_obs: Optional[str] = self.last_observation if (self.remove_stuck_actions and self.last_observation is not None) else None observation, reward, done, info = self._env.step(action_str) + info['action_str'] = action_str self._timestep += 1 if not self.for_unizero: diff --git a/zoo/jericho/priorzero/async_training_coordinator.py b/zoo/jericho/priorzero/async_training_coordinator.py deleted file mode 100644 index 46a7a36bd..000000000 --- a/zoo/jericho/priorzero/async_training_coordinator.py +++ /dev/null @@ -1,390 +0,0 @@ -# async_training_coordinator.py -""" -[PRIORZERO] Async Training Coordinator - -This module implements async coordination for collect/train/eval tasks. - -Key Features: -- Configurable off-policy degree to control async level -- Automatic fallback to synchronous mode (off_policy_degree=0) -- Independent async evaluation -- Thread-safe buffer access control - -Author: PriorZero Team -Date: 2025-01-21 -""" - -import asyncio -import time -from typing import Optional, Dict, Any, Callable, Awaitable -from loguru import logger - - -class AsyncTrainingCoordinator: - """ - Coordinates async execution of collect, train, and eval tasks. - - The coordinator manages the async execution based on off_policy_degree: - - off_policy_degree = 0: Synchronous mode (collect -> train -> eval) - - off_policy_degree > 0: Async mode with bounded lag - - The off_policy_degree controls how many batches the training can lag - behind the collection. Higher values allow more async execution but - increase off-policy bias. - """ - - def __init__( - self, - off_policy_degree: int = 0, - enable_async_eval: bool = False, - buffer_size: int = 10000, - batch_size: int = 32, - ): - """ - Initialize AsyncTrainingCoordinator. - - Args: - off_policy_degree: Degree of async between collect and train - - 0: Synchronous mode - - >0: Max number of batches train can lag behind collect - - -1: Auto-tune based on buffer_size and batch_size - enable_async_eval: Whether to run eval asynchronously - buffer_size: Replay buffer size (for auto-tuning) - batch_size: Training batch size (for auto-tuning) - """ - self.off_policy_degree = off_policy_degree - self.enable_async_eval = enable_async_eval - self.buffer_size = buffer_size - self.batch_size = batch_size - - # Auto-tune off_policy_degree if set to -1 - if self.off_policy_degree == -1: - # Auto-tune: allow lag up to 10% of buffer capacity - self.off_policy_degree = max(1, (buffer_size // batch_size) // 10) - logger.info(f"Auto-tuned off_policy_degree to {self.off_policy_degree}") - - # Synchronization primitives - self._collect_count = 0 # Number of collect iterations completed - self._train_count = 0 # Number of train iterations completed - self._eval_task: Optional[asyncio.Task] = None - - # Locks for thread-safe access - self._lock = asyncio.Lock() - - # Performance tracking - self._collect_times = [] - self._train_times = [] - self._eval_times = [] - - logger.info(f"AsyncTrainingCoordinator initialized:") - logger.info(f" - off_policy_degree: {self.off_policy_degree}") - logger.info(f" - enable_async_eval: {self.enable_async_eval}") - logger.info(f" - mode: {'SYNCHRONOUS' if self.is_synchronous else 'ASYNCHRONOUS'}") - - @property - def is_synchronous(self) -> bool: - """Check if coordinator is in synchronous mode.""" - return self.off_policy_degree == 0 - - @property - def collect_train_lag(self) -> int: - """Get current lag between collect and train iterations.""" - return self._collect_count - self._train_count - - def can_train(self) -> bool: - """ - Check if training is allowed based on off_policy_degree. - - In synchronous mode (off_policy_degree=0), training must wait for collect. - In async mode, training can proceed as long as lag is within bounds. - """ - if self.is_synchronous: - # Synchronous: train only after collect - return self._collect_count > self._train_count - else: - # Async: train can proceed if there's data and lag is acceptable - # We allow training as long as there's collected data - return self._collect_count > 0 - - def can_collect(self) -> bool: - """ - Check if collection is allowed based on off_policy_degree. - - In synchronous mode, collection must wait for train to finish. - In async mode, collection can proceed as long as lag doesn't exceed limit. - """ - if self.is_synchronous: - # Synchronous: collect only after train - return self._train_count >= self._collect_count - else: - # Async: collect can proceed if lag is within bounds - lag = self.collect_train_lag - return lag < self.off_policy_degree - - async def run_collect( - self, - collect_fn: Callable[[], Awaitable[Any]], - ) -> Any: - """ - Run collection with coordination. - - Args: - collect_fn: Async collection function - - Returns: - Collection result - """ - # Wait if needed (for sync mode or if lag is too high) - while not self.can_collect(): - logger.debug(f"Collect waiting (lag={self.collect_train_lag}, limit={self.off_policy_degree})") - await asyncio.sleep(0.1) - - # Run collection - start_time = time.time() - result = await collect_fn() - elapsed = time.time() - start_time - - # Update counter - async with self._lock: - self._collect_count += 1 - self._collect_times.append(elapsed) - - logger.debug(f"Collect completed in {elapsed:.2f}s (count={self._collect_count})") - return result - - async def run_train( - self, - train_fn: Callable[[], Awaitable[Any]], - ) -> Any: - """ - Run training with coordination. - - Args: - train_fn: Async training function - - Returns: - Training result - """ - # Wait if needed - while not self.can_train(): - logger.debug(f"Train waiting (collect={self._collect_count}, train={self._train_count})") - await asyncio.sleep(0.1) - - # Run training - start_time = time.time() - result = await train_fn() - elapsed = time.time() - start_time - - # Update counter - async with self._lock: - self._train_count += 1 - self._train_times.append(elapsed) - - logger.debug(f"Train completed in {elapsed:.2f}s (count={self._train_count}, lag={self.collect_train_lag})") - return result - - async def run_eval( - self, - eval_fn: Callable[[], Awaitable[Any]], - ) -> Any: - """ - Run evaluation with coordination. - - Args: - eval_fn: Async evaluation function - - Returns: - Evaluation result - """ - start_time = time.time() - - if self.enable_async_eval: - # Cancel previous eval if still running - if self._eval_task is not None and not self._eval_task.done(): - logger.info("Cancelling previous eval task") - self._eval_task.cancel() - try: - await self._eval_task - except asyncio.CancelledError: - pass - - # Run eval in background - self._eval_task = asyncio.create_task(eval_fn()) - logger.info("Started async eval in background") - - # Return immediately (don't wait) - return None - else: - # Synchronous eval - result = await eval_fn() - elapsed = time.time() - start_time - self._eval_times.append(elapsed) - logger.debug(f"Eval completed in {elapsed:.2f}s") - return result - - async def wait_for_eval(self) -> Optional[Any]: - """ - Wait for async eval to complete (if running). - - Returns: - Eval result if eval was running, None otherwise - """ - if self._eval_task is not None and not self._eval_task.done(): - logger.info("Waiting for async eval to complete...") - try: - result = await self._eval_task - return result - except asyncio.CancelledError: - logger.warning("Eval task was cancelled") - return None - return None - - def get_statistics(self) -> Dict[str, Any]: - """ - Get performance statistics. - - Returns: - Dictionary with timing statistics - """ - stats = { - 'collect_count': self._collect_count, - 'train_count': self._train_count, - 'collect_train_lag': self.collect_train_lag, - 'mode': 'synchronous' if self.is_synchronous else 'asynchronous', - } - - if self._collect_times: - stats['collect_avg_time'] = sum(self._collect_times) / len(self._collect_times) - stats['collect_total_time'] = sum(self._collect_times) - - if self._train_times: - stats['train_avg_time'] = sum(self._train_times) / len(self._train_times) - stats['train_total_time'] = sum(self._train_times) - - if self._eval_times: - stats['eval_avg_time'] = sum(self._eval_times) / len(self._eval_times) - stats['eval_total_time'] = sum(self._eval_times) - - return stats - - def reset_counters(self): - """Reset all counters (useful for testing).""" - self._collect_count = 0 - self._train_count = 0 - self._collect_times.clear() - self._train_times.clear() - self._eval_times.clear() - logger.info("AsyncTrainingCoordinator counters reset") - - -async def run_async_training_loop( - coordinator: AsyncTrainingCoordinator, - collect_fn: Callable[[], Awaitable[Any]], - train_fn: Callable[[], Awaitable[Any]], - eval_fn: Callable[[], Awaitable[Any]], - eval_interval: int, - max_iterations: int, -): - """ - Main async training loop that coordinates collect/train/eval. - - Args: - coordinator: AsyncTrainingCoordinator instance - collect_fn: Async collection function - train_fn: Async training function - eval_fn: Async evaluation function - eval_interval: How often to run eval (in iterations) - max_iterations: Maximum training iterations - """ - logger.info(f"Starting async training loop (max_iter={max_iterations})") - - if coordinator.is_synchronous: - # ======================================================================== - # SYNCHRONOUS MODE: Original serial execution - # ======================================================================== - logger.info("Running in SYNCHRONOUS mode") - - for iteration in range(max_iterations): - # 1. Collect - logger.info(f"[Iter {iteration}] Collecting...") - await coordinator.run_collect(collect_fn) - - # 2. Train - logger.info(f"[Iter {iteration}] Training...") - await coordinator.run_train(train_fn) - - # 3. Eval (if needed) - if iteration % eval_interval == 0: - logger.info(f"[Iter {iteration}] Evaluating...") - await coordinator.run_eval(eval_fn) - - else: - # ======================================================================== - # ASYNCHRONOUS MODE: Concurrent execution with bounded lag - # ======================================================================== - logger.info(f"Running in ASYNCHRONOUS mode (off_policy_degree={coordinator.off_policy_degree})") - - # Create tasks for collect and train - collect_task = None - train_tasks = [] - - iteration = 0 - while iteration < max_iterations: - tasks_to_wait = [] - - # Start collect if allowed - if coordinator.can_collect() and (collect_task is None or collect_task.done()): - logger.debug(f"[Iter {iteration}] Starting collect task") - collect_task = asyncio.create_task(coordinator.run_collect(collect_fn)) - tasks_to_wait.append(collect_task) - - # Start train if allowed and there's data - if coordinator.can_train(): - logger.debug(f"[Iter {iteration}] Starting train task") - train_task = asyncio.create_task(coordinator.run_train(train_fn)) - train_tasks.append(train_task) - tasks_to_wait.append(train_task) - iteration += 1 - - # Eval (if needed) - if iteration % eval_interval == 0 and iteration > 0: - logger.info(f"[Iter {iteration}] Triggering eval") - await coordinator.run_eval(eval_fn) - - # Wait for at least one task to complete - if tasks_to_wait: - done, pending = await asyncio.wait(tasks_to_wait, return_when=asyncio.FIRST_COMPLETED) - logger.debug(f"Tasks completed: {len(done)}, pending: {len(pending)}") - else: - # No tasks ready, wait a bit - await asyncio.sleep(0.1) - - # Clean up completed train tasks - train_tasks = [t for t in train_tasks if not t.done()] - - # Wait for all remaining tasks - logger.info("Waiting for remaining tasks to complete...") - if collect_task and not collect_task.done(): - await collect_task - for task in train_tasks: - if not task.done(): - await task - - # Wait for eval if running - await coordinator.wait_for_eval() - - # Print statistics - stats = coordinator.get_statistics() - logger.info("="*80) - logger.info("Training Loop Statistics:") - logger.info(f" Mode: {stats['mode']}") - logger.info(f" Collect count: {stats['collect_count']}") - logger.info(f" Train count: {stats['train_count']}") - logger.info(f" Final lag: {stats['collect_train_lag']}") - if 'collect_avg_time' in stats: - logger.info(f" Avg collect time: {stats['collect_avg_time']:.2f}s") - if 'train_avg_time' in stats: - logger.info(f" Avg train time: {stats['train_avg_time']:.2f}s") - if 'eval_avg_time' in stats: - logger.info(f" Avg eval time: {stats['eval_avg_time']:.2f}s") - logger.info("="*80) diff --git a/zoo/jericho/priorzero/ensure_local_lightzero.py b/zoo/jericho/priorzero/ensure_local_lightzero.py deleted file mode 100644 index 7a697176b..000000000 --- a/zoo/jericho/priorzero/ensure_local_lightzero.py +++ /dev/null @@ -1,68 +0,0 @@ -""" -Utility module to ensure local LightZero is used across all PriorZero modules. - -This ensures PriorZero uses the local LightZero installation at: -/mnt/nfs/zhangjinouwen/puyuan/LightZero - -Usage: - Import this at the beginning of any PriorZero module: - - from ensure_local_lightzero import ensure_local_lightzero - ensure_local_lightzero() -""" - -import sys -from pathlib import Path - - -def ensure_local_lightzero(): - """ - Ensures the local LightZero path is first in sys.path. - - This allows PriorZero to use a LightZero version that has been - specifically adapted for PriorZero, rather than a globally installed version. - - Also adds the PriorZero directory to sys.path to ensure PriorZero modules - can be imported. - """ - LIGHTZERO_ROOT = Path("/mnt/nfs/zhangjinouwen/puyuan/LightZero").resolve() - PRIORZERO_DIR = Path(__file__).parent.resolve() - - if not LIGHTZERO_ROOT.exists(): - print(f"⚠️ Warning: LightZero root not found at {LIGHTZERO_ROOT}") - return False - - lightzero_str = str(LIGHTZERO_ROOT) - priorzero_str = str(PRIORZERO_DIR) - - # Remove any existing LightZero paths from sys.path - sys.path = [p for p in sys.path if 'LightZero' not in p or p == lightzero_str] - - # Insert local LightZero at the beginning - if lightzero_str not in sys.path: - sys.path.insert(0, lightzero_str) - - # Also ensure PriorZero directory is in sys.path for module imports - if priorzero_str not in sys.path: - sys.path.insert(0, priorzero_str) - - # Verify - try: - import lzero - lzero_path = Path(lzero.__file__).parent.parent - - if lzero_path == LIGHTZERO_ROOT: - print(f"✓ Using local LightZero: {lzero_path}") - print(f"✓ PriorZero modules path: {priorzero_str}") - return True - else: - print(f"⚠️ Warning: Using LightZero from {lzero_path}") - print(f" Expected: {LIGHTZERO_ROOT}") - return False - except ImportError as e: - print(f"⚠️ Warning: Could not import lzero: {e}") - return False - - -# Auto-ensure on import -ensure_local_lightzero() diff --git a/zoo/jericho/priorzero/fix_environment.sh b/zoo/jericho/priorzero/fix_environment.sh deleted file mode 100644 index 8876f54df..000000000 --- a/zoo/jericho/priorzero/fix_environment.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash -# fix_environment.sh -# Fix numpy version conflicts and other dependency issues - -echo "==========================================" -echo "Fixing PriorZero Environment Dependencies" -echo "==========================================" - -# 1. Fix numpy version (downgrade to 1.26.4 for compatibility) -echo "" -echo "1. Fixing numpy version..." -pip install "numpy<2,>=1.24.1" --force-reinstall --no-deps - -# 2. Reinstall conflicting packages -echo "" -echo "2. Reinstalling di-engine and lightzero..." -pip install di-engine==0.5.3 --no-deps -pip install lightzero==0.2.0 --no-deps - -# 3. Verify installations -echo "" -echo "3. Verifying installations..." -python -c "import numpy; print(f'numpy version: {numpy.__version__}')" -python -c "import torch; print(f'torch version: {torch.__version__}')" -python -c "import vllm; print(f'vllm version: {vllm.__version__}')" - -echo "" -echo "==========================================" -echo "Environment fix complete!" -echo "==========================================" -echo "" -echo "Now you can run:" -echo " python priorzero_config.py" -echo " python game_segment_priorzero.py" -echo " python priorzero_entry.py --quick_test" diff --git a/zoo/jericho/priorzero/game_segment_priorzero.py b/zoo/jericho/priorzero/game_segment_priorzero.py index 654b93e5c..7ae62d701 100644 --- a/zoo/jericho/priorzero/game_segment_priorzero.py +++ b/zoo/jericho/priorzero/game_segment_priorzero.py @@ -1,36 +1,9 @@ -# game_segment_priorzero.py -""" -[PRIORZERO] Enhanced Game Segment for PriorZero - -This module extends the standard GameSegment to store additional information -needed for LLM policy training (SFT + RFT). - -Key Features: -- Store MCTS policy distributions for SFT training -- Store raw text observations for LLM prompt construction -- Store LLM generated priors for analysis and debugging -- Store search values for priority calculation - -Author: PriorZero Team -Date: 2025-01-20 -""" - import numpy as np from typing import Optional, List, Any from lzero.mcts.buffer.game_segment import GameSegment as OriginalGameSegment class GameSegment(OriginalGameSegment): - """ - [PRIORZERO-MODIFIED] - Enhanced GameSegment that stores additional data for PriorZero training. - - New attributes: - - mcts_policy_segment: List of MCTS visit count distributions (for SFT) - - raw_obs_segment: List of raw text observations (for LLM prompts) - - llm_prior_segment: List of LLM generated text (for debugging) - - search_value_segment: List of MCTS search values (for priority) - """ def __init__( self, @@ -39,38 +12,37 @@ def __init__( config: Optional[Any] = None, task_id: Optional[int] = None ): - """ - Initialize enhanced GameSegment. - - Args: - action_space: Action space from environment - game_segment_length: Maximum length of the segment - config: Policy configuration - task_id: Task ID for multi-task learning - """ super().__init__(action_space, game_segment_length, config, task_id) - # [PRIORZERO-NEW] Additional segments for LLM training - self.mcts_policy_segment = [] # MCTS visit count distributions self.raw_obs_segment = [] # Raw text observations - self.llm_prior_segment = [] # LLM generated priors (for debugging) - self.search_value_segment = [] # MCTS search values + self.history_obs_segment = [] + self.llm_prior_per_tok_segment = [] # LLM prior per token (for debugging) + self.cot_prefix_segment = [] # CoT prefixes for reuse (optimization) + self.llm_action_segment = [] # Actions selected by LLM - def reset(self, init_observations: List[np.ndarray]) -> None: + def reset(self, init_observations: List[np.ndarray], init_raw_obs, init_history_obs) -> None: """ [PRIORZERO-MODIFIED] Reset the segment with initial observations. Args: init_observations: List of initial frame stack observations + init_raw_obs: Initial raw text observation + init_history_obs: Initial history observations """ super().reset(init_observations) - - # Clear PriorZero-specific segments - self.mcts_policy_segment.clear() self.raw_obs_segment.clear() - self.llm_prior_segment.clear() - self.search_value_segment.clear() + self.history_obs_segment.clear() + self.llm_prior_per_tok_segment.clear() + self.cot_prefix_segment.clear() # Clear CoT prefix segment + self.llm_action_segment.clear() + + # 以下结果均是第 t 时刻的结果 + self.raw_obs_segment.append(init_raw_obs) + self.history_obs_segment.append(init_history_obs) + self.llm_prior_per_tok_segment.append(None) + self.cot_prefix_segment.append(None) + self.llm_action_segment.append(None) def append( self, @@ -79,383 +51,152 @@ def append( reward: float, action_mask: np.ndarray, to_play: int, + timestep: int = 0, + chance: int = 0, + raw_obs_text: Optional[str] = None, + history_obs: Optional[List[str]] = None, + llm_prior_per_tok = None, + cot_prefix: Optional[str] = None, + llm_action: Optional[str] = None, **kwargs ) -> None: - """ - [PRIORZERO-MODIFIED] - Append a new transition to the segment. - - Args: - action: Action taken - obs: Observation received - reward: Reward received - action_mask: Valid action mask - to_play: Player ID (for multi-agent) - **kwargs: Additional arguments (timestep, chance, raw_obs_text, llm_prior_text) - """ - # [PRIORZERO-NEW] Extract PriorZero-specific kwargs before passing to parent - raw_obs_text = kwargs.pop('raw_obs_text', None) - llm_prior_text = kwargs.pop('llm_prior_text', None) - - # [DEBUG] Log first few appends to see what's being passed - if len(self.raw_obs_segment) < 3: - print(f"[SEGMENT_DEBUG] append() called: kwargs keys = {list(kwargs.keys())}") - print(f"[SEGMENT_DEBUG] raw_obs_text = {raw_obs_text[:50] if raw_obs_text else 'None'}...") - - # Call parent append with remaining kwargs - super().append(action, obs, reward, action_mask, to_play, **kwargs) - - # [PRIORZERO-NEW] Initialize placeholders for new segments - # These will be filled in by store_search_stats() - self.mcts_policy_segment.append(None) - self.search_value_segment.append(None) - - # [PRIORZERO-NEW] Store raw text observation if provided + + super().append(action, obs, reward, action_mask, to_play, timestep, chance) self.raw_obs_segment.append(raw_obs_text) + self.history_obs_segment.append(history_obs) + self.llm_prior_per_tok_segment.append(llm_prior_per_tok) + self.cot_prefix_segment.append(cot_prefix) + self.llm_action_segment.append(llm_action) - # [PRIORZERO-NEW] Store LLM prior text if provided (for debugging) - self.llm_prior_segment.append(llm_prior_text) - - def store_search_stats( - self, - root_visit_dist: List[float], - value: float, - *args, - **kwargs - ) -> None: - """ - [PRIORZERO-MODIFIED] - Store MCTS search statistics. - - This method is called after MCTS search to store the visit count - distribution and search value. These will be used for: - - SFT training: MCTS policy as supervision signal for LLM - - Priority calculation: Search value for prioritized replay - - Args: - root_visit_dist: Visit count distribution from MCTS - value: Search value from MCTS - *args: Additional positional arguments (for compatibility) - **kwargs: Additional keyword arguments (improved_policy, etc.) - """ - # [FIX] Handle NaN values - import numpy as np - if value is None or (isinstance(value, float) and np.isnan(value)): - # Use 0.0 as default for NaN values - value = 0.0 - - # Call parent method to store standard statistics - super().store_search_stats(root_visit_dist, value, *args, **kwargs) - - # [PRIORZERO-NEW] Store MCTS policy distribution - # Convert to numpy array and normalize to probability distribution - policy_array = np.array(root_visit_dist, dtype=np.float32) - - if policy_array.sum() > 0: - policy_array = policy_array / policy_array.sum() - else: - # If no visits (shouldn't happen), use uniform distribution - policy_array = np.ones_like(policy_array) / len(policy_array) - - # Update the most recent position (corresponding to last append) - if len(self.mcts_policy_segment) > 0: - self.mcts_policy_segment[-1] = policy_array - - # [PRIORZERO-NEW] Store search value - if len(self.search_value_segment) > 0: - self.search_value_segment[-1] = float(value) + def store_search_stats(self, visit_counts: List, root_value: List) -> None: + super().store_search_stats(visit_counts, root_value) def game_segment_to_array(self) -> None: - """ - [PRIORZERO-MODIFIED] - Convert all segment lists to numpy arrays for efficient storage. - - This is called when the segment is full and ready to be stored in - the replay buffer. - """ - # Call parent method to convert standard segments super().game_segment_to_array() - - # [PRIORZERO-NEW] Convert PriorZero-specific segments to arrays - # Use object dtype to handle variable-length arrays and None values - self.mcts_policy_segment = np.array(self.mcts_policy_segment, dtype=object) - self.search_value_segment = np.array(self.search_value_segment, dtype=np.float32) - - # For text data, keep as list (more flexible for variable-length strings) - # self.raw_obs_segment and self.llm_prior_segment remain as lists - - def get_stats(self) -> dict: - """ - [PRIORZERO-NEW] - Get statistics about this game segment. - - Returns: - stats: Dictionary of statistics - """ - stats = { - 'segment_length': len(self.reward_segment) if hasattr(self, 'reward_segment') else 0, - 'total_reward': sum(self.reward_segment) if hasattr(self, 'reward_segment') else 0, - 'num_mcts_policies': sum(1 for p in self.mcts_policy_segment if p is not None), - 'num_raw_obs': sum(1 for o in self.raw_obs_segment if o is not None), - 'num_llm_priors': sum(1 for p in self.llm_prior_segment if p is not None), - 'avg_search_value': np.mean([v for v in self.search_value_segment if v is not None]) if any(v is not None for v in self.search_value_segment) else 0.0, - } - return stats - - def get_mcts_policy_for_training(self, index: int) -> Optional[np.ndarray]: - """ - [PRIORZERO-NEW] - Get MCTS policy at a specific index for training. - - Args: - index: Index in the segment - - Returns: - policy: MCTS policy distribution, or None if not available - """ - if 0 <= index < len(self.mcts_policy_segment): - return self.mcts_policy_segment[index] - return None - - def get_raw_obs_for_training(self, index: int) -> Optional[str]: - """ - [PRIORZERO-NEW] - Get raw text observation at a specific index for training. + + def pad_over( + self, next_segment_observations: List, next_segment_rewards: List, next_segment_actions: List, next_segment_root_values: List, + next_segment_child_visits: List, next_segment_improved_policy: List = None, next_chances: List = None, + next_segment_raw_obs: List = None, next_segment_history_obs: List = None, next_segment_llm_prior_per_tok: List = None, + next_segment_cot_prefix: List = None, next_segment_llm_action: List = None + ) -> None: + super().pad_over( + next_segment_observations, next_segment_rewards, next_segment_actions, next_segment_root_values, + next_segment_child_visits, next_segment_improved_policy, next_chances + ) + assert len(next_segment_raw_obs) <= self.num_unroll_steps + self.td_steps + assert len(next_segment_history_obs) <= self.num_unroll_steps + self.td_steps + assert len(next_segment_llm_prior_per_tok) <= self.num_unroll_steps + self.td_steps + assert len(next_segment_cot_prefix) <= self.num_unroll_steps + self.td_steps + assert len(next_segment_llm_action) <= self.num_unroll_steps + self.td_steps + + import copy + if len(next_segment_history_obs) > 0: + assert self.raw_obs_segment[-1] == next_segment_llm_prior_per_tok[0]['current_obs'] + assert self.history_obs_segment[-1] == next_segment_llm_prior_per_tok[0]['history'] + assert self.history_obs_segment[-1][-1][1] == self.llm_action_segment[-1] + assert next_segment_history_obs[0][-1][1] == next_segment_llm_action[0] + + for raw_obs in next_segment_raw_obs: + self.raw_obs_segment.append(copy.deepcopy(raw_obs)) + for history_obs in next_segment_history_obs: + self.history_obs_segment.append(copy.deepcopy(history_obs)) + for lp in next_segment_llm_prior_per_tok: + self.llm_prior_per_tok_segment.append(copy.deepcopy(lp)) + for action in next_segment_llm_action: + self.llm_action_segment.append(copy.deepcopy(action)) + + # Handle CoT prefix padding (optimization for CoT reuse) + if next_segment_cot_prefix is not None: + for cot_prefix in next_segment_cot_prefix: + self.cot_prefix_segment.append(copy.deepcopy(cot_prefix)) + + def get_unroll_raw_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool = False) -> np.ndarray: + """ + Overview: + Get an observation of the correct format: o[t, t + stack frames + num_unroll_steps]. + Arguments: + - timestep (int): The time step. + - num_unroll_steps (int): The extra length of the observation frames. + - padding (bool): If True, pad frames if (t + stack frames) is outside of the trajectory. + """ + stacked_raw_obs = self.raw_obs_segment[timestep:timestep + self.frame_stack_num + num_unroll_steps] + if padding: + pad_len = self.frame_stack_num + num_unroll_steps - len(stacked_raw_obs) + if pad_len > 0: + stacked_raw_obs = stacked_raw_obs[:-1] + pad_frames = [stacked_raw_obs[-1] for _ in range(pad_len + 1)] + stacked_raw_obs = stacked_raw_obs + pad_frames + return stacked_raw_obs + + def get_unroll_histroy_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool = False) -> np.ndarray: + """ + Overview: + Get an observation of the correct format: o[t, t + stack frames + num_unroll_steps]. + Arguments: + - timestep (int): The time step. + - num_unroll_steps (int): The extra length of the observation frames. + - padding (bool): If True, pad frames if (t + stack frames) is outside of the trajectory. + """ + stacked_histroy_obs = self.history_obs_segment[timestep:timestep + self.frame_stack_num + num_unroll_steps] + if padding: + pad_len = self.frame_stack_num + num_unroll_steps - len(stacked_histroy_obs) + if pad_len > 0: + stacked_histroy_obs = stacked_histroy_obs[:-1] + pad_frames = [stacked_histroy_obs[-1] for _ in range(pad_len + 1)] + stacked_histroy_obs = stacked_histroy_obs + pad_frames + return stacked_histroy_obs + + def get_unroll_llm_prior_per_tok(self, timestep: int, num_unroll_steps: int = 0, padding: bool = False) -> np.ndarray: + """ + Return LLM prior per token aligned with actions for unroll window. + """ + stacked_prior = list(self.llm_prior_per_tok_segment[timestep:timestep + self.frame_stack_num + num_unroll_steps]) + if padding: + pad_len = self.frame_stack_num + num_unroll_steps - len(stacked_prior) + if pad_len > 0: + pad_frames = [stacked_prior[-1] for _ in range(pad_len)] + stacked_prior = stacked_prior + pad_frames + return stacked_prior + + def get_unroll_cot_prefix(self, timestep: int, num_unroll_steps: int = 0, padding: bool = False) -> List[str]: + """ + Return CoT prefixes aligned with observations for unroll window (CoT reuse optimization). Args: - index: Index in the segment + timestep: The time step + num_unroll_steps: The extra length of the CoT prefix frames + padding: If True, pad frames if outside of trajectory Returns: - raw_obs: Raw text observation, or None if not available + List of CoT prefix strings """ - if 0 <= index < len(self.raw_obs_segment): - return self.raw_obs_segment[index] - return None + stacked_cot_prefix = list(self.cot_prefix_segment[timestep:timestep + self.frame_stack_num +num_unroll_steps]) + if padding: + pad_len = self.frame_stack_num + num_unroll_steps - len(stacked_cot_prefix) + if pad_len > 0: + # Pad with empty strings or last prefix + pad_frames = [stacked_cot_prefix[-1] for _ in range(pad_len)] + stacked_cot_prefix = stacked_cot_prefix + pad_frames + return stacked_cot_prefix - def get_history_for_training(self, index: int, history_length: int = 5) -> List[tuple]: + def get_unroll_llm_action(self, timestep: int, num_unroll_steps: int = 0, padding: bool = False) -> List[str]: """ - [PRIORZERO-NEW] - Get history context for LLM prompting. + Return LLM actions aligned with observations for unroll window. Args: - index: Current index in the segment - history_length: Number of past transitions to include + timestep: The time step + num_unroll_steps: The extra length of the CoT prefix frames + padding: If True, pad frames if outside of trajectory Returns: - history: List of (obs, action, reward) tuples - """ - history = [] - - # Get recent transitions - start_idx = max(0, index - history_length) - for i in range(start_idx, index): - if i < len(self.raw_obs_segment) and i < len(self.action_segment) and i < len(self.reward_segment): - obs_text = self.raw_obs_segment[i] - action_id = self.action_segment[i] - reward = self.reward_segment[i] - - # Only add if observation is available - if obs_text is not None: - history.append((obs_text, action_id, reward)) - - return history - - def __repr__(self) -> str: - """ - [PRIORZERO-MODIFIED] - String representation with PriorZero statistics. - """ - base_repr = super().__repr__() - stats = self.get_stats() - - priorzero_info = ( - f"\n MCTS policies: {stats['num_mcts_policies']}" - f"\n Raw observations: {stats['num_raw_obs']}" - f"\n LLM priors: {stats['num_llm_priors']}" - f"\n Avg search value: {stats['avg_search_value']:.3f}" - ) - - return base_repr + priorzero_info - - -# ============================================================================== -# Utility Functions -# ============================================================================== - -def create_priorzero_game_segment( - action_space, - game_segment_length: int = 200, - config: Optional[Any] = None, - task_id: Optional[int] = None -) -> GameSegment: - """ - Factory function to create a PriorZero GameSegment. - - Args: - action_space: Action space from environment - game_segment_length: Maximum length of the segment - config: Policy configuration - task_id: Task ID for multi-task learning - - Returns: - segment: PriorZero GameSegment instance - """ - return GameSegment(action_space, game_segment_length, config, task_id) - - -def validate_game_segment(segment: GameSegment) -> bool: - """ - Validate that a GameSegment has consistent data. - - Args: - segment: GameSegment to validate - - Returns: - is_valid: True if segment is valid, False otherwise - """ - try: - # Check basic lengths - if not hasattr(segment, 'obs_segment'): - return False - - base_length = len(segment.obs_segment) - - # Check that all segments have compatible lengths - if hasattr(segment, 'action_segment'): - if len(segment.action_segment) != base_length: - return False - - if hasattr(segment, 'reward_segment'): - if len(segment.reward_segment) != base_length: - return False - - # Check PriorZero-specific segments - if len(segment.mcts_policy_segment) != base_length: - return False - - if len(segment.raw_obs_segment) != base_length: - return False - - # Check that MCTS policies are valid when present - for policy in segment.mcts_policy_segment: - if policy is not None: - if not isinstance(policy, np.ndarray): - return False - if policy.sum() < 0.99 or policy.sum() > 1.01: # Should sum to ~1.0 - return False - if np.any(policy < 0): # Should be non-negative - return False - - return True - - except Exception as e: - print(f"Validation error: {e}") - return False - - -# ============================================================================== -# Example Usage and Testing -# ============================================================================== - -if __name__ == "__main__": - print("="*80) - print("Testing PriorZero GameSegment") - print("="*80) - - # Create a mock action space - class MockActionSpace: - def __init__(self, n): - self.n = n - - # Create a mock config with all required attributes - class MockConfig: - def __init__(self): - self.num_unroll_steps = 10 - self.td_steps = 5 - self.discount_factor = 0.99 - self.gray_scale = False - self.transform2string = False - self.sampled_algo = False - self.gumbel_algo = False - self.use_ture_chance_label_in_chance_encoder = False - self.model = type('obj', (object,), { - 'frame_stack_num': 4, - 'action_space_size': 10, - 'observation_shape': (84, 84, 3), - 'image_channel': 3 - })() - - action_space = MockActionSpace(n=10) - mock_config = MockConfig() - - # Create a game segment - segment = GameSegment(action_space, game_segment_length=100, config=mock_config) - - # Reset with initial observations - init_obs = [np.zeros((84, 84, 3)) for _ in range(4)] - segment.reset(init_obs) - - print("\n1. Empty segment:") - print(f" Length: {len(segment.obs_segment)}") - print(f" MCTS policies: {len(segment.mcts_policy_segment)}") - - # Simulate some transitions - print("\n2. Adding transitions...") - for i in range(5): - obs = np.random.rand(84, 84, 3) - action = np.random.randint(0, 10) - reward = np.random.randn() - action_mask = np.ones(10) - - # Append transition - segment.append( - action, obs, reward, action_mask, to_play=0, - raw_obs_text=f"You see a room. Step {i}.", - llm_prior_text=f"Top actions: go north, take key" - ) - - # Store MCTS stats - visit_dist = np.random.dirichlet([1.0] * 10).tolist() - value = np.random.randn() - segment.store_search_stats(visit_dist, value) - - print(f" Added {len(segment.obs_segment)} transitions") - - # Get statistics - print("\n3. Segment statistics:") - stats = segment.get_stats() - for key, value in stats.items(): - print(f" {key}: {value}") - - # Test retrieval functions - print("\n4. Testing retrieval functions:") - mcts_policy = segment.get_mcts_policy_for_training(2) - print(f" MCTS policy at index 2: {mcts_policy is not None}") - if mcts_policy is not None: - print(f" Shape: {mcts_policy.shape}") - print(f" Sum: {mcts_policy.sum():.3f}") - - raw_obs = segment.get_raw_obs_for_training(2) - print(f" Raw obs at index 2: {raw_obs}") - - history = segment.get_history_for_training(4, history_length=3) - print(f" History for index 4: {len(history)} transitions") - - # Validate segment - print("\n5. Validating segment:") - is_valid = validate_game_segment(segment) - print(f" Is valid: {is_valid}") - - # Convert to array - print("\n6. Converting to array:") - segment.game_segment_to_array() - print(f" MCTS policy type: {type(segment.mcts_policy_segment)}") - print(f" Search value type: {type(segment.search_value_segment)}") - - # Print representation - print("\n7. Segment representation:") - print(segment) - - print("\n" + "="*80) - print("✓ All tests passed!") - print("="*80) + List of LLM action strings + """ + stacked_llm_action = list(self.llm_action_segment[timestep:timestep + self.frame_stack_num + num_unroll_steps]) + if padding: + pad_len = self.frame_stack_num + num_unroll_steps - len(stacked_llm_action) + if pad_len > 0: + # Pad with empty strings or last action + pad_frames = [stacked_llm_action[-1] for _ in range(pad_len)] + stacked_llm_action = stacked_llm_action + pad_frames + return stacked_llm_action \ No newline at end of file diff --git a/zoo/jericho/priorzero/models/actor.py b/zoo/jericho/priorzero/models/actor.py new file mode 100644 index 000000000..1d93ef17b --- /dev/null +++ b/zoo/jericho/priorzero/models/actor.py @@ -0,0 +1,520 @@ +from typing import Optional, Union, List, Dict +from collections import defaultdict +import os +import math +from tqdm import tqdm +import numpy as np +import deepspeed +from torch.optim import Optimizer +import torch +import torch.distributed as dist +import torch.nn as nn +from transformers import AutoModelForCausalLM, BitsAndBytesConfig +from transformers.integrations.deepspeed import HfDeepSpeedConfig +from transformers.trainer import get_scheduler + +from utils import compute_approx_kl, compute_entropy, masked_mean, torch_dist_barrier_and_cuda_sync, log_probs_from_logits + +class Actor(nn.Module): + """ + Base class for Actor models in reinforcement learning. + + This class serves as a foundation for implementing various actor models, which are responsible for selecting actions based on the policy learned from the environment. + + Args: + pretrain_or_model (nn.Module): A pretrained model or a new model instance to be used as the actor. + attn_implementation (str, optional): Attention mechanism implementation to use. Defaults to "flash_attention_2". + bf16 (bool, optional): Enable bfloat16 precision for model computations. Defaults to True. + ds_config (dict, optional): Configuration for DeepSpeed, enabling model partitioning across multiple GPUs. Defaults to None. + device_map (dict, optional): Device mapping for loading the model onto specific devices. Defaults to None. + temperature (float, optional): Temperature for action selection. Defaults to 1.0. + """ + + def __init__( + self, + pretrain_or_model: str, + attn_implementation="flash_attention_2", + bf16=True, + ds_config=None, + device_map=None, + temperature=1.0, + **kwargs, + ) -> None: + super().__init__() + + self.temperature = temperature + attn_impl = attn_implementation + + if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: + _ = HfDeepSpeedConfig(ds_config) + else: + _ = None + + self.model = AutoModelForCausalLM.from_pretrained( + pretrain_or_model, + trust_remote_code=True, + attn_implementation=attn_impl, + torch_dtype=torch.bfloat16 if bf16 else "auto", + device_map=device_map, + ) + self.model.config.use_cache = False + + def forward( + self, + sequences: torch.LongTensor, + action_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + return_output=False, + return_entropy=False, + ) -> torch.Tensor: + + foward_attention_mask = attention_mask + rolled_sequences = torch.roll(sequences, shifts=-1, dims=1) + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + + output = self.model(sequences, attention_mask=foward_attention_mask, position_ids=position_ids) + output["logits"] = output["logits"].to(torch.float32) + + if return_entropy: + assert return_output + entropy = compute_entropy(output["logits"]) + setattr(output, "entropy", entropy[:, :-1]) + + log_probs = log_probs_from_logits(output["logits"], rolled_sequences, temperature=self.temperature) + + log_probs = log_probs[:, :-1] + + action_log_probs = log_probs[:, -action_mask.shape[1] :] * action_mask.float() + return (action_log_probs, output) if return_output else action_log_probs + + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs={"use_reentrant": False}): + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) + + def gradient_checkpointing_disable(self): + self.model.gradient_checkpointing_disable() + + def print_trainable_parameters(self): + self.model.print_trainable_parameters() + +class ReferenceModel: + def __init__(self, strategy, pretrain): + self.strategy = strategy + model = Actor( + pretrain, + attn_implementation=strategy.args.attn_implementation, + bf16=strategy.args.bf16, + ds_config=strategy.get_ds_eval_config( + offload=False + ), + temperature=strategy.args.temperature, + ) + self.model = strategy.prepare(model, is_rlhf=True) + self.model.eval() + self.micro_train_batch_size = self.strategy.args.micro_train_batch_size + + @torch.no_grad() + def forward( + self, + sequences: torch.LongTensor, + action_mask: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + """ + Return: action_log_probs [B, T_action] + """ + device = torch.cuda.current_device() + B = sequences.size(0) + outs = [] + chunk_size = max(self.micro_train_batch_size, 32) + + sequences = sequences.to(device) + attention_mask = attention_mask.to(device) + action_mask = action_mask.to(device) + for i in range(0, B, chunk_size): + s = sequences[i : i + chunk_size].to(device) + am = action_mask[i : i + chunk_size].to(device) + attn = attention_mask[i : i + chunk_size].to(device) + + out = self.model( + s, + action_mask=am, + attention_mask=attn, + ) + outs.append(out) + return torch.cat(outs, dim=0) + +class BatchPPOTrainer: + def __init__( + self, + strategy, + actor, + actor_optim, + actor_scheduler, + micro_train_batch_size: int = 8, + vllm_engine = None + ): + self.strategy = strategy + self.args = strategy.args + + self.actor = actor + self.actor_optim = actor_optim + self.actor_scheduler = actor_scheduler + self.vllm_engine = vllm_engine + self.use_cuda_ipc = self.args.use_cuda_ipc + + self.micro_train_batch_size = micro_train_batch_size + from models.loss import PolicyLoss + self.policy_loss = PolicyLoss( + clip_eps_low=self.args.eps_clip_low_high[0], + clip_eps_high=self.args.eps_clip_low_high[1], + policy_loss_type=self.args.policy_loss_type, + ) + self.train_iter = 0 + + def train_batch(self, batch_data: Dict[str, torch.Tensor], kl_ctl: float, step_idx: int = 0) -> Dict[str, float]: + device = torch.cuda.current_device() + for k, v in batch_data.items(): + if torch.is_tensor(v): + batch_data[k] = v.to(device) + + all_samples_size = batch_data["input_ids"].size(0) + status_list = [] + pbar = tqdm( + range(0, all_samples_size, self.micro_train_batch_size), + desc=f"PPO batch step={step_idx}", + disable=not self.strategy.is_rank_0(), + ) + acc_grad_steps = self.strategy.accumulated_gradient + metrics_buffer = defaultdict(list) # 用于累积 micro_step 指标的缓冲区 + + for micro_step, start_idx in enumerate(pbar): + end_idx = min(start_idx + self.micro_train_batch_size, all_samples_size) + micro_batch = { + 'input_ids': batch_data['input_ids'][start_idx:end_idx], + "attention_mask": batch_data['attention_mask'][start_idx:end_idx], + "action_mask": batch_data['action_mask'][start_idx:end_idx], + "advantages": batch_data['advantages'][start_idx:end_idx], + "old_action_logprob": batch_data['old_action_logprob'][start_idx:end_idx], + "log_status": batch_data['log_status'][start_idx:end_idx] + } + micro_batch['ref_action_log_probs'] = batch_data['ref_action_log_probs'][start_idx:end_idx] if batch_data['ref_action_log_probs'] is not None else None + + action_log_probs, output = self.actor( + micro_batch['input_ids'], + micro_batch['action_mask'], + attention_mask=micro_batch['attention_mask'], + return_output=True, + return_entropy=True, + ) + actor_loss, clipfrac, clip_ratio, approx_kl, vllm_kl = self.policy_loss( + action_log_probs, + micro_batch['old_action_logprob'], + micro_batch['advantages'], + action_mask=micro_batch['action_mask'], + ) + + if self.args.rft_kl_coef > 0 and micro_batch['ref_action_log_probs'] is not None: + kl = compute_approx_kl( + action_log_probs, + micro_batch['ref_action_log_probs'], + kl_estimator=self.args.kl_estimator + ) + kl_loss = masked_mean(kl, micro_batch["action_mask"]) + else: + kl_loss = torch.tensor(0.0, device=device) + + loss = actor_loss + kl_loss * float(kl_ctl.value) + + if self.args.entropy_loss_coef is not None: + entropy_loss = masked_mean(output.entropy[:, -micro_batch["action_mask"].shape[1] :], micro_batch["action_mask"]) + if self.args.entropy_loss_coef != 0: + loss -= entropy_loss * self.args.entropy_loss_coef + + self.strategy.backward(loss, self.actor, self.actor_optim) + self.strategy.optimizer_step(self.actor_optim, self.actor, self.actor_scheduler, name="actor") + + policy_loss_item = actor_loss.detach().float().item() + clipfrac_item = clipfrac.detach().float().item() + clip_ratio_item = clip_ratio.detach().float().item() + approx_kl_item = approx_kl.detach().float().item() + kl_loss_item = kl_loss.detach().float().item() + input_response_length_item = micro_batch["attention_mask"].sum().detach().float().item() / micro_batch["attention_mask"].shape[0] + response_length_item = micro_batch["action_mask"].sum().detach().float().item() / micro_batch["action_mask"].shape[0] + input_length_item = input_response_length_item - response_length_item + entropy_loss_item = entropy_loss.detach().float().item() if self.args.entropy_loss_coef is not None else None + + pbar.set_postfix({ + "policy_loss": policy_loss_item, + "clipfrac": clipfrac_item, + "approx_kl": approx_kl_item, + "iter": self.train_iter, + }) + + metrics_buffer["policy_loss"].append(policy_loss_item) + metrics_buffer["clipfrac"].append(clipfrac_item) + metrics_buffer["clip_ratio"].append(clip_ratio_item) + metrics_buffer["approx_kl"].append(approx_kl_item) + metrics_buffer["ref_kl"].append(kl_loss_item) + metrics_buffer["input_length"].append(input_length_item) + metrics_buffer["response_length"].append(response_length_item) + metrics_buffer['entropy'].append(entropy_loss_item) + + log_status = micro_batch["log_status"] + other_status = {k: [item[k] for item in log_status] for k in log_status[0].keys()} + for k, v in other_status.items(): + metrics_buffer[k] = v + + if ((micro_step + 1) % acc_grad_steps == 0) or ((micro_step + 1) == pbar.total): + self.train_iter += 1 + status = { + "policy_loss": np.mean(metrics_buffer['policy_loss']), + "clipfrac": np.mean(metrics_buffer['clipfrac']), + "clip_ratio": np.mean(metrics_buffer['clip_ratio']), + "approx_kl": np.mean(metrics_buffer['approx_kl']), + "ref_kl": np.mean(metrics_buffer['ref_kl']), + "entropy": np.mean(metrics_buffer['entropy']) if self.args.entropy_loss_coef is not None else None, + + "iter": self.train_iter, + "lr": self.actor_scheduler.get_last_lr()[0], + "global_grad_norm": self.actor_optim._global_grad_norm, + + "input_length_max": np.max(metrics_buffer['input_length']), + "input_length_mean": np.mean(metrics_buffer['input_length']), + "input_length_min": np.min(metrics_buffer['input_length']), + + "response_length_max": np.max(metrics_buffer['response_length']), + "response_length_mean": np.mean(metrics_buffer['response_length']), + "response_length_min": np.min(metrics_buffer['response_length']), + + "fmt_rewards": np.mean(metrics_buffer['fmt_rewards']) if "fmt_rewards" in metrics_buffer else None, + "value_advantage_max": np.max(metrics_buffer['value_advantage']), + "value_advantage_mean": np.mean(metrics_buffer['value_advantage']), + "value_advantage_min": np.min(metrics_buffer['value_advantage']), + "final_advantage_max": np.max(metrics_buffer['final_advantage']), + "final_advantage_mean": np.mean(metrics_buffer['final_advantage']), + "final_advantage_min": np.min(metrics_buffer['final_advantage']), + } + metrics_buffer.clear() + + status = self.strategy.all_reduce(status) + status_list.append(status) + + return status_list + + def _deepspeed_broadcast(self): + use_prefix_cache = getattr(self.strategy.args, "enable_prefix_caching", False) + if use_prefix_cache: + self.vllm_engine.reset_prefix_cache() + + torch.cuda.empty_cache() + model = self.actor.model.module + count, num_params = 0, len(list(model.named_parameters())) + for name, param in model.named_parameters(): + count += 1 # empty_cache at last param + # For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0 + with deepspeed.zero.GatheredParameters([param], enabled=self.strategy.args.zero_stage == 3): + shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape + self.vllm_engine.update_weight(name, dtype=param.dtype, shape=shape, weight=param.data, empty_cache=(count == num_params)) + + def _broadcast_to_vllm(self): + use_prefix_cache = getattr(self.strategy.args, "enable_prefix_caching", False) + if use_prefix_cache and torch.distributed.get_rank() == 0: + self.vllm_engine.reset_prefix_cache() + + torch.cuda.empty_cache() + model = self.actor.model + count, num_params = 0, len(list(model.named_parameters())) + + def _broadcast_param(param, count, num_params): + if torch.distributed.get_rank() == 0: + shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape + self.vllm_engine.update_weight(name, dtype=param.dtype, shape=shape, empty_cache=count == num_params) + + self._model_update_group.broadcast(param.data, src=0, stream=torch.cuda.current_stream()) + + def _handle_cuda_ipc(param, count, num_params): + from torch.multiprocessing.reductions import reduce_tensor + + weight = param.data.clone() + ipc_handle = reduce_tensor(weight) + + from vllm_utils.vllm_engine import get_physical_gpu_id + ipc_handle = {get_physical_gpu_id(): ipc_handle} + ipc_handle_list = [None] * torch.distributed.get_world_size() + torch.distributed.all_gather_object(ipc_handle_list, ipc_handle) + + if torch.distributed.get_rank() == 0: + ipc_handles = {} + for d in ipc_handle_list: + ipc_handles.update(d) + + shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape + self.vllm_engine.update_weight_cuda_ipc( + name, + dtype=param.dtype, + shape=shape, + ipc_handles=ipc_handles, + empty_cache=count == num_params, + ) + + torch_dist_barrier_and_cuda_sync() + + for name, param in model.named_parameters(): + count += 1 # empty_cache at last param + + # broadcast + if not self.use_cuda_ipc: + # For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0 + if self.strategy.args.ds_tensor_parallel_size > 1: + with deepspeed.module_inject.layers.GatherReplacedLayerParams([param], model, enabled=True): + _broadcast_param(param, count, num_params) + else: + with deepspeed.zero.GatheredParameters([param], enabled=self.strategy.args.zero_stage == 3): + _broadcast_param(param, count, num_params) + else: + if self.strategy.args.ds_tensor_parallel_size > 1: + with deepspeed.module_inject.layers.GatherReplacedLayerParams([param], model, enabled=True): + _handle_cuda_ipc(param, count, num_params) + else: + with deepspeed.zero.GatheredParameters([param], enabled=self.strategy.args.zero_stage == 3): + _handle_cuda_ipc(param, count, num_params) + + torch.cuda.empty_cache() + torch_dist_barrier_and_cuda_sync() + + +class PolicyModel: + def __init__( + self, + strategy, + pretrain: str, + max_steps: Optional[int] = None, + vllm_engine=None, + ): + self.strategy = strategy + args = strategy.args + + self.vllm_engine = vllm_engine + self.max_steps = max_steps + + if getattr(args, "vllm_num_engines", 0) > 0: + if getattr(args, "vllm_sync_backend", "nccl") == "nccl": + os.environ["NCCL_CUMEM_ENABLE"] = "0" + + actor = Actor( + pretrain, + attn_implementation=args.attn_implementation, + bf16=args.bf16, + ds_config=strategy.get_ds_train_config(is_actor=True), + temperature=args.temperature, + ) + strategy.print(actor) + + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + pretrain, trust_remote_code=True, padding_side="left" + ) + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + actor_optim = strategy.create_optimizer( + actor, + lr=args.learning_rate, + betas=args.adam_betas, + weight_decay=args.weight_decay, + ) + + if max_steps is None: + max_steps = int(getattr(args, "max_steps", 1_000_000)) + + actor_scheduler = get_scheduler( + args.lr_scheduler, + actor_optim, + num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio), + num_training_steps=max_steps, + scheduler_specific_kwargs={"min_lr": args.learning_rate * 0.1}, + ) + + if args.gradient_checkpointing: + actor.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} + ) + + self.actor, self.actor_optim, self.actor_scheduler = strategy.prepare( + (actor, actor_optim, actor_scheduler), + is_rlhf=True, + ) + + if strategy.args.deepspeed_enable_sleep: + from strategy.deepspeed import offload_deepspeed_states + offload_deepspeed_states(self.actor.model) + + self.trainer = BatchPPOTrainer( + strategy, + self.actor, + actor_optim=self.actor_optim, + actor_scheduler=self.actor_scheduler, + micro_train_batch_size=args.micro_train_batch_size, + vllm_engine = vllm_engine, + ) + + def fit(self, batch_data, kl_ctl: float = 0.0): + torch.cuda.empty_cache() + self.actor.train() + status = self.trainer.train_batch(batch_data, kl_ctl) + torch.cuda.empty_cache() + torch.cuda.synchronize() + return status + + @torch.no_grad() + def forward( + self, + sequences: torch.LongTensor, + action_mask: Optional[Union[int, list[int], torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + to_cpu: bool = False, + ) -> torch.Tensor: + self.actor.eval() + + if action_mask is None: + raise ValueError("action_mask is required for returning action_log_probs") + + device = torch.cuda.current_device() + sequences = sequences.to(device, non_blocking=True) + attention_mask = attention_mask.to(device, non_blocking=True) if attention_mask is not None else None + action_mask = action_mask.to(device, non_blocking=True) if torch.is_tensor(action_mask) else action_mask + + action_log_probs = self.actor( + sequences, + action_mask=action_mask, + attention_mask=attention_mask, + ring_attn_group=self.strategy.ring_attn_group, + packed_seq_lens=packed_seq_lens, + ) + + self.actor.train() + return action_log_probs.to("cpu") if to_cpu else action_log_probs + + def broadcast_to_vllm(self): + # self.trainer._broadcast_to_vllm() + self.trainer._deepspeed_broadcast() + + def save_model(self): + args = self.strategy.args + self.strategy.save_model( + self.actor, + self.tokenizer, + args.save_path, + ) + @property + def train_iter(self): + return self.trainer.train_iter + + def reload_states(self): + from strategy.deepspeed import reload_deepspeed_states + reload_deepspeed_states(self.actor.model) + + def offload_states(self): + from strategy.deepspeed import offload_deepspeed_states + offload_deepspeed_states(self.actor.model) \ No newline at end of file diff --git a/zoo/jericho/priorzero/models/loss.py b/zoo/jericho/priorzero/models/loss.py new file mode 100644 index 000000000..42e798780 --- /dev/null +++ b/zoo/jericho/priorzero/models/loss.py @@ -0,0 +1,109 @@ +from typing import Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +from utils import masked_mean + +class PolicyLoss(nn.Module): + """ + Policy Loss for PPO + """ + + def __init__( + self, + clip_eps_low: float = 0.2, + clip_eps_high: float = 0.2, + dual_clip: float = None, + token_level_loss: bool = True, + policy_loss_type: str = "ppo", + enable_vllm_is_correction: bool = False, + vllm_is_truncated_threshold: list = None, + use_icepop: bool = False, + ) -> None: + super().__init__() + self.clip_eps_low = clip_eps_low + self.clip_eps_high = clip_eps_high + self.token_level_loss = token_level_loss + self.dual_clip = dual_clip + self.policy_loss_type = policy_loss_type + self.enable_vllm_is_correction = enable_vllm_is_correction + self.vllm_is_truncated_threshold = vllm_is_truncated_threshold + self.use_icepop = use_icepop + + # GSPO requires sequence-level loss + if policy_loss_type == "gspo": + self.token_level_loss = False + + # Dual-clip PPO: https://arxiv.org/pdf/1912.09729 + if dual_clip is not None: + assert dual_clip > 1.0, f"dual_clip must be > 1.0, got {dual_clip}" + + def forward( + self, + log_probs: torch.Tensor, + old_log_probs: torch.Tensor, + advantages: torch.Tensor, + action_mask: Optional[torch.Tensor] = None, + rollout_log_probs: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self.policy_loss_type == "ppo": + log_ratio = log_probs - old_log_probs + ratio = log_ratio.exp() + elif self.policy_loss_type == "gspo": + # GSPO: https://arxiv.org/pdf/2507.18071 + if self.enable_vllm_is_correction: + log_ratio = log_probs - rollout_log_probs + else: + log_ratio = log_probs - old_log_probs + ratio = (log_ratio * action_mask).sum(dim=-1) / action_mask.sum(dim=-1) + ratio = ratio.exp().unsqueeze(-1) * action_mask + else: + raise ValueError(f"Invalid policy loss type: {self.policy_loss_type}") + if advantages.dim() == 1: + advantages = advantages.unsqueeze(-1) + + surr1 = ratio * advantages + surr2 = ratio.clamp(1 - self.clip_eps_low, 1 + self.clip_eps_high) * advantages + + if self.dual_clip is None: + # Standard PPO + loss = -torch.min(surr1, surr2) + else: + # Standard PPO clipping + clip1 = torch.min(surr1, surr2) + # Dual-clip: additional lower bound for negative advantages + clip2 = torch.max(clip1, self.dual_clip * advantages) + # Apply dual-clip: use clip2 for negative advantages, clip1 for positive advantages + loss = -torch.where(advantages < 0, clip2, clip1) + + # Your Efficient RL Framework Secretly Brings You Off-Policy RL Training: https://fengyao.notion.site/off-policy-rl + vllm_kl = None + if self.enable_vllm_is_correction and self.policy_loss_type == "ppo": + low_threshold, high_threshold = self.vllm_is_truncated_threshold + if self.use_icepop: + # ICEPOP: set coefficients outside the interval to 0 + vllm_is = torch.exp(old_log_probs - rollout_log_probs).detach() + mask = (vllm_is >= low_threshold) & (vllm_is <= high_threshold) + vllm_is = vllm_is * mask + else: + # Standard clamp with low and high thresholds + vllm_is = ( + torch.exp(old_log_probs - rollout_log_probs).clamp(min=low_threshold, max=high_threshold).detach() + ) + loss = vllm_is * loss + vllm_kl = masked_mean(rollout_log_probs - old_log_probs, action_mask, dim=None) + + loss = ( + masked_mean(loss, action_mask, dim=None) + if self.token_level_loss + else masked_mean(loss, action_mask, dim=-1).mean() + ) + clipped = ratio.gt(1 + self.clip_eps_high) | ratio.lt(1 - self.clip_eps_low) + clipfrac = masked_mean(clipped, action_mask, dim=None) + + clip_ratio = masked_mean(torch.lt(surr2, surr1).float(), action_mask, dim=None) + approx_kl = masked_mean(-log_ratio.detach(), action_mask, dim=None) + return loss, clipfrac, clip_ratio, approx_kl, vllm_kl \ No newline at end of file diff --git a/zoo/jericho/priorzero/models/stability_optimizer.py b/zoo/jericho/priorzero/models/stability_optimizer.py new file mode 100644 index 000000000..a05a0cb84 --- /dev/null +++ b/zoo/jericho/priorzero/models/stability_optimizer.py @@ -0,0 +1,145 @@ +import logging +from collections import deque +from typing import Dict, Optional, Tuple, Union + +import numpy as np +import torch + + +class AdaptiveValueNormalizer: + """ + 作用:把 value/return/advantage 变成稳定尺度(近似零均值、单位方差),并支持 soft(log-sym)/hard(percentile) 抑制极端值。 + 核心:batch 统计(只看当前) + EMA 运行统计(全局追踪非平稳) + 可选裁剪/压缩。 + """ + + def __init__( + self, + init_momentum: float = 0.9, + final_momentum: float = 0.99, + warmup_steps: int = 100, + clip_method: str = "soft", # "soft" | "hard" | "none" + clip_percentile: float = 0.95, # hard clip 中间保留比例,如 0.95 => 保留 [2.5%, 97.5%] + min_std: float = 1e-6, + hard_clip_start_updates: int = 10, # hard clip 前几次不启用 + history_size: int = 1000, + ): + self.init_momentum = init_momentum + self.final_momentum = final_momentum + self.warmup_steps = warmup_steps + self.clip_method = clip_method + self.clip_percentile = clip_percentile + self.min_std = min_std + self.hard_clip_start_updates = hard_clip_start_updates + + self.running_mean = 0.0 + self.running_std = 1.0 + self.update_count = 0 + + self.value_history = deque(maxlen=history_size) + + def _momentum(self) -> float: + if self.update_count >= self.warmup_steps: + return self.final_momentum + p = self.update_count / max(self.warmup_steps, 1) + return self.init_momentum + (self.final_momentum - self.init_momentum) * p + + @staticmethod + def _log_sym(x: torch.Tensor) -> Tuple[torch.Tensor, int]: + # f(x)=sign(x)*log(1+|x|) + significant = int((x.abs() > 10).sum()) + y = torch.sign(x) * torch.log1p(torch.abs(x)) + return y, significant + + def _hard_percentile_clip(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]: + if self.update_count < self.hard_clip_start_updates: + return x, 0 + q = self.clip_percentile + lo = (1 - q) / 2 + hi = 1 - lo + + xf = x.flatten() + lb = torch.quantile(xf, lo) + ub = torch.quantile(xf, hi) + y = torch.clamp(x, lb, ub) + + clipped = int((y != x).sum()) + return y, clipped + + def _batch_mean_std(self, x: torch.Tensor) -> Tuple[float, float]: + xf = x.flatten() + n = xf.numel() + if n == 0: + return 0.0, 1.0 + if n == 1: + mean = float(xf.item()) + return mean, self.min_std + + xf64 = xf.to(torch.float64) + mean = float(xf64.mean().item()) + var = float(xf64.var(unbiased=True).item()) + std = max(var ** 0.5, self.min_std) + return mean, std + + def normalize( + self, + values: torch.Tensor, + clip_values: bool = True, + return_stats: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict]]: + x = values.detach() + + clipped_count = 0 + if clip_values: + if self.clip_method == "soft": + x, clipped_count = self._log_sym(x) + elif self.clip_method == "hard": + x, clipped_count = self._hard_percentile_clip(x) + else: + raise ValueError(f"Unknown clip_method: {self.clip_method}") + + batch_mean, batch_std = self._batch_mean_std(x) + + m = self._momentum() + if self.update_count == 0: + self.running_mean = batch_mean + self.running_std = batch_std + else: + self.running_mean = m * self.running_mean + (1 - m) * batch_mean + self.running_std = m * self.running_std + (1 - m) * batch_std + + self.update_count += 1 + self.value_history.extend(x.flatten().float().cpu().tolist()) + + + y = (x.to(values.dtype) - self.running_mean) / (self.running_std + self.min_std) + + if not return_stats: + return y + + stats = { + "batch_mean": batch_mean, + "batch_std": batch_std, + "running_mean": self.running_mean, + "running_std": self.running_std, + "momentum": m, + "clip_method": self.clip_method, + "clipped_count": clipped_count, + "total_count": int(x.numel()), + } + return y, stats + + def summary(self) -> Dict: + if self.update_count == 0: + return {} + recent = list(self.value_history)[-min(100, len(self.value_history)) :] + return { + "total_updates": self.update_count, + "current_mean": float(self.running_mean), + "current_std": float(self.running_std), + "recent_mean": float(np.mean(recent)) if recent else 0.0, + "recent_std": float(np.std(recent)) if recent else 1.0, + "recent_min": float(np.min(recent)) if recent else 0.0, + "recent_max": float(np.max(recent)) if recent else 0.0, + "clip_method": self.clip_method, + } + diff --git a/zoo/jericho/priorzero/priorzero_collector.py b/zoo/jericho/priorzero/priorzero_collector.py index 1fb6e53c7..5f6a8c653 100644 --- a/zoo/jericho/priorzero/priorzero_collector.py +++ b/zoo/jericho/priorzero/priorzero_collector.py @@ -1,44 +1,25 @@ -# priorzero_collector.py -""" -[PRIORZERO] PriorZero Collector Implementation - -This module implements async data collection with LLM prior integration. - -Key Features: -- Async LLM inference using vLLM for efficient batch generation -- History buffer management for context-aware prompting -- Error handling and retry logic for robust LLM calls -- Full alignment with UniZero collector architecture - -Author: PriorZero Team -Date: 2025-01-20 -""" - import asyncio import logging import sys import time + from collections import deque, defaultdict from pathlib import Path from typing import Optional, Any, List, Dict, Tuple -# [CRITICAL] Ensure local LightZero is used -from ensure_local_lightzero import ensure_local_lightzero -ensure_local_lightzero() - import numpy as np import torch from ding.envs import BaseEnvManager from ding.torch_utils import to_ndarray -from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY -from vllm import AsyncLLMEngine, SamplingParams +from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, allreduce_data +from vllm import SamplingParams +import os # Import from local LightZero from lzero.worker.muzero_segment_collector import MuZeroSegmentCollector as OriginalCollector from lzero.mcts.utils import prepare_observation from game_segment_priorzero import GameSegment - # ============================================================================== # Helper Functions # ============================================================================== @@ -91,10 +72,8 @@ def extract_raw_obs_text(obs_dict: Dict[str, Any]) -> str: class PriorZeroCollector(OriginalCollector): """ [PRIORZERO-MODIFIED] - Async collector that integrates LLM priors into MCTS-based data collection. Features: - - Async LLM inference with vLLM engine - History buffer for each environment (sliding window) - Robust error handling with retries - Detailed logging of LLM prior statistics @@ -102,236 +81,106 @@ class PriorZeroCollector(OriginalCollector): def __init__( self, - vllm_engine: AsyncLLMEngine, policy_config: Dict, + llm_config: Dict, + data_processor = None, + prof = None, **kwargs ): """ Initialize PriorZeroCollector. Args: - vllm_engine: vLLM async engine for LLM inference - policy_config: Policy configuration (contains llm_policy_cfg) + vllm_engine + policy_config: Policy configuration + llm_config: llm configuration **kwargs: Additional arguments for parent class """ - # [FIX] Set policy_config in kwargs before calling super().__init__ - # because parent class needs it kwargs['policy_config'] = policy_config - # Extract debug_mode before passing to parent (parent doesn't accept this parameter) - self.debug_mode = kwargs.pop('debug_mode', False) - super().__init__(**kwargs) - self.vllm_engine = vllm_engine - # self.policy_config already set by parent class from kwargs - self.llm_policy_cfg = policy_config.llm_policy_cfg + self.data_processor = data_processor + self.prof = prof + self.llm_cfg = llm_config - # [PRIORZERO-NEW] History buffer for each environment - # Format: {env_id: deque([(obs_text, action_text, reward), ...])} self.history_buffers = defaultdict( - lambda: deque(maxlen=self.llm_policy_cfg.history_length) + lambda: deque(maxlen=self.llm_cfg.history_length) ) - - # [PRIORZERO-NEW] Statistics for logging - self.llm_stats = { - 'total_calls': 0, - 'successful_calls': 0, - 'failed_calls': 0, - 'retry_count': 0, - 'total_latency': 0.0, - 'llm_prior_top1_match_count': 0, # How often LLM top-1 matches MCTS choice - } + self.llm_prior_temperature = llm_config.llm_prior_temperature self._logger.info("✓ PriorZeroCollector initialized with vLLM engine") - self._logger.info(f" - History length: {self.llm_policy_cfg.history_length}") - self._logger.info(f" - Generate max length: {self.llm_policy_cfg.generate_max_len}") - - # [PRIORZERO-NEW] Use custom GameSegment - self.GameSegment = GameSegment - - async def _async_get_llm_prior( - self, - states: List[str], - request_ids: List[str], - histories: Optional[List[List[Tuple[str, str, float]]]] = None, - max_retries: int = 3, - timeout: float = 30.0 - ) -> List[Any]: - """ - [PRIORZERO-NEW] - Async call to LLM to get action ranking priors. - - Args: - states: List of current observation texts - request_ids: List of unique request IDs for tracking - histories: Optional list of history tuples for each state - max_retries: Maximum number of retries on failure - timeout: Timeout in seconds for each request - - Returns: - llm_outputs: List of vLLM output objects - """ - # [FIX] Check if vLLM engine is available - if self.vllm_engine is None: - self._logger.info("INFO: vLLM engine not available, skipping LLM prior") - return [None] * len(states) - - from priorzero_policy import build_llm_prompt - - # Build prompts - prompts = [] - for i, state in enumerate(states): - history = histories[i] if histories is not None else None - - # Build instruction using the helper function from policy - instruction = build_llm_prompt( - current_obs=state, - history=history, - use_cot=self.llm_policy_cfg.use_cot + self._logger.info(f" - History length: {self.llm_cfg.history_length}") + self._logger.info(f" - Generate max length: {self.llm_cfg.generate_max_len}") + + def pad_and_save_last_trajectory( + self, i: int, last_game_segments: List[GameSegment], last_game_priorities: List[np.ndarray], + game_segments: List[GameSegment], done: np.ndarray + ) -> None: + beg_index = self.policy_config.model.frame_stack_num + end_index = beg_index + self.policy_config.num_unroll_steps + self.policy_config.td_steps + + pad_obs_lst = game_segments[i].obs_segment[beg_index:end_index] + pad_raw_obs_lst = game_segments[i].raw_obs_segment[beg_index:end_index] + pad_history_obs_lst = game_segments[i].history_obs_segment[beg_index:end_index] + pad_llm_prior_per_tok_lst = game_segments[i].llm_prior_per_tok_segment[beg_index:end_index] + pad_cot_prefix_lst = game_segments[i].cot_prefix_segment[beg_index:end_index] # CoT reuse + pad_llm_action_lst = game_segments[i].llm_action_segment[beg_index:end_index] + + # NOTE: Specific padding logic for UniZero. + pad_action_lst = game_segments[i].action_segment[:self.policy_config.num_unroll_steps + self.policy_config.td_steps] + pad_child_visits_lst = game_segments[i].child_visit_segment[:self.policy_config.num_unroll_steps + self.policy_config.td_steps] + + beg_index = 0 + end_index = beg_index + self.unroll_plus_td_steps - 1 + pad_reward_lst = game_segments[i].reward_segment[beg_index:end_index] + + if self.policy_config.use_ture_chance_label_in_chance_encoder: + chance_lst = game_segments[i].chance_segment[beg_index:end_index] + + beg_index = 0 + end_index = beg_index + self.unroll_plus_td_steps + pad_root_values_lst = game_segments[i].root_value_segment[beg_index:end_index] + + if self.policy_config.gumbel_algo: + pad_improved_policy_prob = game_segments[i].improved_policy_probs[beg_index:end_index] + + # Pad and finalize the last game segment. + if self.policy_config.gumbel_algo: + last_game_segments[i].pad_over( + pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst, + next_segment_improved_policy=pad_improved_policy_prob, + next_segment_cot_prefix=pad_cot_prefix_lst, # CoT reuse + next_segment_llm_action=pad_llm_action_lst ) - - # Apply chat template if policy has tokenizer - if hasattr(self._policy, 'llm_tokenizer'): - prompt = self._policy.llm_tokenizer.apply_chat_template( - [{"role": "user", "content": instruction}], - tokenize=False, - add_generation_prompt=True + else: + if self.policy_config.use_ture_chance_label_in_chance_encoder: + last_game_segments[i].pad_over( + pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst, + next_chances=chance_lst, next_segment_raw_obs=pad_raw_obs_lst, + next_segment_history_obs=pad_history_obs_lst, next_segment_llm_prior_per_tok=pad_llm_prior_per_tok_lst, + next_segment_cot_prefix=pad_cot_prefix_lst, # CoT reuse + next_segment_llm_action=pad_llm_action_lst ) else: - prompt = instruction - - # [FIX] Ensure prompt is a string - if prompt is None: - self._logger.error(f"[ERROR] Prompt {i} is None! Instruction was: {instruction[:100] if instruction else 'None'}") - prompt = "" # Fallback to empty string - elif not isinstance(prompt, str): - self._logger.error(f"[ERROR] Prompt {i} is not a string! Type: {type(prompt)}, Value: {prompt}") - prompt = str(prompt) # Force conversion to string - - prompts.append(prompt) - - # Configure sampling parameters - sampling_params = SamplingParams( - temperature=1.0, - top_p=1.0, - max_tokens=self.llm_policy_cfg.generate_max_len, - skip_special_tokens=False, - ) - - # Retry logic - for attempt in range(max_retries): - try: - start_time = time.time() - - # [DEBUG] Log prompts and parameters before generation - if self.debug_mode and attempt == 0: - self._logger.info(f"[DEBUG] Sending {len(prompts)} prompts to vLLM engine") - for i, prompt in enumerate(prompts[:2]): # Show first 2 prompts - self._logger.info(f"[DEBUG] Prompt {i} (len={len(prompt)}): {prompt[:200]}...") - self._logger.info(f"[DEBUG] Sampling params: temp={sampling_params.temperature}, max_tokens={sampling_params.max_tokens}, top_p={sampling_params.top_p}") - self._logger.info(f"[DEBUG] Request IDs: {request_ids[:2]}...") - - # [FIX] vLLM V1 generate() takes single prompt, not list - # Create generators for each prompt individually - generators = [] - for i, (prompt, req_id) in enumerate(zip(prompts, request_ids)): - gen = self.vllm_engine.generate( - prompt, # Single prompt string - sampling_params, - req_id # Single request_id string - ) - generators.append((i, gen)) - - # Collect results - llm_outputs = [None] * len(prompts) - - try: - # Collect all results concurrently - async def collect_from_generator(idx, gen): - """Collect final result from a generator""" - final_result = None - async for result in gen: - final_result = result - # Check timeout - if time.time() - start_time > timeout: - raise asyncio.TimeoutError(f"LLM generation timeout after {timeout}s") - return idx, final_result - - # Gather all results concurrently - tasks = [collect_from_generator(idx, gen) for idx, gen in generators] - results = await asyncio.gather(*tasks, return_exceptions=True) - - # Process results - for result in results: - if isinstance(result, Exception): - raise result - idx, output = result - llm_outputs[idx] = output - - except asyncio.TimeoutError: - self._logger.warning(f"⚠ LLM generation timeout after {timeout}s (attempt {attempt+1}/{max_retries})") - if attempt < max_retries - 1: - self.llm_stats['retry_count'] += 1 - continue - else: - # On final timeout, return None for all - self.llm_stats['failed_calls'] += len(prompts) - return [None] * len(prompts) - - # Check if all outputs were received - if None in llm_outputs: - missing_count = llm_outputs.count(None) - self._logger.warning(f"⚠ {missing_count}/{len(prompts)} LLM outputs missing (attempt {attempt+1}/{max_retries})") - if attempt < max_retries - 1: - self.llm_stats['retry_count'] += 1 - continue - - # Success - elapsed = time.time() - start_time - self.llm_stats['total_calls'] += len(prompts) - self.llm_stats['successful_calls'] += len([o for o in llm_outputs if o is not None]) - self.llm_stats['failed_calls'] += len([o for o in llm_outputs if o is None]) - self.llm_stats['total_latency'] += elapsed - - self._logger.debug(f"✓ LLM generation completed in {elapsed:.2f}s ({len(prompts)} prompts)") - - # [DEBUG] Log detailed LLM outputs if debug mode is enabled - if self.debug_mode: - for i, (prompt, output) in enumerate(zip(prompts, llm_outputs)): - if output is not None: - output_text = output.outputs[0].text if output.outputs else "[No output]" - self._logger.info(f"[DEBUG] Env {i} - Prompt: {prompt[:100]}... -> LLM Output: {output_text[:100]}...") - else: - self._logger.warning(f"[DEBUG] Env {i} - LLM output is None") - - return llm_outputs - - except Exception as e: - import traceback - error_msg = f"{type(e).__name__}: {str(e)}" if str(e) else type(e).__name__ - error_trace = traceback.format_exc() - - # [FIX] Always log the full traceback on first attempt or in debug mode - if attempt == 0 or self.debug_mode: - self._logger.error(f"✗ LLM generation error (attempt {attempt+1}/{max_retries}): {error_msg}") - self._logger.error(f"Full traceback:\n{error_trace}") - else: - self._logger.error(f"✗ LLM generation error (attempt {attempt+1}/{max_retries}): {error_msg}") + last_game_segments[i].pad_over( + pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst, + next_segment_raw_obs=pad_raw_obs_lst, next_segment_history_obs=pad_history_obs_lst, + next_segment_llm_prior_per_tok=pad_llm_prior_per_tok_lst, + next_segment_cot_prefix=pad_cot_prefix_lst, # CoT reuse + next_segment_llm_action=pad_llm_action_lst + ) - if attempt < max_retries - 1: - self.llm_stats['retry_count'] += 1 - await asyncio.sleep(0.5) # Brief pause before retry - else: - # Final failure - self._logger.error(f"✗ LLM generation failed after {max_retries} attempts. Last error: {error_msg}") - self._logger.error(f"Final traceback:\n{error_trace}") - self.llm_stats['failed_calls'] += len(prompts) - return [None] * len(prompts) + last_game_segments[i].game_segment_to_array() - return [None] * len(prompts) + # Add the completed game segment to the pool. + self.game_segment_pool.append((last_game_segments[i], last_game_priorities[i], done[i])) - async def collect( + # Reset placeholders for the next collection cycle. + last_game_segments[i] = None + last_game_priorities[i] = None + + def collect( self, num_segments: Optional[int] = None, train_iter: int = 0, @@ -344,9 +193,8 @@ async def collect( Main changes from parent: 1. Extract text observations from environment - 2. Async call to LLM to get action priors - 3. Pass LLM priors to policy forward pass - 4. Update history buffers after each step + 2. Pass LLM priors to policy forward pass + 3. Update history buffers after each step Args: num_segments: Number of segments to collect @@ -372,29 +220,26 @@ async def collect( temperature = policy_kwargs.get('temperature', 1.0) epsilon = policy_kwargs.get('epsilon', 0.0) - # ================================================================== - # Initialization - # ================================================================== collected_episode = 0 collected_step = 0 + llm_prior_entropy = [[] for _ in range(self._env_num)] env_nums = self._env_num init_obs = self._env.ready_obs - # Wait for all environments to be ready retry_waiting_time = 0.05 while len(init_obs.keys()) != env_nums: self._logger.info(f'Waiting for all environments to reset. Ready: {list(init_obs.keys())}') time.sleep(retry_waiting_time) init_obs = self._env.ready_obs - # Initialize state tracking for env_id in range(env_nums): if env_id in init_obs: self.action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) self.to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) self.timestep_dict[env_id] = to_ndarray(init_obs[env_id].get('timestep', -1)) - # Initialize game segments + last_game_segments = [None for _ in range(env_nums)] + last_game_priorities = [None for _ in range(env_nums)] game_segments = [ GameSegment( self._env.action_space, @@ -404,7 +249,6 @@ async def collect( ) for _ in range(env_nums) ] - # Initialize observation stacks observation_window_stack = [ deque(maxlen=self.policy_config.model.frame_stack_num) for _ in range(env_nums) @@ -415,39 +259,32 @@ async def collect( for _ in range(self.policy_config.model.frame_stack_num) ] observation_window_stack[env_id].extend(initial_frames) - game_segments[env_id].reset(observation_window_stack[env_id]) + game_segments[env_id].reset(observation_window_stack[env_id], init_raw_obs=extract_raw_obs_text(init_obs[env_id]), + init_history_obs=list(self.history_buffers[env_id])) - # Priority calculation lists search_values_lst = [[] for _ in range(env_nums)] pred_values_lst = [[] for _ in range(env_nums)] - # Logging variables eps_steps_lst = np.zeros(env_nums) visit_entropies_lst = np.zeros(env_nums) if collect_with_pure_policy: temp_visit_list = [0.0 for _ in range(self._env.action_space.n)] - # ================================================================== - # Main Collection Loop - # ================================================================== while True: with self._timer: - # Get ready environments obs = self._env.ready_obs ready_env_id = set(obs.keys()) if len(ready_env_id) < self._env_num: self._logger.debug(f'Only {len(ready_env_id)}/{self._env_num} envs ready') - # Prepare stacked observations for world model stack_obs_dict = { env_id: game_segments[env_id].get_obs() for env_id in ready_env_id } stack_obs_list = [stack_obs_dict[env_id] for env_id in sorted(list(ready_env_id))] - # Prepare action masks and other info action_mask = [self.action_mask_dict[env_id] for env_id in sorted(list(ready_env_id))] to_play = [self.to_play_dict[env_id] for env_id in sorted(list(ready_env_id))] timestep = [self.timestep_dict[env_id] for env_id in sorted(list(ready_env_id))] @@ -460,62 +297,47 @@ async def collect( ) stack_obs_tensor = torch.from_numpy(stack_obs_tensor).to(self.policy_config.device) - # ============================================================== - # [PRIORZERO-NEW] Get LLM Priors - # ============================================================== - if not collect_with_pure_policy: + if collect_with_pure_policy: + continue + else: # Extract text observations and valid actions raw_obs_list = [] histories_list = [] - valid_actions_list = [] # [PRIORZERO] Store valid actions for each env + valid_actions_list = [] for env_id in sorted(list(ready_env_id)): - # Extract raw text raw_obs_text = extract_raw_obs_text(obs[env_id]) raw_obs_list.append(raw_obs_text) - # Get history for this environment history = list(self.history_buffers[env_id]) histories_list.append(history) - # [PRIORZERO] Extract valid actions from observation valid_actions = obs[env_id].get('valid_actions', []) valid_actions_list.append(valid_actions) - - # Generate request IDs - request_ids = [ - f"collect_{train_iter}_{i}" - for i in range(len(raw_obs_list)) - ] - - # Async call to LLM - llm_outputs = await self._async_get_llm_prior( - raw_obs_list, - request_ids, - histories_list - ) - - # Add to policy kwargs - policy_kwargs['llm_prior_outputs'] = llm_outputs - policy_kwargs['valid_actions_list'] = valid_actions_list # [PRIORZERO] Pass valid actions - else: - policy_kwargs['llm_prior_outputs'] = None - policy_kwargs['valid_actions_list'] = None - - # ============================================================== - # Policy Forward Pass - # ============================================================== - policy_args = (stack_obs_tensor, action_mask, temperature, to_play, epsilon) + with self.prof.block("collect_step_get_llm_prior", rank=self._rank): + # CoT reuse optimization: request CoT prefixes to store in game segments + llm_prior_per_seq, llm_prior_per_tok, cot_prefixes = self.data_processor.get_llm_prior( + states=raw_obs_list, + valid_actions_list=valid_actions_list, # [PRIORZERO] Pass valid actions + histories=histories_list, + return_cot=True # Request CoT prefixes for reuse in training + ) + assert len(llm_prior_per_seq) == len(ready_env_id) == len(valid_actions_list) + for idx, llm_prior in enumerate(llm_prior_per_seq): + scaled_llm_prior = self.apply_temperature_scaling(llm_prior, return_logprobs=True) + llm_prior_per_seq[idx] = scaled_llm_prior + policy_kwargs_forward = { - 'ready_env_id': sorted(list(ready_env_id)), - 'timestep': timestep, - 'llm_prior_outputs': policy_kwargs.get('llm_prior_outputs'), - 'valid_actions_list': policy_kwargs.get('valid_actions_list') # [PRIORZERO] Pass valid actions + 'llm_prior_logprob': llm_prior_per_seq, + 'valid_actions_list': valid_actions_list, } if self.task_id is not None: policy_kwargs_forward['task_id'] = self.task_id - - policy_output = self._policy.forward(*policy_args, **policy_kwargs_forward) + with self.prof.block("collect_step_forward", rank=self._rank): + policy_output = self._policy.forward(data=stack_obs_tensor, action_mask=action_mask, + temperature=temperature, to_play=to_play, epsilon=epsilon, + ready_env_id=sorted(list(ready_env_id)), timestep=timestep, + **policy_kwargs_forward) # Extract outputs actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} @@ -534,22 +356,11 @@ async def collect( env_id: actions_with_env_id.pop(env_id) for env_id in ready_env_id } - - # ============================================================== - # Step Environments - # ============================================================== - timesteps = self._env.step(actions) - - # [DEBUG] Log actions taken if debug mode is enabled - if self.debug_mode: - for env_id, action in actions.items(): - self._logger.info(f"[DEBUG] Env {env_id} - Action taken: {action}") + with self.prof.block("collect_step", rank=self._rank): + timesteps = self._env.step(actions) interaction_duration = self._timer.value / len(timesteps) - # ================================================================== - # Process Environment Responses - # ================================================================== for env_id, episode_timestep in timesteps.items(): with self._timer: # Handle abnormal timesteps @@ -566,51 +377,31 @@ async def collect( episode_timestep.done, episode_timestep.info ) - - # [DEBUG] Log observation and reward if debug mode is enabled - if self.debug_mode: - raw_obs_preview = extract_raw_obs_text(obs_new)[:150] - self._logger.info(f"[DEBUG] Env {env_id} - Obs: {raw_obs_preview}... | Reward: {reward} | Done: {done}") - - # Store search statistics - if collect_with_pure_policy: - game_segments[env_id].store_search_stats(temp_visit_list, 0) - else: - game_segments[env_id].store_search_stats( - distributions_dict_with_env_id[env_id], - value_dict_with_env_id[env_id] - ) - - # Append transition to game segment - # [PRIORZERO-FIX] Extract and pass raw_obs_text to GameSegment - raw_obs_text_for_segment = extract_raw_obs_text(obs_new) - + game_segments[env_id].store_search_stats( + distributions_dict_with_env_id[env_id], + value_dict_with_env_id[env_id]) + # =========================================================== + # [PRIORZERO-NEW] Update History Buffer + # =========================================================== + raw_obs_text = extract_raw_obs_text(obs[env_id]) + action = info['action_str'] + self.history_buffers[env_id].append((raw_obs_text, action, float(reward))) + + # Append transition to game segment (including CoT prefix for reuse optimization) game_segments[env_id].append( actions[env_id], to_ndarray(obs_new['observation']), reward, self.action_mask_dict[env_id], self.to_play_dict[env_id], - timestep=to_ndarray(obs_new.get('timestep', -1)), - raw_obs_text=raw_obs_text_for_segment + timestep=to_ndarray(self.timestep_dict[env_id]), + raw_obs_text=extract_raw_obs_text(obs_new), + history_obs=list(self.history_buffers[env_id]), + llm_prior_per_tok=llm_prior_per_tok[env_id], + cot_prefix=cot_prefixes[env_id], + llm_action=action ) - # =========================================================== - # [PRIORZERO-NEW] Update History Buffer - # =========================================================== - raw_obs_text = extract_raw_obs_text(obs[env_id]) - # [PRIORZERO] Use dynamic action mapping if available - dynamic_action_inv_map = policy_output.get(env_id, {}).get('dynamic_action_inv_map', None) - if dynamic_action_inv_map is not None: - action_text = dynamic_action_inv_map.get(actions[env_id], f"action_{actions[env_id]}") - else: - # Fallback to static mapping - action_text = getattr(self._policy, 'action_inv_map', {}).get( - actions[env_id], - f"action_{actions[env_id]}" - ) - self.history_buffers[env_id].append((raw_obs_text, action_text, float(reward))) - # Update state self.action_mask_dict[env_id] = to_ndarray(obs_new['action_mask']) self.to_play_dict[env_id] = to_ndarray(obs_new['to_play']) @@ -642,26 +433,17 @@ async def collect( # Save Full Game Segment # =========================================================== if game_segments[env_id].is_full(): - if self.last_game_segments[env_id] is not None: - self.pad_and_save_last_trajectory( - env_id, - self.last_game_segments, - self.last_game_priorities, - game_segments, - self.dones - ) + if last_game_segments[env_id] is not None: + self.pad_and_save_last_trajectory(env_id, last_game_segments, last_game_priorities, + game_segments, self.dones) # Calculate priorities - priorities = self._compute_priorities( - env_id, - pred_values_lst, - search_values_lst - ) + priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) pred_values_lst[env_id], search_values_lst[env_id] = [], [] # Save segment - self.last_game_segments[env_id] = game_segments[env_id] - self.last_game_priorities[env_id] = priorities + last_game_segments[env_id] = game_segments[env_id] + last_game_priorities[env_id] = priorities # Create new segment game_segments[env_id] = GameSegment( @@ -670,9 +452,15 @@ async def collect( config=self.policy_config, task_id=self.task_id ) - game_segments[env_id].reset(observation_window_stack[env_id]) + game_segments[env_id].reset(observation_window_stack[env_id], init_raw_obs=extract_raw_obs_text(obs_new), init_history_obs=list(self.history_buffers[env_id])) self._env_info[env_id]['step'] += 1 + if llm_prior_per_seq[env_id] is not None: + llm_prior_tensor = torch.tensor([logit for k, logit in llm_prior_per_seq[env_id].items()]) + llm_prior_prob = torch.softmax(llm_prior_tensor, dim=-1) + llm_prior_entropy[env_id].append(-torch.sum(llm_prior_prob * torch.log(llm_prior_prob + 1e-9), dim=-1)) + else: + llm_prior_entropy[env_id].append(0.0) collected_step += 1 self._env_info[env_id]['time'] += self._timer.value + interaction_duration @@ -683,13 +471,21 @@ async def collect( if episode_timestep.done: self._logger.info(f'======== Env {env_id} episode finished! ========') self._total_episode_count += 1 - # Logging info_log = { - 'reward': episode_timestep.info['eval_episode_return'], + 'reward': episode_timestep.info['score'], 'time': self._env_info[env_id]['time'], 'step': self._env_info[env_id]['step'], - } + 'llm_prior_entropy': sum(llm_prior_entropy[env_id])/len(llm_prior_entropy[env_id])} + + self._logger.info( + f"[Episode Complete] Env={env_id} | " + f"Reward={info_log['reward']:.2f} | " + f"Steps={info_log['step']} | " + f"Time={info_log['time']:.2f}s | " + f"LLM_Entropy={info_log['llm_prior_entropy']:.3f}" + ) + if not collect_with_pure_policy: info_log['visit_entropy'] = ( visit_entropies_lst[env_id] / eps_steps_lst[env_id] @@ -698,23 +494,11 @@ async def collect( collected_episode += 1 self._episode_info.append(info_log) - # Save remaining segments - if self.last_game_segments[env_id] is not None: - self.pad_and_save_last_trajectory( - env_id, - self.last_game_segments, - self.last_game_priorities, - game_segments, - self.dones - ) - - priorities = self._compute_priorities( - env_id, - pred_values_lst, - search_values_lst - ) + if last_game_segments[env_id] is not None: + self.pad_and_save_last_trajectory( env_id, last_game_segments, last_game_priorities, game_segments, self.dones) + priorities = self._compute_priorities( env_id, pred_values_lst, search_values_lst) game_segments[env_id].game_segment_to_array() if len(game_segments[env_id].reward_segment) > 0: self.game_segment_pool.append(( @@ -722,7 +506,6 @@ async def collect( priorities, self.dones[env_id] )) - # Reset pred_values_lst[env_id], search_values_lst[env_id] = [], [] eps_steps_lst[env_id], visit_entropies_lst[env_id] = 0, 0 @@ -732,15 +515,22 @@ async def collect( # Clear history buffer for this environment self.history_buffers[env_id].clear() - # Re-initialize game segment + init_obs = self._env.ready_obs + observation_window_stack[env_id] = deque( + [init_obs[env_id]['observation'] for _ in range(self.policy_config.model.frame_stack_num)], + maxlen=self.policy_config.model.frame_stack_num + ) + game_segments[env_id] = GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config, task_id=self.task_id ) - game_segments[env_id].reset(observation_window_stack[env_id]) + game_segments[env_id].reset(observation_window_stack[env_id], init_raw_obs=extract_raw_obs_text(init_obs[env_id]), init_history_obs=list(self.history_buffers[env_id])) + last_game_segments[env_id] = None + last_game_priorities[env_id] = None # ================================================================== # Check if Enough Segments Collected @@ -771,25 +561,25 @@ async def collect( # ================================================================== collected_duration = sum([d['time'] for d in self._episode_info]) + if self._world_size > 1: + # Before allreduce + local_step, local_episode = collected_step, collected_episode + collected_step = allreduce_data(collected_step, 'sum') + collected_episode = allreduce_data(collected_episode, 'sum') + collected_duration = allreduce_data(collected_duration, 'sum') + # After allreduce + self._logger.info( + f"[Rank {self._rank} Aggregation] " + f"Local: steps={local_step}, episodes={local_episode} | " + f"Global: steps={collected_step}, episodes={collected_episode}" + ) + self._total_envstep_count += collected_step self._total_episode_count += collected_episode self._total_duration += collected_duration self._output_log(train_iter) - # [PRIORZERO-NEW] Log LLM statistics - if self.llm_stats['total_calls'] > 0: - avg_latency = self.llm_stats['total_latency'] / self.llm_stats['total_calls'] - success_rate = self.llm_stats['successful_calls'] / self.llm_stats['total_calls'] - - self._logger.info( - f"📊 LLM Prior Statistics:\n" - f" - Total calls: {self.llm_stats['total_calls']}\n" - f" - Success rate: {success_rate*100:.1f}%\n" - f" - Avg latency: {avg_latency:.3f}s\n" - f" - Retry count: {self.llm_stats['retry_count']}" - ) - return return_data def _output_log(self, train_iter: int) -> None: @@ -797,4 +587,102 @@ def _output_log(self, train_iter: int) -> None: [INHERITED] Log collection statistics (inherited from parent). """ - super()._output_log(train_iter) + if self._rank != 0: + return + + if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: + self._last_train_iter = train_iter + episode_count = len(self._episode_info) + envstep_count = sum([d['step'] for d in self._episode_info]) + duration = sum([d['time'] for d in self._episode_info]) + episode_reward = [d['reward'] for d in self._episode_info] + episode_llm_prior_entropy = [d['llm_prior_entropy'] for d in self._episode_info] + + info = { + 'episode_count': episode_count, + 'envstep_count': envstep_count, + 'avg_envstep_per_episode': envstep_count / episode_count, + 'avg_envstep_per_sec': envstep_count / duration if duration > 0 else 0, + 'avg_episode_per_sec': episode_count / duration if duration > 0 else 0, + 'collect_time': duration, + 'reward_mean': np.mean(episode_reward), + 'reward_std': np.std(episode_reward), + 'reward_max': np.max(episode_reward), + 'reward_min': np.min(episode_reward), + 'total_envstep_count': self._total_envstep_count, + 'total_episode_count': self._total_episode_count, + 'total_duration': self._total_duration, + 'llm_prior_entropy_mean': np.mean(episode_llm_prior_entropy), + 'llm_prior_entropy_max': np.max(episode_llm_prior_entropy), + 'llm_prior_entropy_min': np.min(episode_llm_prior_entropy) + } + + if not self.collect_with_pure_policy: + visit_entropy = [d['visit_entropy'] for d in self._episode_info] + info['visit_entropy_mean'] = np.mean(visit_entropy) + if self.policy_config.gumbel_algo: + completed_value = [d['completed_value'] for d in self._episode_info] + info['completed_value_mean'] = np.mean(completed_value) + + self._episode_info.clear() + + self._logger.info( + f"\n{'='*80}\n" + f"[Collector Summary] Train Iter: {train_iter}\n" + f"{'-'*80}\n" + f"Episodes: {info['episode_count']} (Total: {info['total_episode_count']})\n" + f"Steps: {info['envstep_count']} (Total: {info['total_envstep_count']})\n" + f"Avg Steps/Ep: {info['avg_envstep_per_episode']:.1f}\n" + f"Throughput: {info['avg_envstep_per_sec']:.2f} steps/s, {info['avg_episode_per_sec']:.3f} eps/s\n" + f"Duration: {info['collect_time']:.2f}s (Total: {info['total_duration']:.2f}s)\n" + f"{'-'*80}\n" + f"Reward: mean={info['reward_mean']:.2f}, std={info['reward_std']:.2f}, " + f"min={info['reward_min']:.2f}, max={info['reward_max']:.2f}\n" + f"LLM Entropy: mean={info['llm_prior_entropy_mean']:.3f}, " + f"min={info['llm_prior_entropy_min']:.3f}, max={info['llm_prior_entropy_max']:.3f}\n" + + (f"Visit Entropy: {info.get('visit_entropy_mean', 0):.3f}\n" if not self.collect_with_pure_policy else "") + + (f"Completed Val: {info.get('completed_value_mean', 0):.3f}\n" if self.policy_config.gumbel_algo else "") + + f"{'='*80}" + ) + + # Log to console + self._logger.info("Collector Training Summary:\n{}".format('\n'.join([f' {k}: {v}' for k, v in info.items()]))) + + # Log to TensorBoard and WandB + for k, v in info.items(): + if self.task_id is None: + tb_prefix_iter = f'{self._instance_name}_iter/' + tb_prefix_step = f'{self._instance_name}_step/' + else: + tb_prefix_iter = f'{self._instance_name}_iter_task{self.task_id}/' + tb_prefix_step = f'{self._instance_name}_step_task{self.task_id}/' + + self._tb_logger.add_scalar(tb_prefix_iter + k, v, train_iter) + self._tb_logger.add_scalar(tb_prefix_step + k, v, self._total_envstep_count) + + def apply_temperature_scaling(self, logprobs_dict: dict, return_logprobs: bool = True) -> dict: + """ + 对 Logprobs 字典进行温度缩放,控制分布的平缓程度。 + """ + import math + T = self.llm_prior_temperature + if T <= 1e-8: + max_key = max(logprobs_dict, key=logprobs_dict.get) + return {k: (0.0 if k != max_key else 1.0) for k in logprobs_dict} + + scaled_logits = {k: v / T for k, v in logprobs_dict.items()} + + max_val = max(scaled_logits.values()) + sum_exp = sum(math.exp(v - max_val) for v in scaled_logits.values()) + log_sum_exp = math.log(sum_exp) + max_val + + result = {} + for k, v in scaled_logits.items(): + normalized_logprob = v - log_sum_exp + + if return_logprobs: + result[k] = normalized_logprob + else: + result[k] = math.exp(normalized_logprob) + + return result diff --git a/zoo/jericho/priorzero/priorzero_config.py b/zoo/jericho/priorzero/priorzero_config.py index 1614aaed4..5ec9a1a6e 100644 --- a/zoo/jericho/priorzero/priorzero_config.py +++ b/zoo/jericho/priorzero/priorzero_config.py @@ -1,87 +1,185 @@ -# priorzero_config.py -""" -[PRIORZERO] PriorZero Configuration - -This module provides complete configuration for PriorZero algorithm. - -Key Features: -- Complete UniZero world model configuration -- LLM policy configuration (ORZ-style) -- Action space mapping for text environments -- Flexible switches to enable/disable components - -Author: PriorZero Team -Date: 2025-01-20 -""" - import os -from typing import Dict, Tuple +from typing import Dict, Tuple, Optional, Any from easydict import EasyDict - - -def get_jericho_action_mapping(env_id: str = 'zork1.z5') -> Tuple[Dict[str, int], Dict[int, str]]: - """ - Get action mapping for Jericho environments. - - In Jericho, the action space is typically defined by the game's valid actions. - For simplicity, we'll provide a basic mapping that can be extended. - - Args: - env_id: Jericho game ID - - Returns: - action_map: Mapping from action text to action index - action_inv_map: Mapping from action index to action text - """ - # Basic common actions for text adventure games - # These should ideally be loaded from the environment's action space - common_actions = [ - # Movement - "go north", "go south", "go east", "go west", - "go up", "go down", "go northeast", "go northwest", - "go southeast", "go southwest", - # Object interaction - "take all", "drop all", "inventory", "look", - "examine", "open", "close", "unlock", - # Common verbs - "read", "eat", "drink", "wear", "remove", - ] - - # Create mapping - action_map = {action.lower(): idx for idx, action in enumerate(common_actions)} - action_inv_map = {idx: action for action, idx in action_map.items()} - - return action_map, action_inv_map +import torch.distributed as dist +from dataclasses import dataclass, field + +# ============================================================================ +# Model Configuration Presets +# ============================================================================ +MODEL_CONFIGS = { + "qwen2.5-0.5b": { + "model_name_or_path": "/mnt/afs/wanzunian/niuyazhe/xiongjyu/models/Qwen2.5-0.5B-Instruct", + "vllm_tensor_parallel_size": 1, + "gpu_memory_utilization": 0.2, + "description": "Qwen2.5-0.5B-Instruct (smallest, fastest)", + }, + "qwen2.5-1.5b": { + "model_name_or_path": "/mnt/shared-storage-user/puyuan/xiongjyu/models/Qwen2.5-1.5B-Instruct", + "vllm_tensor_parallel_size": 1, + "gpu_memory_utilization": 0.2, + "description": "Qwen2.5-1.5B-Instruct (balanced performance)", + }, + "qwen2.5-3b": { + "model_name_or_path": "/mnt/afs/niuyazhe/workspace/xiongjyu/models/Qwen2.5-3B-Instruct", + "vllm_tensor_parallel_size": 1, + "gpu_memory_utilization": 0.25, + "description": "Qwen2.5-3B-Instruct (better quality)", + }, + "qwen2.5-7b": { + "model_name_or_path": "/mnt/shared-storage-user/puyuan/model/Qwen2.5-7B-Instruct", + "vllm_tensor_parallel_size": 2, + "gpu_memory_utilization": 0.35, + "description": "Qwen2.5-7B-Instruct (high quality, needs 2+ GPUs)", + }, + "qwen2.5-14b": { + "model_name_or_path": "/mnt/shared-storage-user/puyuan/model/Qwen2.5-14B-Instruct", + "vllm_tensor_parallel_size": 4, + "gpu_memory_utilization": 0.5, + "description": "Qwen2.5-14B-Instruct (best quality, needs 4+ GPUs)", + }, +} + +def get_available_models(): + """Get list of available model configurations""" + return list(MODEL_CONFIGS.keys()) + +def get_model_config(model_key: str) -> Dict: + """Get model configuration by key""" + if model_key not in MODEL_CONFIGS: + available = ", ".join(get_available_models()) + raise ValueError( + f"Unknown model key: {model_key}\n" + f"Available models: {available}" + ) + return MODEL_CONFIGS[model_key] + +def print_available_models(): + """Print all available model configurations""" + print("\n" + "="*80) + print("Available Model Configurations:") + print("="*80) + for key, config in MODEL_CONFIGS.items(): + print(f"\n {key}:") + print(f" Path: {config['model_name_or_path']}") + print(f" Tensor Parallel Size: {config['vllm_tensor_parallel_size']}") + print(f" GPU Memory Utilization: {config['gpu_memory_utilization']}") + print(f" Description: {config['description']}") + print("="*80 + "\n") + +@dataclass +class PriorZeroLLMConfig: + model_name_or_path: str = "Qwen2.5-3B-Instruct" + local_rank: int = -1 + enable_rft: bool = True + enable_world_model: bool = True + + attn_implementation: str = "flash_attention_2" + history_length: int = 10 + use_cot: bool = True + prompt_max_len: int = 8192 + generate_max_len: int = 512 + bf16: bool = True + + # vLLM engines + enable_vllm: bool = True + enable_prefix_caching: bool = True + use_cuda_ipc: bool = False + vllm_sync_backend: str = "nccl" # vLLM 同步参数使用的后端 + vllm_sync_with_ray: bool = False # 是否使用 ray 来同步 vLLM 参数 + + vllm_tensor_parallel_size: int = 1 # 每个vllm engine使用几张GPU张量并行 (Fixed: 1.5B model should use 1 GPU) + + gpu_memory_utilization: float = 0.3 + vllm_enable_sleep: bool = True # 是否可以休眠 + temperature: float = 1.0 + top_p: float = 0.95 + seed: int = 0 + reduction: str = "mean" + llm_prior_temperature: float = 2.0 # LLM prior 分布的温度参数 + eval_dict: Optional[EasyDict] = field(default_factory=lambda: EasyDict({ + "world_model": True, + "world_model_llm_prior": True, + "llm_prior": True, + "eval_freq": int(500), + })) + + # 训练相关参数 + colocate_all_models: bool = True # 是否把所有模型都放在一起训练 + policy_model_num_gpus: int = 1 # 需要训练的 llm 使用几张卡 + reference_model_num_gpus: int = 1 + deepspeed_enable_sleep: bool = True + + zero_stage: int = 2 + gradient_checkpointing: bool = False + max_norm: float = 1.0 # Gradient clipping + ds_tensor_parallel_size: int = 1 + ring_attn_size: int = 1 + + # 需要注意的是,buffer中取一条经验是 10个样本,因为包含10次交互; num_unroll_steps = 10 + train_batch_size: int = 320 # 总的train_size, 结果= micro_batch_size * GPUS * gradient_accumulation_steps + micro_train_batch_size: int = 2 # 一次micro_train_batch_size 用来计算梯度;只有一次 train_batch_size 才会更新参数 + broadcast_every: int = 2 # 每次训练多少次 train_batch_size 才同步 vllm 参数;也就是说 vllm 中的模型 off 多少次参数更新 + + learning_rate: float = 1e-6 + adam_betas: Tuple[float, float] = (0.9, 0.95) + weight_decay: float = 0.01 + lr_scheduler: str = "cosine_with_min_lr" + lr_warmup_ratio: float = 0.03 + max_steps: int = int(1e4) + policy_loss_type: str = "ppo" # 'ppo' / 'gspo' + reward_func: Optional[EasyDict] = field(default_factory=lambda: EasyDict({ + "format_reward": True, + "format_param": EasyDict( + {"format_weight": 0.5, } # fmt_reward 的权重,应该在 [0, 1) 之间,因为advantage的权重是 1 - format_weight + ), + })) + # advantage = target_value - pred_value + advantage_type: str = "advantage_running_norm" # "advantage", "target_reward", "advantage_batch_norm", "advantage_running_norm" + eps_clip_low_high: Tuple[float, float] = (0.2, 0.2) + rft_kl_coef: float = 0.01 + entropy_loss_coef: float = 0.0 + kl_estimator: str = "k3" + + train_llm_after_wm_warm_step: int = int(2e2) + llm_save_freq: int = 500 # 每多少步保存一次 llm 模型,一步代表一次参数更新而不是梯度累积 + save_path: str = "" # 该参数将被 exp_name 目录覆盖 + + value_norm_cfg: Optional[EasyDict] = field(default_factory=lambda: EasyDict({ + 'enable_stability_optimizer': True, + 'value_norm_init_momentum': 0.9, # Fast adaptation in early training + 'value_norm_final_momentum': 0.99, # Slow, stable updates in later training + 'value_norm_warmup_steps': 100, # Steps to transition from init to final momentum + 'value_norm_clip_percentile': 0.95, # Clip outliers beyond this percentile + 'value_norm_clip_method': "soft", + "value_norm_history_size": 1000, + })) def get_priorzero_config( - env_id: str = 'zork1.z5', + env_id: str = 'detective.z5', seed: int = 0, exp_name: str = None, - enable_llm: bool = True, - enable_rft: bool = True, - debug_mode: bool = False, + use_cot: bool = False, + model_key: Optional[str] = "qwen2.5-3b", + multi_gpu: bool = False ) -> Tuple[EasyDict, EasyDict]: """ - Generate complete PriorZero configuration. + Generate complete PriorZero configuration with automatic model configuration. Args: env_id: Jericho game ID seed: Random seed exp_name: Experiment name (auto-generated if None) - enable_llm: Whether to enable LLM policy (if False, degrades to pure UniZero) - enable_rft: Whether to enable RFT training (if False, only use SFT) - debug_mode: Whether to enable detailed debug logging (obs, action, LLM output, etc.) + use_cot: Whether to use Chain-of-Thought reasoning + model_key: Model configuration key (e.g., 'qwen2.5-0.5b', 'qwen2.5-1.5b', 'qwen2.5-7b') + If None, uses default 'qwen2.5-1.5b' Returns: main_config: Main configuration dictionary create_config: Creation configuration for DI-engine components + llm_config: LLM configuration with auto-configured model parameters """ - - # ============================================================================== - # 1. Basic Settings - # ============================================================================== - # Action space and max steps per environment (from jericho_unizero_config.py) env_configurations = { 'detective.z5': (12, 100), 'omniquest.z5': (25, 100), @@ -89,404 +187,164 @@ def get_priorzero_config( 'zork1.z5': (55, 500), } action_space_size, max_steps = env_configurations.get(env_id, (20, 100)) - - # World model encoder (for processing text observations) - wm_encoder_option = 'legacy' # Options: 'legacy', 'clip', 'custom' - wm_model_name = 'BAAI/bge-base-en-v1.5' # Sentence transformer for text encoding - - # LLM policy model - # llm_model_name = "Qwen/Qwen2.5-1.5B-Instruct" # Smaller model for faster iteration - llm_model_name = "Qwen/Qwen2.5-0.5B-Instruct" # Smaller model for faster iteration - - # Get action mappings - action_map, action_inv_map = get_jericho_action_mapping(env_id) - - # Convert action_inv_map to use string keys for EasyDict compatibility - action_inv_map_str = {str(k): v for k, v in action_inv_map.items()} - - # ============================================================================== - # 2. Environment Configuration - # ============================================================================== + wm_encoder_option = 'legacy' + # wm_model_name = 'BAAI/bge-base-en-v1.5' + wm_model_name = '/mnt/afs/niuyazhe/workspace/xiongjyu/models/bge-base-en-v1.5' + + collector_env_num = 1 + evaluator_env_num = 2 + n_episode = collector_env_num + + num_unroll_steps = 10 + infer_context_length = 4 + game_segment_length = 50 + num_layers = 2 + embed_dim = 768 + replay_ratio = 0.1 + batch_size = 64 + collect_num_simulations=25 + eval_num_simulations=25 + replay_buffer_size = int(1e5) + env_config = dict( - # Stop conditions stop_value=int(1e6), max_steps=max_steps, - - # Observation and action space - observation_shape=512, # BGE embedding dimension - action_space_size=action_space_size, - - # [FIX] Jericho environment expects these at top level + observation_shape=512, env_id=env_id, - game_path=f"/mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/{env_id}", + # game_path=f"/mnt/afs/wanzunian/niuyazhe/xiongjyu/jericho/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/{env_id}", + game_path=f"/mnt/afs/niuyazhe/workspace/xiongjyu/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/{env_id}", + # game_path=f"/mnt/shared-storage-user/puyuan/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/{env_id}", + for_unizero=True, tokenizer_path=wm_model_name, - env_type="jericho", max_action_num=action_space_size, max_seq_len=512, - save_replay=False, - save_replay_path="", - collect_policy_mode="default", - - # Parallelization - collector_env_num=4, - evaluator_env_num=2, - n_evaluator_episode=2, - - # Environment manager + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, manager=dict( shared_memory=False, - reset_timeout=60, # Increased timeout for text env initialization ), + use_cache=True, + cache_size=100000, ) - - # ============================================================================== - # 3. UniZero World Model Configuration - # ============================================================================== - world_model_config = dict( - # [CRITICAL] DI-engine requires 'type' field to identify model class - type='UniZeroModel', - - # [FIX] EasyDict.pop() doesn't handle default values properly, must include import_names - import_names=[], # Empty list since UniZeroModel is already registered - - # Model type - model_type='mlp', # For vector observations (text embeddings) - continuous_action_space=False, - - # Observation and action - observation_shape=512, - action_space_size=action_space_size, - - # [FIX] Encoder settings must be at top level for UniZeroModel.__init__ - encoder_option=wm_encoder_option, - encoder_url=wm_model_name, - - # World model architecture - world_model_cfg=dict( - # Obs type - obs_type="text", # Important: text-based observations - - # Environment settings - env_num=max(4, 2), # max(collector_env_num, evaluator_env_num), will be updated in quick_test - action_space_size=action_space_size, - - # Transformer settings - # num_layers=4, # Reduced for faster training - num_layers=2, # Reduced for faster training # TODO - num_heads=8, - embed_dim=512, - - # Context and unroll - # Note: Each timestep contains 2 tokens: observation and action - num_unroll_steps=10, # Number of steps to unroll in training - infer_context_length=4, # Inference context length - tokens_per_block=2, # obs + action - max_blocks=10, # num_unroll_steps (default) - max_tokens=2 * 10, # 2 * num_unroll_steps - context_length=2 * 4, # 2 * infer_context_length - - # Regularization - embed_pdrop=0.1, - resid_pdrop=0.1, - attn_pdrop=0.1, - - # Loss weights - latent_recon_loss_weight=0.0, # Latent reconstruction loss - perceptual_loss_weight=0.0, - policy_entropy_weight=0.0, # Entropy regularization - - # Normalization - final_norm_option_in_head="LayerNorm", - final_norm_option_in_encoder="LayerNorm", - predict_latent_loss_type='mse', # or 'group_kl' with SimNorm - - # Device - device="cuda", - - # Advanced settings - gru_gating=False, - attention='causal', - support_size=101, # For distributional RL - - # Analysis flags - analysis_sim_norm=False, - analysis_dormant_ratio_weight_rank=False, - # use_priority=False, - use_priority=True, - - # Position encoding - rotary_emb=False, # Whether to use RoPE - rope_theta=10000, - max_seq_len=8192, - - # LoRA (optional, for world model) - lora_r=0, # Set > 0 to enable LoRA - - # Other - decode_loss_mode=None, # 'after_backbone', 'before_backbone', or None - gamma=1.0, # Discount factor - dormant_threshold=0.025, - - task_embed_option=None, - use_task_embed=False, - use_normal_head=True, - use_softmoe_head=False, - use_moe_head=False, - num_experts_in_moe_head=4, - moe_in_transformer=False, - multiplication_moe_in_transformer=False, - n_shared_experts=1, - num_experts_per_tok=1, - num_experts_of_moe_in_transformer=8, - # game_segment_length=200, - game_segment_length=50, - ), - - # Distributional RL - categorical_distribution=True, - reward_support_range=(-50., 51., 1.), # (min, max, step) for reward support - value_support_range=(-50., 51., 1.), # (min, max, step) for value support - - # Self-supervised learning - self_supervised_learning_loss=True, - - # Model architecture details - frame_stack_num=1, - bias=True, - res_connection_in_dynamics=True, - norm_type='LN', # LayerNorm for text - ) - - # ============================================================================== - # 4. LLM Policy Configuration (ORZ-style) - # ============================================================================== - llm_policy_config = dict( - # Model path - pretrain_llm_path=llm_model_name, - - # LoRA for parameter-efficient fine-tuning - use_lora=False, # Set to True to enable LoRA - lora_r=8, - lora_alpha=16, - lora_dropout=0.05, - - # Training - llm_learning_rate=1e-6, - llm_weight_decay=0.01, - llm_loss_weight=0.5, # Weight of SFT loss in total loss - rft_loss_weight=0.3, # Weight of RFT loss in total loss - - # [PRIORZERO-OOM-FIX] Gradient accumulation for memory efficiency - # Process LLM training in smaller micro-batches to avoid OOM - llm_micro_batch_size=4, # Small batch size per forward pass (reduce if still OOM) - llm_gradient_accumulation_steps=8, # Accumulate gradients over 8 steps (effective batch = 4*8=32) - # Note: Effective batch size = llm_micro_batch_size * llm_gradient_accumulation_steps - - # Generation - prompt_max_len=2048, - generate_max_len=256, # Max tokens for LLM output - - # Prompting strategy - history_length=5, # Number of recent (obs, action, reward) tuples to include - use_cot=True, # Whether to use Chain-of-Thought prompting - - # Training strategy - sft_target='mcts_policy', # 'mcts_policy' or 'oracle_policy' - enable_rft=enable_rft, # Whether to enable RFT with env rewards - # enable_rft=False, # Whether to enable RFT with env rewards # TODO - - # vLLM settings - vllm_tensor_parallel_size=1, - gpu_memory_utilization=0.3, # Adjust based on your GPU memory - ) - - # ============================================================================== - # 5. Policy Configuration (Combines World Model + LLM) - # ============================================================================== policy_config = dict( + type='priorzero', + multi_gpu=multi_gpu, + use_wandb=False, learn=dict( learner=dict( hook=dict( - save_ckpt_after_iter=1000000, # To save memory, set a large value. If intermediate checkpoints are needed, reduce this value. + save_ckpt_after_iter=1000000, ), ), ), - type='priorzero', - - # Environment settings (must match env config) - collector_env_num=env_config['collector_env_num'], - evaluator_env_num=env_config['evaluator_env_num'], - - # Model config (world model) - model=world_model_config, - - # [PRIORZERO-NEW] LLM policy config - llm_policy_cfg=llm_policy_config, - - # [PRIORZERO-NEW] Action mappings (use original dict, not EasyDict) - # These will be set directly on policy instance, not through EasyDict - _action_map=action_map, # Prefix with _ to avoid EasyDict conversion - _action_inv_map=action_inv_map, - - # ============================================================================== - # [ASYNC-NEW] Async Training Configuration - # ============================================================================== - # off_policy_degree controls the degree of asynchrony between collect and train: - # - 0: Fully synchronous (serial) mode - collect -> train -> eval - # - 1-10: Low async - train can lag behind collect by a few batches - # - 10-50: Medium async - train can lag more, higher throughput - # - >50: High async - maximum throughput, highest off-policy bias - # - # Special value -1: Auto-tune based on buffer size and batch size - off_policy_degree=0, # Default to synchronous mode for stability - # off_policy_degree=5, - - # Whether to enable async evaluation (runs eval in background) - enable_async_eval=False, - - # MCTS settings - num_simulations=25, - collect_num_simulations=25, - eval_num_simulations=25, - - # MCTS exploration - root_dirichlet_alpha=0.3, - root_noise_weight=0.25, - - # MCTS variants (set one to True to use that variant) - sampled_algo=False, # Sampled MuZero - gumbel_algo=False, # Gumbel MuZero - mcts_ctree=True, # Use C++ MCTS (faster) - - # Training settings - batch_size=32, - learning_rate=3e-4, # World model learning rate + model=dict( + observation_shape=512, + action_space_size=action_space_size, + encoder_option=wm_encoder_option, + encoder_url=wm_model_name, + model_type="mlp", + continuous_action_space=False, + norm_type="LN", + world_model_cfg=dict( + norm_type="LN", + final_norm_option_in_head="LayerNorm", + final_norm_option_in_encoder="LayerNorm", + predict_latent_loss_type='mse', + policy_entropy_weight=5e-2, + continuous_action_space=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device="cuda", + action_space_size=action_space_size, + num_layers=num_layers, + num_heads=24, + embed_dim=embed_dim, + obs_type="text", + env_num=max(collector_env_num, evaluator_env_num), + decode_loss_mode=None, + latent_recon_loss_weight=0, + + task_embed_option=None, + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + game_segment_length=game_segment_length, + ) + ), + update_per_collect=None, + num_segments=collector_env_num, + action_type="varied_action_space", + model_path=None, + num_unroll_steps=num_unroll_steps, + reanalyze_ratio=0, + replay_ratio=replay_ratio, + batch_size=batch_size, + learning_rate=3e-4, weight_decay=1e-4, + cos_lr_scheduler=False, + fixed_temperature_value=0.25, + manual_temperature_decay=False, + n_episode=n_episode, + train_start_after_envsteps=0, + replay_buffer_size=replay_buffer_size, + eval_freq=int(3e4), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=1 / 1000000, + reanalyze_batch_size=160, + reanalyze_partition=0.75, + device='cuda', + + collect_num_simulations=collect_num_simulations, + eval_num_simulations=eval_num_simulations, + game_segment_length=game_segment_length, + off_policy_degree=0, + enable_async_eval=False, + optim_type='AdamW', grad_clip_value=10.0, - - # Loss components - value_loss_weight=1.0, + value_loss_weight=0.25, policy_loss_weight=1.0, reward_loss_weight=1.0, - # Adaptive entropy weight (for exploration) - use_adaptive_entropy_weight=True, + use_adaptive_entropy_weight=False, adaptive_entropy_alpha_lr=1e-4, - - # Encoder gradient clipping with annealing - use_encoder_clip_annealing=True, + use_encoder_clip_annealing=False, encoder_clip_anneal_type='cosine', encoder_clip_start_value=30.0, encoder_clip_end_value=10.0, encoder_clip_anneal_steps=100000, - - # Training schedule - num_unroll_steps=10, - td_steps=5, - train_start_after_envsteps=0, - # train_start_after_envsteps=1000, - update_per_collect=None, # Will be set automatically - replay_ratio=0.25, - - # Replay buffer - # replay_buffer_size=int(1e4), - replay_buffer_size=int(1e5), - use_priority=True, # Prioritized experience replay + use_priority=False, # Prioritized experience replay priority_prob_alpha=0.6, priority_prob_beta=0.4, - - # Evaluation - eval_freq=500, - - # Game segments - # game_segment_length=200, - game_segment_length=50, - num_segments=env_config['collector_env_num'], # Must equal collector_env_num - - # Misc - ignore_done=False, - collect_with_pure_policy=False, - monitor_extra_statistics=True, - - # Device - cuda=True, - device='cuda', - multi_gpu=False, - - # Environment type - env_type='not_board_games', - action_type='varied_action_space', # Jericho has varied action space per state - battle_mode='play_with_bot_mode', - - # Data processing - transform2string=False, - gray_scale=False, - use_augmentation=False, - - # Advanced - use_rnd_model=False, # Random Network Distillation for exploration - analysis_sim_norm=False, - sample_type='transition', - - # ============================================================================== - # [ALIGN WITH UNIZERO] Reanalyze Configuration (atari_unizero_segment_config.py line 201-206) - # ============================================================================== - # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, - # 2 means reanalyze once every two epochs, 1/50 means reanalyze once every 50 epochs. - buffer_reanalyze_freq=1/5000000000, # Effectively disabled for Jericho (set very low) - # Each reanalyze process will reanalyze sequences - # ( transitions per sequence) - reanalyze_batch_size=160, - # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, - # 0.5 means samples from the first half of the buffer. - reanalyze_partition=0.75, - # Reanalyze ratio (used in some algorithms, kept for compatibility) - reanalyze_ratio=0.0, ) - # ============================================================================== - # 6. Replay Buffer Configuration - # ============================================================================== - replay_buffer_config = dict( - type='game', - replay_buffer_size=policy_config['replay_buffer_size'], - batch_size=policy_config['batch_size'], - ) + llm_config = PriorZeroLLMConfig(use_cot=use_cot) # 需要修改 llm 相关的参数,修改以上类即可 - # ============================================================================== - # 6.5 Remove problematic nested dicts before EasyDict conversion - # ============================================================================== - # Store action mappings separately to avoid EasyDict issues with integer keys - _temp_action_map = action_map - _temp_action_inv_map = action_inv_map + # Apply model configuration + model_config = get_model_config(model_key) + llm_config.model_name_or_path = model_config["model_name_or_path"] + llm_config.vllm_tensor_parallel_size = model_config["vllm_tensor_parallel_size"] + llm_config.gpu_memory_utilization = model_config["gpu_memory_utilization"] - # ============================================================================== - # 7. Main Configuration Assembly - # ============================================================================== + if exp_name is None: + env_name = env_id.replace(".z5", "") + exp_name = f"priorzero_{env_name}_{model_key}_{llm_config.policy_loss_type}_WM_{llm_config.enable_world_model}_useCot_{llm_config.use_cot}_seed{seed}" + priorzero_config = dict( env=env_config, policy=policy_config, - replay_buffer=replay_buffer_config, - - # Experiment settings - exp_name=exp_name or f"priorzero_{env_id}_seed{seed}", - seed=seed, - - # Debug settings - debug_mode=debug_mode, + exp_name=exp_name, + seed=seed ) - - # ============================================================================== - # 8. Create Configuration (for DI-engine component creation) - # ============================================================================== create_config = dict( env=dict( type="jericho", import_names=["zoo.jericho.envs.jericho_env"], ), env_manager=dict( - type="base" # [FIX] Use 'base' for jericho to avoid daemon process issues + type="base" ), policy=dict( type="priorzero", @@ -505,184 +363,49 @@ def get_priorzero_config( import_names=['lzero.mcts.buffer.game_buffer_muzero'], ), ) - - # ============================================================================== - # 9. Convert to EasyDict for convenient access - # ============================================================================== - # IMPORTANT: Remove _action_map and _action_inv_map from policy_config before EasyDict - # to avoid integer key issues - policy_config_copy = {k: v for k, v in policy_config.items() if not k.startswith('_')} - priorzero_config['policy'] = policy_config_copy - main_config = EasyDict(priorzero_config) create_config = EasyDict(create_config) - # Set experiment path - main_config.exp_name = f"data_priorzero/{main_config.exp_name}" - - # [IMPORTANT] Set action mappings as regular attributes (not through EasyDict) - # Use object.__setattr__ to bypass EasyDict's __setattr__ which tries to convert dicts - object.__setattr__(main_config.policy, 'action_map', _temp_action_map) - object.__setattr__(main_config.policy, 'action_inv_map', _temp_action_inv_map) + print(f"[Config] Model configuration applied:") + print(f" - Model: {model_key}") + print(f" - Path: {llm_config.model_name_or_path}") + print(f" - Tensor Parallel Size: {llm_config.vllm_tensor_parallel_size}") + print(f" - GPU Memory Utilization: {llm_config.gpu_memory_utilization}") - return main_config, create_config + return main_config, create_config, llm_config -def get_priorzero_config_for_quick_test(env_id: str = 'zork1.z5', seed: int = 0, debug_mode: bool = False): - """ - Get a lightweight configuration for quick testing (reduced resources). - - This is useful for: - - Debugging - - CI/CD pipelines - - Local development without powerful GPUs - - IMPORTANT: All sequence-length related parameters must be consistent: - - num_unroll_steps: Number of timesteps in training unroll - - max_blocks: Should equal num_unroll_steps - - max_tokens: Should equal num_unroll_steps * tokens_per_block (= num_unroll_steps * 2) - - infer_context_length: Context length for inference - - context_length: Should equal infer_context_length * tokens_per_block (= infer_context_length * 2) - """ - main_config, create_config = get_priorzero_config(env_id, seed, debug_mode=debug_mode) - - # ============================================================================== - # [CRITICAL FIX] Define num_unroll_steps FIRST to ensure consistency - # ============================================================================== - quick_test_num_unroll_steps = 10 # Core parameter that determines sequence length - quick_test_infer_context_length = 4 # Inference context length - tokens_per_block = 2 # obs + action (fixed in UniZero architecture) - - # Reduce computational requirements - main_config.env.collector_env_num = 2 - main_config.env.evaluator_env_num = 1 - main_config.env.n_evaluator_episode = 1 - - # ============================================================================== - # Policy-level configurations - # ============================================================================== - main_config.policy.num_simulations = 5 - # main_config.policy.batch_size = 20 - main_config.policy.batch_size = 2 - main_config.policy.game_segment_length = 20 # Can be larger than num_unroll_steps - main_config.policy.num_segments = 2 # Must equal collector_env_num - main_config.policy.replay_buffer_size = 1000 - - # [CRITICAL] Set policy-level num_unroll_steps to match world model - main_config.policy.num_unroll_steps = quick_test_num_unroll_steps - - # ============================================================================== - # World model configurations - ALL must be consistent with num_unroll_steps - # ============================================================================== - main_config.policy.model.world_model_cfg.num_layers = 1 - main_config.policy.model.world_model_cfg.num_heads = 2 - - # Update env_num to match the reduced collector/evaluator counts - main_config.policy.model.world_model_cfg.env_num = max( - main_config.env.collector_env_num, - main_config.env.evaluator_env_num - ) - - # [CRITICAL] Sequence length parameters - must all be consistent - main_config.policy.model.world_model_cfg.num_unroll_steps = quick_test_num_unroll_steps - main_config.policy.model.world_model_cfg.max_blocks = quick_test_num_unroll_steps - main_config.policy.model.world_model_cfg.max_tokens = quick_test_num_unroll_steps * tokens_per_block # 3 * 2 = 6 - - main_config.policy.model.world_model_cfg.infer_context_length = quick_test_infer_context_length - main_config.policy.model.world_model_cfg.context_length = quick_test_infer_context_length * tokens_per_block # 2 * 2 = 4 - - # Verify tokens_per_block is set correctly (should already be 2 from base config) - main_config.policy.model.world_model_cfg.tokens_per_block = tokens_per_block - - # ============================================================================== - # LLM policy configurations - # ============================================================================== - main_config.policy.llm_policy_cfg.prompt_max_len = 1024 - main_config.policy.llm_policy_cfg.generate_max_len = 128 - main_config.policy.llm_policy_cfg.history_length = 3 - # [PRIORZERO-OOM-FIX] Reduce micro-batch size for quick test to avoid OOM - main_config.policy.llm_policy_cfg.llm_micro_batch_size = 2 - main_config.policy.llm_policy_cfg.llm_gradient_accumulation_steps = 4 - - main_config.exp_name = f"{main_config.exp_name}_debug" - - return main_config, create_config - - -# ============================================================================== -# Preset Configurations for Different Scenarios -# ============================================================================== - -def get_config_pure_unizero(env_id: str = 'zork1.z5', seed: int = 0): - """Get config for pure UniZero (without LLM).""" - main_config, create_config = get_priorzero_config( - env_id=env_id, - seed=seed, - enable_llm=False, - ) - main_config.exp_name = f"pure_unizero_{env_id}_seed{seed}" - main_config.policy.llm_policy_cfg.llm_loss_weight = 0.0 - main_config.policy.llm_policy_cfg.rft_loss_weight = 0.0 - return main_config, create_config - +def get_priorzero_debug_config( + env_id: str = 'detective.z5', + seed: int = 0, + exp_name: str = None, + use_cot: bool = False, + model_key: Optional[str] = "qwen2.5-3b", +) -> EasyDict: -def get_config_llm_only_sft(env_id: str = 'zork1.z5', seed: int = 0): - """Get config for LLM with only SFT (no RFT).""" - main_config, create_config = get_priorzero_config( - env_id=env_id, - seed=seed, - enable_rft=False, + main_config, create_config, llm_config = get_priorzero_config( + env_id=env_id, seed=seed, exp_name=exp_name, use_cot=use_cot, model_key=model_key ) - main_config.exp_name = f"priorzero_sft_only_{env_id}_seed{seed}" - return main_config, create_config - - -def get_config_with_lora(env_id: str = 'zork1.z5', seed: int = 0): - """Get config with LoRA enabled for LLM (memory efficient).""" - main_config, create_config = get_priorzero_config(env_id=env_id, seed=seed) - main_config.policy.llm_policy_cfg.use_lora = True - main_config.exp_name = f"priorzero_lora_{env_id}_seed{seed}" - return main_config, create_config - - -# ============================================================================== -# Example Usage -# ============================================================================== - -if __name__ == "__main__": - # Test configuration generation - print("="*80) - print("Testing PriorZero Configuration Generation") - print("="*80) - - # 1. Standard config - print("\n1. Standard PriorZero Config:") - main_cfg, create_cfg = get_priorzero_config(env_id='zork1.z5', seed=0) - print(f" Exp name: {main_cfg.exp_name}") - print(f" Action space size: {main_cfg.policy.model.action_space_size}") - print(f" LLM model: {main_cfg.policy.llm_policy_cfg.pretrain_llm_path}") - print(f" World model layers: {main_cfg.policy.model.world_model_cfg.num_layers}") - print(f" Num action mappings: {len(main_cfg.policy.action_map)}") - - # 2. Quick test config - print("\n2. Quick Test Config:") - test_cfg, _ = get_priorzero_config_for_quick_test() - print(f" Batch size: {test_cfg.policy.batch_size}") - print(f" Num simulations: {test_cfg.policy.num_simulations}") - print(f" Collector envs: {test_cfg.env.collector_env_num}") - - # 3. Pure UniZero config - print("\n3. Pure UniZero Config:") - unizero_cfg, _ = get_config_pure_unizero() - print(f" LLM loss weight: {unizero_cfg.policy.llm_policy_cfg.llm_loss_weight}") - print(f" RFT enabled: {unizero_cfg.policy.llm_policy_cfg.enable_rft}") - - # 4. Config with LoRA - print("\n4. Config with LoRA:") - lora_cfg, _ = get_config_with_lora() - print(f" Use LoRA: {lora_cfg.policy.llm_policy_cfg.use_lora}") - print(f" LoRA rank: {lora_cfg.policy.llm_policy_cfg.lora_r}") - - print("\n" + "="*80) - print("✓ All configurations generated successfully!") - print("="*80) + max_steps = 20 + + batch_size = 8 + collect_num_simulations=2 + eval_num_simulations=2 + num_layers=1 + game_segment_length = 50 + + llm_config.train_batch_size = 40 # 总的train_size, 结果= micro_batch_size * GPUS * gradient_accumulation_steps + llm_config.micro_train_batch_size = 8 + llm_config.train_llm_after_wm_warm_step = 0 + + create_config.max_steps = max_steps + + main_config.policy.model.world_model_cfg.num_layers = num_layers + main_config.policy.model.world_model_cfg.game_segment_length = game_segment_length + main_config.policy.batch_size = batch_size + main_config.policy.collect_num_simulations = collect_num_simulations + main_config.policy.eval_num_simulations = eval_num_simulations + main_config.policy.update_per_collect = 2 + main_config.policy.game_segment_length = game_segment_length + + return main_config, create_config, llm_config diff --git a/zoo/jericho/priorzero/priorzero_datafactory.py b/zoo/jericho/priorzero/priorzero_datafactory.py new file mode 100644 index 000000000..09365e01d --- /dev/null +++ b/zoo/jericho/priorzero/priorzero_datafactory.py @@ -0,0 +1,704 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import re +import torch +import torch.distributed as dist +from vllm import SamplingParams +from ding.utils import build_logger +import random +import math + +_FMT_RE = re.compile( + r'^\s*Reasoning:\s*(?P[\s\S]*?)\nAction:\s*(?P[^\n\r]+)\s*$', + flags=re.IGNORECASE +) +def _format_reward(text: str) -> int: + """ + Return 1 if the output strictly matches: + Reasoning: + Action: + Otherwise 0. + """ + if not isinstance(text, str): + return 0 + + t = text.replace("\r\n", "\n").replace("\r", "\n").strip() + + m = _FMT_RE.match(t) + if m is None: + return 0 + + if len(re.findall(r'Reasoning:', t, flags=re.IGNORECASE)) != 1: + return 0 + if len(re.findall(r'Action:', t, flags=re.IGNORECASE)) != 1: + return 0 + + # Action 必须非空(regex 已经用 + 保证非空,这里再保险) + if m.group("action").strip() == "": + return 0 + + return 1 + +class DataProcessor: + """ + - build_llm_prompt / build_chat_context + - priorzero_batch -> samples + - (use_cot) 批量生成 prefix_cot + - vLLM 计算 action prior score(prompt_logprobs) + - samples -> Dataset/Dataloader(collate_fn 做 pack) + """ + + def __init__(self, rank, world_size, vllm_engine, strategy, model_path, exp_name=None, instance_name="vllm_output"): + self.vllm_engine = vllm_engine + self.strategy = strategy + self.args = getattr(strategy, "args", None) + + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, padding_side="left" + ) + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.use_cot = self.args.use_cot + self.prompt_max_len = self.args.prompt_max_len + self.generate_max_len = self.args.generate_max_len + self.temperature = self.args.temperature + self.top_p = self.args.top_p + self.vllm_enable_sleep = self.args.vllm_enable_sleep + self.reduction = self.args.reduction + self.rank = rank + self.world_size = world_size + self.output_step = 0 + self.llm_prior_with_cot = False + + from collections import deque + self.episode_output = [] + + # Running statistics for advantage normalization + self.value_running_mean = 0.0 + self.value_running_std = 1.0 + self.value_count = 0 + self.running_momentum = 0.99 # EMA momentum for running statistics + + if self.rank == 0: + self._logger, _ = build_logger( + path=f'./{exp_name}/log/{instance_name}', name=instance_name, need_tb=False + ) + + if self.args.value_norm_cfg.enable_stability_optimizer: + from models.stability_optimizer import AdaptiveValueNormalizer + self.value_normalizer = AdaptiveValueNormalizer( + init_momentum=self.args.value_norm_cfg.value_norm_init_momentum, + final_momentum=self.args.value_norm_cfg.value_norm_final_momentum, + warmup_steps=self.args.value_norm_cfg.value_norm_warmup_steps, + clip_method=self.args.value_norm_cfg.value_norm_clip_method, + clip_percentile=self.args.value_norm_cfg.value_norm_clip_percentile, + min_std=1e-6, + history_size=self.args.value_norm_cfg.value_norm_history_size, + ) + else: + self.value_normalizer = None + + def get_system_prompt(self): + """ + 系统提示词:纯文本指令,定义角色、目标和严格的输出协议。 + """ + parts = [ + "You are an expert player in a text-based adventure game. Your goal is to maximize the score by choosing the optimal next action.", + "Please analyze the game history and current observation to decide the single best next action.", + "OUTPUT FORMAT:", + ] + + if self.use_cot: + parts.append( + "You MUST produce exactly TWO parts in the following order:\n" + "1. Reasoning: Analyze the current situation, available actions, constraints, and uncertainties. Do NOT reveal the final choice here.\n" + "2. Action: The final chosen action.\n" + "Strict Format Example:\n" + "Reasoning: \n" + "Action: " + ) + else: + parts.append( + "Output exactly one line starting with 'Action:'.\n" + "Example:\n" + "Action: " + ) + return "\n".join(parts) + + def get_user_prompt(self, history: Optional[List[Tuple[str, str, float]]] = None, current_obs: Optional[str] = None): + """ + 用户提示词:注入历史和当前状态,并触发输出。 + """ + prompt_parts = [] + + if history and len(history) > 0: + prompt_parts.append("=== GAME HISTORY ===") + for i, (obs, action, reward) in enumerate(history, start=1): + prompt_parts.append(f"Step {i}:") + prompt_parts.append(f"Observation: {obs.strip()}") + prompt_parts.append(f"Action: {action.strip()}") + prompt_parts.append(f"Reward: {reward}") + prompt_parts.append("") # 空行分隔 + + prompt_parts.append("=== CURRENT OBSERVATION ===") + prompt_parts.append(current_obs.strip()) + + prompt_parts.append("\n=== INSTRUCTION ===") + if self.use_cot: + prompt_parts.append( + "Please analyze the situation and provide your response in the following format:\n" + "Reasoning: \n" + "Action: " + ) + else: + prompt_parts.append( + "Decide on the best next move and output it in the following format:\n" + "Action: " + ) + return "\n".join(prompt_parts) + + def build_chat_context(self, user_prompt: str) -> str: + return self.tokenizer.apply_chat_template( + [ + {"role": "system", "content": self.get_system_prompt()}, + {"role": "user", "content": user_prompt} + ], + tokenize=False, + add_generation_prompt=True, + ) + + def build_llm_samples(self, + raw_obs_list: List[List[str]], + history_obs_list: List[List[List[Tuple[str, str, float]]]], + llm_prior_per_tok_list: Optional[List[List[Any]]] = None, + pred_values: Optional[torch.Tensor] = None, # [B, T-1] + target_values: Optional[torch.Tensor] = None, # [B, T-1] + cot_prefix_list: Optional[List[List[str]]] = None, # CoT reuse optimization + llm_action_list: Optional[List[List[str]]] = None, + ) -> List[Dict[str, Any]]: + """ + Build training samples from collected data. + + Args: + raw_obs_list: Raw observations + history_obs_list: History observations + llm_prior_per_tok_list: LLM prior per token from collect phase + target_values: Target values for advantage calculation + cot_prefix_list: CoT prefixes from collect phase (CoT reuse optimization) + + Returns: + List of sample dictionaries + """ + samples: List[Dict[str, Any]] = [] + B = len(raw_obs_list) + if B == 0: + return samples + T = len(raw_obs_list[0]) + + for b in range(B): + for t in range(T - 1): + current_obs = raw_obs_list[b][t] + current_hist = history_obs_list[b][t] + + instruction = self.get_user_prompt( + history=current_hist, + current_obs=current_obs, + ) + prompt = self.build_chat_context(instruction) + + true_action = llm_action_list[b][t+1] + old_logprob = llm_prior_per_tok_list[b][t+1]['old_action_logprob'][true_action] + full_ids = llm_prior_per_tok_list[b][t+1]['full_ids'][true_action] + label_ids = llm_prior_per_tok_list[b][t+1]['label_ids'][true_action] + + target_value = None + if target_values is not None: + target_value = float(target_values[b][t].item()) + + pred_value = None + if pred_values is not None: + pred_value = float(pred_values[b][t].item()) + + # CoT reuse optimization: get CoT prefix from stored data + prefix_cot = None + if self.use_cot and cot_prefix_list is not None: + prefix_cot = cot_prefix_list[b][t+1] + + samples.append( + { + "instruction": instruction, + "prompt": prompt, + "target": true_action, + "pred_value": pred_value, + "target_value": target_value, + "old_logprob": old_logprob, # Reinforce++ ratio 需要 + "prefix_cot": prefix_cot, # CoT reuse optimization + "full_ids": full_ids, + "label_ids": label_ids, + } + ) + return samples + + def make_llm_train_samples(self, priorzero_batch, ddp: bool = False) -> List[Dict[str, Any]]: + """ + Convert PriorZero batch to LLM training samples. + + Args: + priorzero_batch: Tuple of (raw_obs_list, history_obs_list, llm_prior_per_tok_list, target_value, pred_value, cot_prefix_list) + CoT prefix list is added for CoT reuse optimization. + + Returns: + Tuple of (input_ids, attention_mask, action_mask, advantages, old_logprob) + """ + raw_obs_list, history_obs_list, llm_prior_per_tok_list, target_value, pred_value, cot_prefix_list, llm_action_list = priorzero_batch + + assert len(raw_obs_list) == len(history_obs_list) == len(llm_prior_per_tok_list) == len(target_value) == len(pred_value) == len(cot_prefix_list) == len(llm_action_list), \ + f"Batch size mismatch: raw_obs={len(raw_obs_list)}, history_obs={len(history_obs_list)}, llm_prior_per_tok={len(llm_prior_per_tok_list)}, target_value={len(target_value)}, cot_prefix={len(cot_prefix_list)}, llm_action={len(llm_action_list)}" + + # Build samples with CoT prefixes + samples = self.build_llm_samples( + raw_obs_list, history_obs_list, llm_prior_per_tok_list, pred_value, target_value, cot_prefix_list, llm_action_list + ) + random.shuffle(samples) + + if ddp: + print(f"[Rank {self.rank}] process {len(samples)} samples collected by Rank {self.rank}") + real_samples = samples + else: + per_rank = len(samples) // self.world_size + start = self.rank * per_rank + end = (self.rank + 1) * per_rank if self.rank != self.world_size - 1 else len(samples) + print(f"[Rank {self.rank}] process {start}: {end} samples. Total {len(samples)} samples collected by Rank 0.") + real_samples = samples[start:end] + + prompts_only = [s["prompt"] for s in real_samples] + if self.use_cot: + targets_only = [s["prefix_cot"] + " " + s["target"] + self.tokenizer.eos_token for s in real_samples] + if self.args.reward_func.format_reward: + fmt_rewards = torch.tensor([_format_reward(t) for t in targets_only]) + else: + fmt_rewards = None + else: + targets_only = [s["target"] + self.tokenizer.eos_token for s in real_samples] + fmt_rewards = None + + full_ids_list = [s['full_ids'] for s in real_samples] + tgt_ids_list = [s['label_ids'] for s in real_samples] + + inputs = self.tokenizer.pad({"input_ids": full_ids_list}, padding=True, return_tensors="pt") + labels = torch.full_like(inputs.input_ids, -100) + for i, tgt_ids in enumerate(tgt_ids_list): + tgt_len = len(tgt_ids) + labels[i, -tgt_len:] = inputs.input_ids[i, -tgt_len:] + action_mask_full = (labels != -100).long() + max_tgt_len = max(len(t) for t in tgt_ids_list) + action_mask = action_mask_full[:, -max_tgt_len:] + log_status_tmp = {} + log_status = [] + + if fmt_rewards is not None: + fmt_weight = self.args.reward_func.format_param.format_weight + assert 0.0 <= fmt_weight < 1.0, f"format_weight should be in [0, 1), but got {fmt_weight}" + log_status_tmp['fmt_rewards'] = fmt_rewards.tolist() + + # t 时刻的 target_value = td_step 步真实 r 的折扣和 + boostrap( t + td_step) 的 v + target_value = torch.tensor([s["target_value"] for s in real_samples], dtype=torch.float32) + # t 时刻的 pred_value = boostrap( t ) 的 v + pred_value = torch.tensor([s["pred_value"] for s in real_samples], dtype=torch.float32) + advantage = target_value - pred_value + + if self.args.advantage_type == "advantage": + advantage = advantage + log_status_tmp["value_advantage"] = advantage.tolist() + if fmt_rewards is not None: + advantage = (1 - fmt_weight) * advantage + fmt_weight * fmt_rewards + log_status_tmp["final_advantage"] = advantage.tolist() + + + elif self.args.advantage_type == "advantage_batch_norm": + # Legacy implementation: batch normalization (not recommended) + advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8) + log_status_tmp["value_advantage"] = advantage.tolist() + + if fmt_rewards is not None: + advantage = (1 - fmt_weight) * advantage + fmt_weight * fmt_rewards + log_status_tmp["final_advantage"] = advantage.tolist() + + elif self.args.advantage_type == "advantage_running_norm": + if self.value_normalizer is not None: + raw_mean = advantage.mean().item() + raw_std = advantage.std().item() + raw_min = advantage.min().item() + raw_max = advantage.max().item() + batch_size = advantage.numel() + + advantage, norm_stats = self.value_normalizer.normalize( + advantage, + clip_values=True, + return_stats=True + ) + + norm_min = advantage.min().item() + norm_max = advantage.max().item() + norm_mean = advantage.mean().item() + norm_std = advantage.std().item() + + if self.rank == 0 and self.value_normalizer.update_count % 10 == 0: + print( + f"[Value Norm] step={self.value_normalizer.update_count} | " + f"batch_size={batch_size} | " + f"running: mean={norm_stats['running_mean']:.3f}, std={norm_stats['running_std']:.3f} | " + f"batch: mean={norm_stats['batch_mean']:.3f}, std={norm_stats['batch_std']:.3f} | " + f"raw: min={raw_min:.3f}, max={raw_max:.3f} | " + f"norm: min={norm_min:.3f}, max={norm_max:.3f} | " + f"clipped={norm_stats['clipped_count']}/{norm_stats['total_count']} | " + f"momentum={norm_stats['momentum']:.3f}" + ) + else: + batch_mean = advantage.mean().item() + batch_std = advantage.std().item() + batch_min = advantage.min().item() + batch_max = advantage.max().item() + batch_size = advantage.numel() + + if self.value_count == 0: + self.value_running_mean = batch_mean + self.value_running_std = max(batch_std, 1e-8) # Avoid zero std + else: + self.value_running_mean = ( + self.running_momentum * self.value_running_mean + + (1 - self.running_momentum) * batch_mean + ) + self.value_running_std = ( + self.running_momentum * self.value_running_std + + (1 - self.running_momentum) * max(batch_std, 1e-8) + ) + + self.value_count += 1 + advantage = (advantage - self.value_running_mean) / (self.value_running_std + 1e-8) + + norm_min = advantage.min().item() + norm_max = advantage.max().item() + norm_mean = advantage.mean().item() + norm_std = advantage.std().item() + + if self.rank == 0 and self.value_count % 10 == 0: + print( + f"[Advantage Running Norm] step={self.value_count} | " + f"batch_size={batch_size} | " + f"running: mean={self.value_running_mean:.3f}, std={self.value_running_std:.3f} | " + f"batch: mean={batch_mean:.3f}, std={batch_std:.3f} | " + f"raw: min={batch_min:.3f}, max={batch_max:.3f} | " + f"norm: min={norm_min:.3f}, max={norm_max:.3f}" + ) + + + log_status_tmp["value_advantage"] = advantage.tolist() + if fmt_rewards is not None: + advantage = (1 - fmt_weight) * advantage + fmt_weight * fmt_rewards + log_status_tmp["final_advantage"] = advantage.tolist() + else: + raise ValueError(f"Unknown advantage_type: {self.args.advantage_type}") + + log_status = [ + {k: log_status_tmp[k][i] for k in log_status_tmp.keys()} for i in range(len(log_status_tmp['value_advantage'])) + ] + + old_seq_max_len = max([len(s['old_logprob']) for s in real_samples]) + old_logprob = torch.zeros(len(real_samples), old_seq_max_len, dtype=torch.float32) + for idx in range(len(real_samples)): + logprob_token_list = real_samples[idx]['old_logprob'] + old_logprob[idx, -len(logprob_token_list):] = torch.tensor(logprob_token_list, dtype=torch.float32) + + return inputs.input_ids, inputs.attention_mask, action_mask, advantage, old_logprob, log_status + + @torch.no_grad() + def _build_cot_prefix_texts(self, all_user_prompts: List[str]) -> List[str]: + """ + 生成CoT推理前缀。 + 优化: 使用较短的max_tokens(128)和stop条件以减少不必要的生成。 + 从最后一次出现的 "Action:" 截断出 prefix(包含 Action: 和其后的空格位置)。 + 返回 prefix_cot_list,与 all_user_prompts 等长。 + """ + cot_sampling_params = SamplingParams( + temperature=1.0, + top_p=1.0, + max_tokens=self.generate_max_len, + stop=["\n\n"], + # stop=["Action:", "\n\n"] + include_stop_str_in_output=True, + logprobs=None, + prompt_logprobs=None, + ) + + all_context_texts = [self.build_chat_context(p) for p in all_user_prompts] + context_token_ids = self.tokenizer( + all_context_texts, + add_special_tokens=False, + max_length=self.prompt_max_len, + padding=False, + truncation=True, + )["input_ids"] + + self.vllm_engine.add_requests(sampling_params=cot_sampling_params, prompt_token_ids=context_token_ids) + cot_outputs = self.vllm_engine.get_responses() + + prefix_cot_list, full_output = [], [] + reasoning_pattern = re.compile(r"Reasoning\s*:", re.IGNORECASE) + action_pattern = re.compile(r"Action\s*:", re.IGNORECASE) + + for output in cot_outputs: + gen_text = output.outputs[0].text + full_output.append(gen_text) + # TODO 这里是否要清洗数据?清洗过后,计算prior先验的时候比较正常,但是format_reward几乎没用 + # if not reasoning_pattern.search(gen_text): + # prefix_cot_list.append("Action:") + # continue + action_match = action_pattern.search(gen_text) + if action_match: + end_index = action_match.end() + prefix_piece = gen_text[:end_index].strip() + prefix_cot_list.append(prefix_piece) + continue + # else: + # prefix_piece = gen_text.strip() + "\nAction:" + # prefix_cot_list.append(prefix_piece) + prefix_cot_list.append(gen_text.strip()) + + return prefix_cot_list, full_output + + @torch.no_grad() + def get_llm_prior( + self, + states: List[str], + valid_actions_list: List[List[str]], + histories: Optional[List[List[Tuple[str, str, float]]]] = None, + return_cot: bool = False, # CoT reuse optimization: return CoT prefixes + ) -> List[Any]: + """ + Get LLM prior scores for actions. + + Args: + states: List of current state observations + valid_actions_list: List of valid actions for each state + histories: List of history observations + return_cot: If True, return CoT prefixes for reuse (optimization) + + Returns: + If return_cot=False: (llm_prior_per_seq, llm_prior_per_tok) + If return_cot=True: (llm_prior_per_seq, llm_prior_per_tok, prefix_cots) + """ + prompt_list = [] + assert len(states) == len(histories) == len(valid_actions_list) + for state, history in zip(states, histories): + prompt = self.get_user_prompt(current_obs=state, history=history) + prompt_list.append(prompt) + + if self.use_cot: + prefix_cots, full_output = self._build_cot_prefix_texts(prompt_list) + else: + prefix_cots = [None] * len(prompt_list) + full_output = None + + all_prompts = [] + all_labels = [] + all_prefix_cots = [] + all_env_indices = [] + + for env_idx, (prompt, actions, prefix) in enumerate(zip(prompt_list, valid_actions_list, prefix_cots)): + actions2 = actions if "go" in actions else (actions + ["go"]) # 确保环境使用的动作都在valid actions里有对应的logprob + for action in actions2: + all_prompts.append(prompt) + all_labels.append(action) + all_prefix_cots.append(prefix) + all_env_indices.append(env_idx) + assert len(all_prompts) == len(all_labels) == len(all_prefix_cots) == len(all_env_indices) + + scores, old_action_logprob, full_ids, label_ids = self._score_labels_with_prompt_logprobs(all_prompts, all_labels, all_prefix_cots) + assert len(all_prompts) == len(scores) == len(old_action_logprob) == len(full_ids) == len(label_ids) + + llm_prior_per_seq, llm_prior_per_tok = [],[], + cur_env_idx = 0 + seq_dict = {} + tok_dict = {'old_action_logprob': {}, 'full_ids': {}, 'label_ids': {}} + + for idx, (env_idx, prompt, label, prefix_cot) in enumerate(zip(all_env_indices, all_prompts, all_labels, all_prefix_cots)): + if env_idx != cur_env_idx: + llm_prior_per_seq.append(seq_dict) + llm_prior_per_tok.append(tok_dict) + seq_dict = {} + tok_dict = {'old_action_logprob': {}, 'full_ids': {}, 'label_ids': {}} + cur_env_idx = env_idx + + seq_dict[label] = scores[idx] + tok_dict['old_action_logprob'][label] = old_action_logprob[idx] + tok_dict['full_ids'][label] = full_ids[idx] + tok_dict['label_ids'][label] = label_ids[idx] + tok_dict['prompt'] = prompt + tok_dict['prefix_cot'] = prefix_cot + tok_dict['current_obs'] = states[env_idx] + tok_dict['history'] = histories[env_idx] + + if len(seq_dict) > 0: + llm_prior_per_seq.append(seq_dict) + llm_prior_per_tok.append(tok_dict) + + if self.use_cot: + self.episode_output.append({ + "Instruction": prompt_list[0], + "Response": full_output[0], + "llm_prior_per_seq": llm_prior_per_seq[0] + }) + # CoT reuse optimization: return CoT prefixes if requested + if return_cot: + return llm_prior_per_seq, llm_prior_per_tok, prefix_cots + else: + return llm_prior_per_seq, llm_prior_per_tok + + @torch.no_grad() + def _score_labels_with_prompt_logprobs(self, all_prompts: List[str], all_labels: List[str], all_prefix_cots: List[str]) -> List[float]: + assert len(all_prompts) == len(all_labels) == len(all_prefix_cots) + sampling_params = SamplingParams( + temperature=self.temperature, + top_p=self.top_p, + max_tokens=1, + include_stop_str_in_output=True, + logprobs=None, + prompt_logprobs=1, + ) + + all_context_texts = [self.build_chat_context(p) for p in all_prompts] + context_ids = self.tokenizer(all_context_texts, add_special_tokens=False, max_length=self.prompt_max_len - self.generate_max_len - 20, padding=False, truncation=True)["input_ids"] + + if self.use_cot: + label_texts = [pc + " " + l + self.tokenizer.eos_token for pc, l in zip(all_prefix_cots, all_labels)] + label_texts_no_cots = [" " + l + self.tokenizer.eos_token for l in all_labels] + else: + label_texts = [l + self.tokenizer.eos_token for l in all_labels] + label_texts_no_cots = label_texts + + label_ids = self.tokenizer(label_texts, add_special_tokens=False, padding=False, truncation=False)["input_ids"] + label_ids_no_cots = self.tokenizer(label_texts_no_cots, add_special_tokens=False, padding=False, truncation=False)["input_ids"] + + for idx, (l_ids, l_ids_not_cot) in enumerate(zip(label_ids, label_ids_no_cots)): + len_not_cot = len(l_ids_not_cot) + if l_ids[-len_not_cot:] != l_ids_not_cot: + raise ValueError(f"Label IDs mismatch: with CoT {l_ids[-len_not_cot:]}, without CoT {l_ids_not_cot}, label_text: {label_texts[idx]}") + + full_ids = [c + l for c, l in zip(context_ids, label_ids)] + p_lens = [len(x) for x in context_ids] + l_lens = [len(x) for x in label_ids] + l_no_cots_lens = [len(x) for x in label_ids_no_cots] + + self.vllm_engine.add_requests(sampling_params=sampling_params, prompt_token_ids=full_ids) + outs = self.vllm_engine.get_responses() + + scores = [] + old_action_logprob = [] + nan_found = False + for i, (out, ids, p_len, l_len, l_no_cots_len) in enumerate(zip(outs, full_ids, p_lens, l_lens, l_no_cots_lens)): + prompt_logprobs = getattr(out, "prompt_logprobs", None) + token_lps = [] + + for j in range(1, len(ids)): + tok_id = ids[j] + lp_dict = prompt_logprobs[j] + + assert tok_id in lp_dict + token_lps.append(lp_dict[tok_id].logprob) + + if not token_lps: + scores.append(float("-inf")) + old_action_logprob.append([]) + else: + assert l_no_cots_len <= l_len + if self.llm_prior_with_cot: + target_lps = token_lps[-l_len:] + else: + target_lps = token_lps[-l_no_cots_len:] + denom = len(target_lps) + + score = sum(target_lps) if self.reduction == "sum" else sum(target_lps) / denom + scores.append(score) + + if (not nan_found) and math.isnan(score): + vllm_returned_nan = any(math.isnan(x) for x in target_lps) + token_level_debug = [] + for t_id, t_lp in zip(ids[1:], token_lps): + token_level_debug.append(f"TokenID: {t_id} -> LogProb: {t_lp} {'(NaN HERE!)' if math.isnan(t_lp) else ''}") + + nan_found = True + nan_debug_dump = ( + f"\n{'='*20} [NaN DEBUG REPORT] {'='*20}\n" + f"Sample Index (i): {i}\n" + f"Reason: {'vLLM returned NaN logprob' if vllm_returned_nan else 'Math error during sum/div'}\n\n" + f"--- Text Info ---\n" + f"Prompt: ...{repr(all_prompts[i])}\n" + f"Label Action: {repr(all_labels[i])}\n" + f"Prefix CoT: {repr(all_prefix_cots[i])}\n\n" + f"--- Numerical Info (Copy this to reproduce) ---\n" + f"Full Input Token IDs (full_ids[{i}]): {ids}\n" + f"Context Length (p_len): {p_len}\n" + f"Label Length (l_len): {l_len}\n" + f"Target Length (l_no_cots_len): {l_no_cots_len}\n\n" + f"--- Critical Calculation Data ---\n" + f"Head 10 Token IDs: {ids[1:11]}\n" + f"LogProbs List: {token_lps[:10]}\n" + f"Detailed Mapping:\n" + "\n".join(token_level_debug[:10]) + "\n\n" + + f"Tail Token IDs: {ids[-l_len - 10: -l_len]}\n" + f"LogProbs List: {token_lps[-l_len - 10: -l_len]}\n" + f"Detailed Mapping:\n" + "\n".join(token_level_debug[-l_len - 10: -l_len]) + "\n\n" + + f"Target Token IDs: {ids[-l_no_cots_len:]}\n" + f"LogProbs List: {target_lps}\n" + f"Detailed Mapping:\n" + "\n".join(token_level_debug[-l_no_cots_len:]) + "\n" + f"{'='*60}\n" + ) + old_action_logprob.append(token_lps[-l_len:]) + + if self.rank == 0: + if nan_found: + self._logger.info(nan_debug_dump) + + return scores, old_action_logprob, full_ids, label_ids + + @torch.no_grad() + def get_llm_output_log(self, wm_train_iter: int = 0, llm_train_iter: int = 0): + if self.rank != 0: + return + + self._logger.info( + f"\n{'='*80}\n" + f"[LLM Output Log] WM Iter: {wm_train_iter} | LLM Iter: {llm_train_iter}\n" + f"{'='*80}" + ) + + for i, tmp_dict in enumerate(self.episode_output[:15]): + instruction = tmp_dict["Instruction"] + response = tmp_dict["Response"] + llm_prior = tmp_dict["llm_prior_per_seq"] + + self._logger.info( + f"\n{'-'*80}\n" + f"[Step {i}]\n" + f"{'-'*80}\n" + f"Instruction:\n{instruction}\n\n" + f"Response:\n{response}\n\n" + f"Action Probabilities:" + ) + + action_probs = {a: math.exp(float(lp)) for a, lp in llm_prior.items() if lp is not None and math.isfinite(float(lp))} + all_prob = sum(action_probs.values()) + + for action, prob in sorted(action_probs.items(), key=lambda x: x[1], reverse=True): + self._logger.info(f" {action:30s} | unnorm={prob:.6f} | norm={(prob / all_prob):.6f}") + self._logger.info(f" {'':30s} | unnorm={1-all_prob:.6f}") + self.episode_output = [] + + + \ No newline at end of file diff --git a/zoo/jericho/priorzero/priorzero_entry.py b/zoo/jericho/priorzero/priorzero_entry.py deleted file mode 100644 index 65337f6f7..000000000 --- a/zoo/jericho/priorzero/priorzero_entry.py +++ /dev/null @@ -1,581 +0,0 @@ -# priorzero_entry.py -""" -[PRIORZERO] Main Training Entry Point - -This module provides the main async training loop for PriorZero. - -Key Features: -- Async training with vLLM integration -- Checkpoint management and recovery -- Comprehensive logging (TensorBoard + file logs) -- Graceful error handling - -Author: PriorZero Team -Date: 2025-01-20 -""" - -import asyncio -import os -import sys -from functools import partial -from pathlib import Path -from typing import Tuple, Optional -# from lzero.entry.utils import log_buffer_memory_usage -# from lzero.policy import visit_count_temperature -# from ding.rl_utils import get_epsilon_greedy_fn - -# ============================================================================== -# [CRITICAL] Ensure local LightZero is used for PriorZero-specific adaptations -# ============================================================================== -from ensure_local_lightzero import ensure_local_lightzero -ensure_local_lightzero() - - -import ray -import torch -import wandb -from ding.config import compile_config -from ding.envs import create_env_manager, get_vec_env_setting -from ding.policy import create_policy -from ding.utils import set_pkg_seed, get_rank -from ding.worker import create_buffer, BaseLearner -from tensorboardX import SummaryWriter -from loguru import logger -from vllm import AsyncLLMEngine -from vllm.engine.arg_utils import AsyncEngineArgs - -# Import PriorZero components -from priorzero_config import get_priorzero_config, get_priorzero_config_for_quick_test -from priorzero_collector import PriorZeroCollector -from priorzero_evaluator import PriorZeroEvaluator -# Import policy to ensure registration happens -import priorzero_policy # noqa: F401 - - -async def train_priorzero( - cfg: dict, - create_cfg: dict, - seed: int = 0, - max_train_iter: int = int(1e6), - max_env_step: Optional[int] = int(1e10), - enable_save: bool = True, -): - """ - [PRIORZERO-MODIFIED] - Main async training function for PriorZero. - - Args: - cfg: Main configuration dictionary - create_cfg: Creation configuration for DI-engine components - seed: Random seed - max_train_iter: Maximum training iterations - enable_save: Whether to save checkpoints - """ - # ================================================================== - # 1. Compile Configuration - # ================================================================== - cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg) - - # ================================================================== - # 2. Initialize Ray (for distributed vLLM) - # ================================================================== - # Note: vLLM will initialize Ray internally if needed. - # We skip manual Ray initialization to avoid conflicts with existing clusters. - if ray.is_initialized(): - logger.info(f"✓ Ray already initialized (connected to existing cluster)") - else: - logger.info(f"✓ Ray not initialized - vLLM will handle initialization if needed") - - # ================================================================== - # 3. Create vLLM Engine - # ================================================================== - logger.info("Creating vLLM engine...") - - # [ROBUST FIX] Handle shared GPU environment - # Issue: vLLM V1 engine fails when other processes release GPU memory during init - # Solution: Use alternative initialization method that bypasses V1 checks - import os - - # Note: In vLLM>=0.3.0, worker_use_ray is replaced by distributed_executor_backend - # For single GPU: use "mp" (multiprocessing) - # For multi-GPU: use "ray" if available - tensor_parallel = cfg.policy.llm_policy_cfg.vllm_tensor_parallel_size - distributed_backend = "ray" if tensor_parallel > 1 and ray.is_initialized() else None - - # [ROBUST FIX] Lower GPU memory utilization in shared environment - # This leaves more headroom for memory fluctuations - gpu_mem_util = cfg.policy.llm_policy_cfg.gpu_memory_utilization - if gpu_mem_util > 0.85: - gpu_mem_util = 0.75 # More conservative in shared environment - logger.info(f"✓ Adjusted GPU memory utilization to {gpu_mem_util} for stability") - - # [ROBUST FIX] Use alternative initialization to avoid V1 engine issues - # Set env var BEFORE importing to ensure it takes effect - use_v1_env = os.environ.get('VLLM_USE_V1', None) - if use_v1_env is None: - # Only set if not already set by user - os.environ['VLLM_USE_V1'] = '0' - logger.info("✓ Using vLLM V0 engine for stability in shared GPU environment") - - try: - engine_args = AsyncEngineArgs( - model=cfg.policy.llm_policy_cfg.pretrain_llm_path, - tensor_parallel_size=tensor_parallel, - gpu_memory_utilization=gpu_mem_util, - distributed_executor_backend=distributed_backend, - trust_remote_code=True, - # [ROBUST FIX] Disable prefix caching in shared environment to reduce memory complexity - enable_prefix_caching=False, - # [ROBUST FIX] Disable enforce_eager to avoid memory profiling issues - enforce_eager=False, - ) - vllm_engine = AsyncLLMEngine.from_engine_args(engine_args) - logger.info(f"✓ vLLM Engine created (backend: {distributed_backend or 'default'})") - except (ValueError, RuntimeError) as e: - if "VLLM_USE_V1" in str(e) or "memory profiling" in str(e): - # Fallback: Try without V1 env var - logger.warning(f"⚠️ Initial vLLM initialization failed: {e}") - logger.info("Retrying with alternative configuration...") - if 'VLLM_USE_V1' in os.environ: - del os.environ['VLLM_USE_V1'] - - engine_args = AsyncEngineArgs( - model=cfg.policy.llm_policy_cfg.pretrain_llm_path, - tensor_parallel_size=tensor_parallel, - gpu_memory_utilization=gpu_mem_util * 0.9, # Even more conservative - distributed_executor_backend=distributed_backend, - trust_remote_code=True, - enable_prefix_caching=False, - enforce_eager=True, # Force eager mode as fallback - ) - vllm_engine = AsyncLLMEngine.from_engine_args(engine_args) - logger.info(f"✓ vLLM Engine created with fallback configuration") - else: - raise - - # ================================================================== - # 4. Create Environments - # ================================================================== - logger.info("Creating environments...") - logger.info(f"[DEBUG] Config values: collector_env_num={cfg.env.collector_env_num}, " - f"evaluator_env_num={cfg.env.evaluator_env_num}, " - f"n_evaluator_episode={cfg.env.n_evaluator_episode}") - env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) - logger.info(f"[DEBUG] get_vec_env_setting returned: " - f"collector envs={len(collector_env_cfg)}, " - f"evaluator envs={len(evaluator_env_cfg)}") - collector_env = create_env_manager( - cfg.env.manager, - [partial(env_fn, cfg=c) for c in collector_env_cfg] - ) - evaluator_env = create_env_manager( - cfg.env.manager, - [partial(env_fn, cfg=c) for c in evaluator_env_cfg] - ) - - # Seed environments - collector_env.seed(seed) - evaluator_env.seed(seed, dynamic_seed=False) - set_pkg_seed(seed, use_cuda=True) - logger.info(f"✓ Environments created and seeded (seed={seed})") - logger.info(f"[DEBUG] Actual env counts: collector={collector_env.env_num}, " - f"evaluator={evaluator_env.env_num}") - - # ================================================================== - # 5. Create Policy, Buffer, and Components - # ================================================================== - logger.info("Creating policy, buffer, and components...") - - # Create policy (align with UniZero) - policy = create_policy( - cfg.policy, - enable_field=['learn', 'collect', 'eval'] - ) - logger.info("✓ Policy created") - - # Create TensorBoard logger (align with UniZero) - os.makedirs(f'./{cfg.exp_name}/log/', exist_ok=True) - tb_logger = SummaryWriter(os.path.join(f'./{cfg.exp_name}/log/', 'serial')) if get_rank() == 0 else None - logger.info(f"✓ TensorBoard logger: ./{cfg.exp_name}/log/") - - # Create learner (align with UniZero - this sets up policy._logger) - learner = BaseLearner( - cfg.policy.learn.learner, - policy.learn_mode, - tb_logger, - exp_name=cfg.exp_name - ) - logger.info("✓ BaseLearner created") - - # [PRIORZERO-MODIFIED] Create PriorZero-specific replay buffer - # This buffer returns game_segments for LLM training (SFT/RFT) - from lzero.mcts.buffer.game_buffer_priorzero import PriorZeroGameBufferOptimized - replay_buffer = PriorZeroGameBufferOptimized(cfg.policy) - logger.info("✓ PriorZero replay buffer created (with game_segments support)") - - # Create collector - collector = PriorZeroCollector( - env=collector_env, - policy=policy.collect_mode, - tb_logger=tb_logger, - exp_name=cfg.exp_name, - vllm_engine=vllm_engine, - policy_config=cfg.policy, - debug_mode=cfg.get('debug_mode', False), - ) - logger.info("✓ Collector created") - - # Create evaluator - evaluator = PriorZeroEvaluator( - eval_freq=cfg.policy.eval_freq, - n_evaluator_episode=cfg.env.n_evaluator_episode, - stop_value=cfg.env.stop_value, - env=evaluator_env, - policy=policy.eval_mode, - tb_logger=tb_logger, - exp_name=cfg.exp_name, - vllm_engine=vllm_engine, - policy_config=cfg.policy, - ) - logger.info("✓ Evaluator created") - - # Initialize WandB if enabled (PriorZero enhancement) - if cfg.policy.get('use_wandb', True): - if get_rank() == 0: - wandb.init( - project=cfg.policy.get('wandb_project', 'priorzero'), - name=cfg.exp_name, - config=cfg, - tags=['priorzero', 'unizero', 'llm-policy'], - ) - logger.info("✓ WandB initialized") - # Set train iter and env step for policy wandb logging - policy.set_train_iter_env_step(learner.train_iter, collector.envstep) - - # Call learner's before_run hook (align with UniZero) - learner.call_hook('before_run') - - # ================================================================== - # 6. Initialize Async Training Coordinator - # ================================================================== - from async_training_coordinator import AsyncTrainingCoordinator - - coordinator = AsyncTrainingCoordinator( - off_policy_degree=cfg.policy.off_policy_degree, - enable_async_eval=cfg.policy.enable_async_eval, - buffer_size=cfg.policy.replay_buffer_size, - batch_size=cfg.policy.batch_size, - ) - - # ================================================================== - # 7. Main Training Loop - # ================================================================== - logger.info("="*80) - logger.info("Starting PriorZero Training") - logger.info("="*80) - logger.info(f"Experiment: {cfg.exp_name}") - logger.info(f"Max iterations: {max_train_iter}") - logger.info(f"Batch size: {cfg.policy.batch_size}") - logger.info(f"LLM model: {cfg.policy.llm_policy_cfg.pretrain_llm_path}") - logger.info(f"World model layers: {cfg.policy.model.world_model_cfg.num_layers}") - logger.info(f"Off-policy degree: {cfg.policy.off_policy_degree} ({'SYNC' if cfg.policy.off_policy_degree == 0 else 'ASYNC'})") - logger.info(f"Async eval: {cfg.policy.enable_async_eval}") - logger.info("="*80) - - # [ALIGN WITH UNIZERO] Initialize reanalyze-related counters (train_unizero_segment.py line 119-121) - buffer_reanalyze_count = 0 - train_epoch = 0 - reanalyze_batch_size = cfg.policy.reanalyze_batch_size - batch_size = cfg.policy.batch_size - best_eval_reward = -float('inf') - policy_config = cfg.policy - - # Async control variables - collect_task = None - train_task = None - pending_new_data = None # Store collected data waiting to be added to buffer - - try: - while True: - # ================================================================== - # Determine if we're in synchronous or asynchronous mode - # ================================================================== - is_sync_mode = coordinator.is_synchronous - - # ================================================================== - # Evaluation (align with train_unizero_segment.py line 158-162) - # ================================================================== - if learner.train_iter > 0 and evaluator.should_eval(learner.train_iter): - # if learner.train_iter == 0 r evaluator.should_eval(learner.train_iter): - - logger.info(f"\n[Iter {learner.train_iter}] Evaluating...") - - # Define async eval function - async def eval_fn(): - return evaluator.eval( - save_ckpt_fn=learner.save_checkpoint if enable_save else None, - train_iter=learner.train_iter, - envstep=collector.envstep - ) - - # Run eval through coordinator (handles sync/async based on config) - eval_result = await coordinator.run_eval(eval_fn) - - # If sync eval, process result immediately - if not cfg.policy.enable_async_eval and eval_result is not None: - stop, eval_reward_dict = eval_result - mean_reward = eval_reward_dict.get('reward_mean', 0) - logger.info(f" ✓ Evaluation done: reward_mean={mean_reward:.2f}") - - if mean_reward > best_eval_reward: - best_eval_reward = mean_reward - - if stop: - logger.info(f" 🎉 Training converged! (reward >= {cfg.env.stop_value})") - break - else: - logger.info(f" ✓ Async evaluation started in background") - - # ================================================================== - # Collect Data (align with train_unizero_segment.py line 165) - # ================================================================== - collect_kwargs = { - 'temperature': 0.25, - 'epsilon': 0.0 - } - - if is_sync_mode: - # ============================================================ - # SYNCHRONOUS MODE: Original serial execution - # ============================================================ - logger.info(f"\n[Iter {learner.train_iter}] Collecting data...") - - new_data = await collector.collect( - train_iter=learner.train_iter, - policy_kwargs=collect_kwargs - ) - - # Update replay buffer - from lzero.entry.utils import calculate_update_per_collect - update_per_collect = calculate_update_per_collect(cfg, new_data, world_size=1) - - replay_buffer.push_game_segments(new_data) - replay_buffer.remove_oldest_data_to_fit() - buffer_size = replay_buffer.get_num_of_transitions() if hasattr(replay_buffer, 'get_num_of_transitions') else 0 - logger.info(f" ✓ Data collected, buffer size: {buffer_size} transitions") - - else: - # ============================================================ - # ASYNCHRONOUS MODE: Collect can overlap with train - # ============================================================ - # Start or check collect task - if collect_task is None or collect_task.done(): - if coordinator.can_collect(): - logger.info(f"\n[Iter {learner.train_iter}] Starting async collect...") - - # Define async collect function - async def collect_fn(): - return await collector.collect( - train_iter=learner.train_iter, - policy_kwargs=collect_kwargs - ) - - # Start collect task through coordinator - collect_task = asyncio.create_task(coordinator.run_collect(collect_fn)) - else: - logger.debug(f"Collect blocked (lag={coordinator.collect_train_lag}/{coordinator.off_policy_degree})") - - # Check if collect completed - if collect_task is not None and collect_task.done(): - new_data = await collect_task - collect_task = None - - # Store for buffer update - pending_new_data = new_data - logger.info(f" ✓ Async collect completed, data pending buffer update") - - # Update buffer if we have pending data - if pending_new_data is not None: - from lzero.entry.utils import calculate_update_per_collect - update_per_collect = calculate_update_per_collect(cfg, pending_new_data, world_size=1) - - replay_buffer.push_game_segments(pending_new_data) - replay_buffer.remove_oldest_data_to_fit() - buffer_size = replay_buffer.get_num_of_transitions() if hasattr(replay_buffer, 'get_num_of_transitions') else 0 - logger.info(f" ✓ Buffer updated, size: {buffer_size} transitions") - - pending_new_data = None - else: - # No new data yet, use previous update_per_collect or default - update_per_collect = cfg.policy.get('update_per_collect', 10) - - # ============================================================ - # Periodically reanalyze buffer (align with train_unizero_segment.py line 175-186) - # ============================================================ - if cfg.policy.buffer_reanalyze_freq >= 1: - # Reanalyze buffer times in one train_epoch - reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq - else: - # Reanalyze buffer each <1/buffer_reanalyze_freq> train_epoch - if train_epoch > 0 and train_epoch % int(1/cfg.policy.buffer_reanalyze_freq) == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition): - logger.info(f"[Reanalyze] Starting buffer reanalysis...") - replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) - buffer_reanalyze_count += 1 - logger.info(f" ✓ Buffer reanalyze count: {buffer_reanalyze_count}") - - # ============================================================ - # Training (align with train_unizero_segment.py line 189-221) - # ============================================================ - if collector.envstep > cfg.policy.train_start_after_envsteps: - # Check if there is sufficient data for training - if cfg.policy.sample_type == 'episode': - data_sufficient = replay_buffer.get_num_of_game_segments() > batch_size - else: - data_sufficient = replay_buffer.get_num_of_transitions() > batch_size - - if not data_sufficient: - logger.warning( - f' ⚠ Data in replay_buffer is not sufficient: ' - f'batch_size: {batch_size}, replay_buffer: {replay_buffer}. Continue to collect...' - ) - continue - - logger.info(f"[Iter {learner.train_iter}] Training...") - - # Define training function - async def train_one_batch(): - # Reanalyze buffer during training (align with train_unizero_segment.py line 202-210) - # Note: This is simplified - full reanalyze logic should be per-batch - - # Sample batch - train_data = replay_buffer.sample(batch_size, policy) - train_data.insert(2, learner.train_iter) - - # Train - log_vars = learner.train(train_data, collector.envstep) - - # Update priority if enabled - if cfg.policy.use_priority: - replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) - - return log_vars - - if is_sync_mode: - # Synchronous: train all batches sequentially - for i in range(update_per_collect): - await train_one_batch() - else: - # Asynchronous: train batches while allowing collect to proceed - # We still train sequentially per batch, but collect can run in parallel - if coordinator.can_train(): - # Train one batch through coordinator - await coordinator.run_train(train_one_batch) - else: - logger.debug(f"Train waiting for collect...") - - # Increment epoch counter (align with train_unizero_segment.py line 222) - train_epoch += 1 - - # [FIX] Clear KV cache BEFORE collection to prevent index overflow during MCTS - policy.recompute_pos_emb_diff_and_clear_cache() - - # ============================================================ - # Check stopping criteria (align with train_unizero_segment.py line 226-227) - # ============================================================ - if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: - logger.info("Stopping condition met, training ends!") - break - - # In async mode, yield to event loop - if not is_sync_mode: - await asyncio.sleep(0.001) - - except KeyboardInterrupt: - logger.warning("\n⚠ Training interrupted by user (Ctrl+C)") - - except Exception as e: - logger.error(f"\n✗ Training error: {e}") - import traceback - traceback.print_exc() - - finally: - # ============================================================ - # Cleanup (align with train_unizero_segment.py line 229) - # ============================================================ - learner.call_hook('after_run') - - # Wait for any pending async eval - if cfg.policy.enable_async_eval: - logger.info("Waiting for async eval to complete...") - await coordinator.wait_for_eval() - - # Print async training statistics - async_stats = coordinator.get_statistics() - logger.info("\n" + "="*80) - logger.info("Async Training Statistics:") - logger.info(f" Mode: {async_stats['mode'].upper()}") - logger.info(f" Collect iterations: {async_stats['collect_count']}") - logger.info(f" Train iterations: {async_stats['train_count']}") - logger.info(f" Final lag: {async_stats['collect_train_lag']}") - if 'collect_avg_time' in async_stats: - logger.info(f" Avg collect time: {async_stats['collect_avg_time']:.2f}s") - if 'train_avg_time' in async_stats: - logger.info(f" Avg train time: {async_stats['train_avg_time']:.2f}s") - if 'eval_avg_time' in async_stats: - logger.info(f" Avg eval time: {async_stats['eval_avg_time']:.2f}s") - logger.info("="*80) - - logger.info("\nCleaning up...") - collector_env.close() - evaluator_env.close() - tb_logger.close() - - logger.info("="*80) - logger.info("Training Complete!") - logger.info(f"Total iterations: {learner.train_iter}") - logger.info(f"Best eval reward: {best_eval_reward:.2f}") - logger.info("="*80) - - return policy - - -def main(): - """ - Main entry point with argument parsing. - """ - import argparse - - parser = argparse.ArgumentParser(description='PriorZero Training') - parser.add_argument('--env_id', type=str, default='zork1.z5', help='Jericho game ID') - parser.add_argument('--seed', type=int, default=0, help='Random seed') - parser.add_argument('--max_iter', type=int, default=int(1e6), help='Max training iterations') - parser.add_argument('--quick_test', action='store_true', help='Use quick test config') - parser.add_argument('--no_save', action='store_true', help='Disable checkpoint saving') - parser.add_argument('--debug', action='store_true', help='Enable detailed debug logging (obs, action, LLM output)') - - args = parser.parse_args() - - # args.quick_test = True # ONLY FOR DEBUG - - # Get configuration - if args.quick_test: - logger.info("Using quick test configuration") - main_cfg, create_cfg = get_priorzero_config_for_quick_test(args.env_id, args.seed, debug_mode=args.debug) - else: - main_cfg, create_cfg = get_priorzero_config(args.env_id, args.seed, debug_mode=args.debug) - - # Run training - asyncio.run(train_priorzero( - main_cfg, - create_cfg, - seed=args.seed, - max_train_iter=args.max_iter, - enable_save=not args.no_save - )) - - -if __name__ == "__main__": - import os - # Disable tokenizer parallelism to prevent multi-process conflicts - os.environ['TOKENIZERS_PARALLELISM'] = 'false' - main() diff --git a/zoo/jericho/priorzero/priorzero_entry_sync.py b/zoo/jericho/priorzero/priorzero_entry_sync.py new file mode 100644 index 000000000..77a2c155e --- /dev/null +++ b/zoo/jericho/priorzero/priorzero_entry_sync.py @@ -0,0 +1,356 @@ +import sys +import os +from pathlib import Path + +# ============================================================================== +# 假设当前脚本在 .../zoo/jericho/priorzero/ 目录下 +current_file_path = Path(__file__).resolve() +# 回退 4 层找到 LightZero 根目录 (priorzero -> jericho -> zoo -> LightZero) +project_root = current_file_path.parents[3] + +if str(project_root) not in sys.path: + print(f"[SYSTEM] Inserting project root to sys.path: {project_root}") + sys.path.insert(0, str(project_root)) +# ============================================================================== + + +import asyncio +import os +import sys +from functools import partial +from pathlib import Path +from typing import Tuple, Optional + +import torch +import torch.distributed as dist +import wandb + +from ding.config import compile_config, save_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.utils import set_pkg_seed, get_rank, get_world_size +from ding.worker import create_buffer, BaseLearner +from tensorboardX import SummaryWriter +from loguru import logger +import deepspeed + +from priorzero_config import ( + get_priorzero_config, + get_priorzero_debug_config, + get_available_models, +) +from priorzero_collector import PriorZeroCollector +from priorzero_evaluator import PriorZeroEvaluator +from priorzero_policy import * +from lzero.mcts.buffer.game_buffer_priorzero import PriorZeroGameBufferOptimized +from utils import dump_dataclass_cfg_py + +from lzero.entry.utils import calculate_update_per_collect + +def prepare_unizero(rank, cfg, create_cfg, llm_cfg, seed): + cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg) + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + collector_env.seed(seed) + evaluator_env.seed(seed, dynamic_seed=False) + + policy = create_policy( cfg.policy, enable_field=['learn', 'collect', 'eval'], exp_name=cfg.exp_name) + logger.info(f"[Rank {rank}] Policy created") + + os.makedirs(f'./{cfg.exp_name}/log/', exist_ok=True) + tb_logger = SummaryWriter(os.path.join(f'./{cfg.exp_name}/log/', 'serial')) if get_rank() == 0 else None + logger.info(f"[Rank {rank}] TensorBoard logger: ./{cfg.exp_name}/log/") + + learner = BaseLearner( + cfg.policy.learn.learner, + policy.learn_mode, + tb_logger, + exp_name=cfg.exp_name + ) + logger.info(f"[Rank {rank}] BaseLearner created") + + + replay_buffer = PriorZeroGameBufferOptimized(cfg.policy) + logger.info(f"[Rank {rank}] PriorZero replay buffer created (with game_segments support)") + + # Create collector + collector = PriorZeroCollector( + env=collector_env, + policy=policy.collect_mode, + llm_config=llm_cfg, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=cfg.policy, + ) + logger.info(f"[Rank {rank}] Collector created") + + # Create evaluator + evaluator = PriorZeroEvaluator( + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=cfg.policy, + llm_config=llm_cfg, + ) + logger.info(f"[Rank {rank}] Evaluator created") + learner.call_hook('before_run') + + return cfg, replay_buffer, tb_logger, policy, collector, evaluator, learner + +def bcast_obj(world_size, obj, rank, src=0): + if world_size <= 1: + return obj + lst = [obj] if rank == src else [None] + dist.broadcast_object_list(lst, src=src) + return lst[0] + +def train_priorzero( + cfg: dict, + create_cfg: dict, + llm_cfg, + seed: int = 0, + max_train_iter: int = int(1e6), + max_env_step: Optional[int] = int(1e10), + enable_profile: bool = False +): + rank = int(os.environ.get("RANK", "0")) + print(f"rank={rank}") + if rank == 0: + cfg, replay_buffer, tb_logger, policy, collector, evaluator, learner = prepare_unizero( + rank=rank, + cfg=cfg, + create_cfg=create_cfg, + llm_cfg=llm_cfg, + seed=seed) + batch_size = cfg.policy.batch_size + logger.info(f"[Rank {rank}] World Model components initialized") + dump_dataclass_cfg_py(llm_cfg, path=f"{cfg.exp_name}/llm_cfg.py") + llm_cfg.save_path = f'./{cfg.exp_name}/llm_ckpt/' + + from utils import Profiler + prof = Profiler(log_interval=10, stats_file=f'./{cfg.exp_name}/log/profiler.txt', enable_profile=enable_profile) + + from strategy.deepspeed import get_strategy, torch_dist_barrier_and_cuda_sync + strategy = get_strategy(llm_cfg) + strategy.print(llm_cfg) + + strategy.setup_distributed() # torchrun 下:绑定 local_rank + init_distributed + world_size = getattr(strategy, "world_size", 1) + + logger.info(f"[Rank {rank}] Initializing LLM Actor...") + set_pkg_seed(seed + rank, use_cuda=True) + + from models.actor import PolicyModel, ReferenceModel + if llm_cfg.rft_kl_coef > 0: + ref_model = ReferenceModel( + strategy=strategy, + pretrain=llm_cfg.model_name_or_path + ) + else: + ref_model = None + + from vllm_utils.vllm_engine import create_vllm_engine + vllm_engine = create_vllm_engine( + tensor_parallel_size=llm_cfg.vllm_tensor_parallel_size, + pretrain=llm_cfg.model_name_or_path, + enable_prefix_caching=llm_cfg.enable_prefix_caching, + max_model_len=llm_cfg.prompt_max_len + llm_cfg.generate_max_len, + gpu_memory_utilization=llm_cfg.gpu_memory_utilization, + vllm_enable_sleep=llm_cfg.vllm_enable_sleep, + ) + + print(f'[Rank {rank}] Vllm engine successfully created!') + + from priorzero_datafactory import DataProcessor + data_processor = DataProcessor(rank=rank, + world_size=world_size, + vllm_engine=vllm_engine, + strategy=strategy, + model_path=llm_cfg.model_name_or_path, + exp_name=cfg.exp_name if rank == 0 else None, + ) + if rank == 0: + collector.data_processor = data_processor + collector.prof = prof + evaluator.data_processor = data_processor + + policy_model = PolicyModel( + strategy=strategy, + pretrain=llm_cfg.model_name_or_path, + vllm_engine=vllm_engine, + max_steps=llm_cfg.max_steps + ) + from priorzero_trainer import PriorZeroLLMTrainer + trainer = PriorZeroLLMTrainer( + cfg=llm_cfg, + pretrain=llm_cfg.model_name_or_path, + strategy= strategy, + vllm_engine = vllm_engine, + policy_model=policy_model, + reference_model=ref_model, + exp_name=cfg.exp_name if rank == 0 else None, + tb_logger=tb_logger if rank == 0 else None, + llm_save_freq=llm_cfg.llm_save_freq + ) + + torch_dist_barrier_and_cuda_sync() + + while True: + cmd = "noop" + priorzero_batch = None + if rank == 0: + if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter): + logger.info(f"\n[Rank {rank}: Iter {learner.train_iter}] Evaluating...") + if llm_cfg.vllm_enable_sleep and vllm_engine is not None: + vllm_engine.wake_up() + evaluator.eval(train_iter=learner.train_iter, envstep=collector.envstep) + if llm_cfg.vllm_enable_sleep and vllm_engine is not None: + vllm_engine.sleep() + + if cmd != "stop": + if llm_cfg.vllm_enable_sleep and vllm_engine is not None: + vllm_engine.wake_up() + + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'temperature': 0.25, 'epsilon': 0.0}) + data_processor.get_llm_output_log(wm_train_iter=learner.train_iter, llm_train_iter=policy_model.train_iter) + + if llm_cfg.vllm_enable_sleep and vllm_engine is not None: + vllm_engine.sleep() + + update_per_collect = calculate_update_per_collect(cfg, new_data, world_size=1) + + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + num_of_transitions = replay_buffer.get_num_of_transitions() + new_num_of_transitions = replay_buffer.get_num_of_transitions() - replay_buffer.last_pos_in_transition + logger.info(f"[Rank {rank}] Data collected, num_of_transitions: {num_of_transitions} transitions\tnew_num_of_transitions: {new_num_of_transitions}") + + if not (num_of_transitions > batch_size): + logger.warning( + f' ⚠ Data in replay_buffer is not sufficient: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}. Continue to collect...' + ) + cmd = "noop" + cmd = bcast_obj(world_size, cmd, rank, src=0) + continue + + logger.info(f"[Rank {rank}: World Model] [Iter {learner.train_iter}] Training for {update_per_collect} updates......") + + for i in range(update_per_collect): + with prof.block("train_world_model", rank=0): + train_data = replay_buffer.sample(batch_size, policy) + train_data.append(learner.train_iter) + + log_vars = learner.train(train_data, collector.envstep) + if cfg.policy.use_priority: + replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + policy.recompute_pos_emb_diff_and_clear_cache() + + # 计算需要收集多少样本才能满足 llm 的训练 + # 一次参数更新是train_batch_size,off次数为broadcast_every,1是因为只有一个rank收集数据 + # 此外, 需要的 transitions是样本数 / unroll_steps,即轨迹数 + llm_need_sample_cnt = llm_cfg.train_batch_size * llm_cfg.broadcast_every // 1 + llm_need_transition_cnt = (llm_need_sample_cnt + cfg.policy.num_unroll_steps - 1) // cfg.policy.num_unroll_steps + + if learner.train_iter >= llm_cfg.train_llm_after_wm_warm_step and new_num_of_transitions >= llm_need_transition_cnt: + with prof.block("fetch_latest_batch", rank=0): + print(f"[Rank 0] world_model: train_iter ={learner.train_iter} \t replay_buffer.fetch_latest_batch begin \t llm_need_transition_cnt={llm_need_transition_cnt}") + priorzero_batch = replay_buffer.fetch_latest_batch(batch_size=llm_need_transition_cnt, policy=policy) + print(f"[Rank 0] fetch_latest_batch returned: type={type(priorzero_batch)}, len={len(priorzero_batch)}") + cmd = "llm" + + if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: + cmd = "stop" + + cmd = bcast_obj(world_size, cmd, rank, src=0) + if cmd == "stop": + break + elif cmd == "llm": + with prof.block("train_llm", rank=rank): + logger.info(f"[Rank {rank}] Waiting for broadcast of train_samples from Rank 0...") + priorzero_batch = bcast_obj(world_size, priorzero_batch, rank, src=0) + logger.info(f"[Rank {rank}] Received broadcast. train_samples count: {len(priorzero_batch[0]) if priorzero_batch and len(priorzero_batch) > 0 else 'UNKNOWN'}. Starting LLM training...") + train_samples = data_processor.make_llm_train_samples(priorzero_batch) + trainer.train_batch(train_samples, collect_env_steps=collector.envstep) + torch_dist_barrier_and_cuda_sync() + + +def main(): + """ + Main entry point with argument parsing. + """ + import argparse + + parser = argparse.ArgumentParser( + description='PriorZero Training with Auto Model Configuration', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Use default model (qwen2.5-1.5b) + torchrun --nproc_per_node 2 priorzero_entry_sync.py + + # Use specific model + torchrun --nproc_per_node 2 priorzero_entry_sync.py --model qwen2.5-0.5b + torchrun --nproc_per_node 2 priorzero_entry_sync.py --model qwen2.5-7b + + # List all available models + python priorzero_entry_sync.py --list-models + + # Different environment + torchrun --nproc_per_node 2 priorzero_entry_sync.py --env_id zork1.z5 --model qwen2.5-1.5b + """ + ) + parser.add_argument('--env_id', type=str, default='detective.z5', help='Jericho game ID') + parser.add_argument('--seed', type=int, default=0, help='Random seed') + parser.add_argument('--max_iter', type=int, default=int(1e6), help='Max training iterations') + parser.add_argument('--quick_test', action='store_true', default=False, help='Use quick test config') + # Model selection + parser.add_argument('--model', type=str, default="qwen2.5-3b", choices=get_available_models()) + parser.add_argument('--enable_profile', action='store_true', default=False) + parser.add_argument('--use_cot', action='store_true', default=True) + args = parser.parse_args() + + model_key = args.model if args.model else "qwen2.5-1.5b" + print(f"\n{'='*80}") + print(f"PriorZero Training Configuration") + print(f"{'='*80}") + print(f"Environment: {args.env_id}") + print(f"Model: {model_key}") + print(f"Seed: {args.seed}") + print(f"Quick Test: {args.quick_test}") + print(f"use cot: {args.use_cot}") + print(f"enable_profile: {args.enable_profile}") + print(f"{'='*80}\n") + + if args.quick_test: + logger.info("Using quick test configuration") + main_cfg, create_cfg, llm_cfg = get_priorzero_debug_config( + args.env_id, args.seed, use_cot=args.use_cot, + exp_name=f'data_priorzero/priorzero_debug_{args.env_id}', + model_key=model_key, + ) + else: + main_cfg, create_cfg, llm_cfg = get_priorzero_config( + args.env_id, args.seed, use_cot=args.use_cot, + model_key=model_key, + ) + + train_priorzero( + main_cfg, + create_cfg, + llm_cfg, + seed=args.seed, + max_train_iter=args.max_iter, + enable_profile=args.enable_profile, # 是否要对各个耗时部分进行 profile + ) + + +if __name__ == "__main__": + os.environ['TOKENIZERS_PARALLELISM'] = 'false' + main() diff --git a/zoo/jericho/priorzero/priorzero_entry_sync_ddp.py b/zoo/jericho/priorzero/priorzero_entry_sync_ddp.py new file mode 100644 index 000000000..c1617e9d7 --- /dev/null +++ b/zoo/jericho/priorzero/priorzero_entry_sync_ddp.py @@ -0,0 +1,375 @@ +import sys +import os +from pathlib import Path + +# ============================================================================== +# 假设当前脚本在 .../zoo/jericho/priorzero/ 目录下 +current_file_path = Path(__file__).resolve() +# 回退 4 层找到 LightZero 根目录 (priorzero -> jericho -> zoo -> LightZero) +project_root = current_file_path.parents[3] + +if str(project_root) not in sys.path: + print(f"[SYSTEM] Inserting project root to sys.path: {project_root}") + sys.path.insert(0, str(project_root)) +# ============================================================================== + + +import asyncio +import os +import sys +from functools import partial +from pathlib import Path +from typing import Tuple, Optional + +import torch +import torch.distributed as dist +import wandb + +from ding.config import compile_config, save_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.utils import set_pkg_seed, get_rank, get_world_size +from ding.worker import create_buffer, BaseLearner +from tensorboardX import SummaryWriter +from loguru import logger +import deepspeed + +from priorzero_config import ( + get_priorzero_config, + get_priorzero_debug_config, + get_available_models, +) +from priorzero_collector import PriorZeroCollector +from priorzero_evaluator import PriorZeroEvaluator +from priorzero_policy import * +from lzero.mcts.buffer.game_buffer_priorzero import PriorZeroGameBufferOptimized +from utils import dump_dataclass_cfg_py + +from lzero.entry.utils import calculate_update_per_collect + +def prepare_unizero(rank, cfg, create_cfg, llm_cfg, seed): + cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg) + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + collector_env.seed(seed) + evaluator_env.seed(seed, dynamic_seed=False) + + policy = create_policy( cfg.policy, enable_field=['learn', 'collect', 'eval'], exp_name=cfg.exp_name) + logger.info(f"[Rank {rank}] Policy created") + + os.makedirs(f'./{cfg.exp_name}/log/', exist_ok=True) + tb_logger = SummaryWriter(os.path.join(f'./{cfg.exp_name}/log/', 'serial')) if get_rank() == 0 else None + logger.info(f"[Rank {rank}] TensorBoard logger: ./{cfg.exp_name}/log/") + + learner = BaseLearner( + cfg.policy.learn.learner, + policy.learn_mode, + tb_logger, + exp_name=cfg.exp_name + ) + logger.info(f"[Rank {rank}] BaseLearner created") + + + replay_buffer = PriorZeroGameBufferOptimized(cfg.policy) + logger.info(f"[Rank {rank}] PriorZero replay buffer created (with game_segments support)") + + # Create collector + collector = PriorZeroCollector( + env=collector_env, + policy=policy.collect_mode, + llm_config=llm_cfg, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=cfg.policy, + ) + logger.info(f"[Rank {rank}] Collector created") + + # Create evaluator + evaluator = PriorZeroEvaluator( + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=cfg.policy, + llm_config=llm_cfg, + ) + logger.info(f"[Rank {rank}] Evaluator created") + learner.call_hook('before_run') + + return cfg, replay_buffer, tb_logger, policy, collector, evaluator, learner + +def all_gather_cmd(world_size, obj) -> List: + if world_size <= 1: + return [obj] + lst = [None] * dist.get_world_size() + dist.all_gather_object(lst, obj) + return lst + +def train_priorzero( + cfg: dict, + create_cfg: dict, + llm_cfg, + seed: int = 0, + max_train_iter: int = int(1e6), + max_env_step: Optional[int] = int(1e10), + enable_profile: bool = False +): + rank = int(os.environ.get("RANK", "0")) + print(f"DEBUG: Is dist initialized at start? {dist.is_initialized()}") + if dist.is_initialized(): + print(f"DEBUG: Backend is {dist.get_backend()}") + from strategy.deepspeed import get_strategy, torch_dist_barrier_and_cuda_sync + strategy = get_strategy(llm_cfg) + strategy.print(llm_cfg) + + strategy.setup_distributed() # torchrun 下:绑定 local_rank + init_distributed + world_size = getattr(strategy, "world_size", 1) + + + cfg, replay_buffer, tb_logger, policy, collector, evaluator, learner = prepare_unizero( + rank=rank, + cfg=cfg, + create_cfg=create_cfg, + llm_cfg=llm_cfg, + seed=seed) + batch_size = cfg.policy.batch_size + logger.info(f"[Rank {rank}] World Model components initialized") + if rank == 0: + dump_dataclass_cfg_py(llm_cfg, path=f"{cfg.exp_name}/llm_cfg.py") + llm_cfg.save_path = f'./{cfg.exp_name}/llm_ckpt/' + + from utils import Profiler + prof = Profiler(log_interval=10, stats_file=f'./{cfg.exp_name}/log/profiler.txt', enable_profile=enable_profile) + + + logger.info(f"[Rank {rank}] Initializing LLM Actor...") + set_pkg_seed(seed + rank, use_cuda=True) + + from models.actor import PolicyModel, ReferenceModel + if llm_cfg.rft_kl_coef > 0: + ref_model = ReferenceModel( + strategy=strategy, + pretrain=llm_cfg.model_name_or_path + ) + else: + ref_model = None + + from vllm_utils.vllm_engine import create_vllm_engine + vllm_engine = create_vllm_engine( + tensor_parallel_size=llm_cfg.vllm_tensor_parallel_size, + pretrain=llm_cfg.model_name_or_path, + enable_prefix_caching=llm_cfg.enable_prefix_caching, + max_model_len=llm_cfg.prompt_max_len + llm_cfg.generate_max_len, + gpu_memory_utilization=llm_cfg.gpu_memory_utilization, + vllm_enable_sleep=llm_cfg.vllm_enable_sleep, + ) + + print(f'[Rank {rank}] Vllm engine successfully created!') + + from priorzero_datafactory import DataProcessor + data_processor = DataProcessor(rank=rank, + world_size=world_size, + vllm_engine=vllm_engine, + strategy=strategy, + model_path=llm_cfg.model_name_or_path, + exp_name=cfg.exp_name if rank == 0 else None, + ) + # 在collector中初始化data_processor 和prof对象 + collector.data_processor = data_processor + collector.prof = prof + evaluator.data_processor = data_processor + + policy_model = PolicyModel( + strategy=strategy, + pretrain=llm_cfg.model_name_or_path, + vllm_engine=vllm_engine, + max_steps=llm_cfg.max_steps + ) + from priorzero_trainer import PriorZeroLLMTrainer + trainer = PriorZeroLLMTrainer( + cfg=llm_cfg, + pretrain=llm_cfg.model_name_or_path, + strategy= strategy, + vllm_engine = vllm_engine, + policy_model=policy_model, + reference_model=ref_model, + exp_name=cfg.exp_name if rank == 0 else None, + tb_logger=tb_logger if rank == 0 else None, + llm_save_freq=llm_cfg.llm_save_freq + ) + + torch_dist_barrier_and_cuda_sync() + + while True: + cmd = 0 # 0 表示当前循环contiune, 1 表示继续,2 表示break + priorzero_batch = None + if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter): + logger.info(f"\n[Rank {rank}: Iter {learner.train_iter}] Evaluating...") + + if llm_cfg.vllm_enable_sleep and vllm_engine is not None: + vllm_engine.wake_up() + evaluator.eval(train_iter=learner.train_iter, envstep=collector.envstep) + if llm_cfg.vllm_enable_sleep and vllm_engine is not None: + vllm_engine.sleep() + + if llm_cfg.vllm_enable_sleep and vllm_engine is not None: + vllm_engine.wake_up() + + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'temperature': 0.25, 'epsilon': 0.0}) + data_processor.get_llm_output_log(wm_train_iter=learner.train_iter, llm_train_iter=policy_model.train_iter) + + if llm_cfg.vllm_enable_sleep and vllm_engine is not None: + vllm_engine.sleep() + + update_per_collect = calculate_update_per_collect(cfg, new_data, world_size=world_size) + + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + num_of_transitions = replay_buffer.get_num_of_transitions() + new_num_of_transitions = replay_buffer.get_num_of_transitions() - replay_buffer.last_pos_in_transition + logger.info( + f"[Data Collection] Rank {rank} | " + f"Total transitions: {num_of_transitions} | " + f"New transitions: {new_num_of_transitions}" + ) + if not (num_of_transitions > batch_size): + logger.warning( + f' ⚠ Data in replay_buffer is not sufficient: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}. Continue to collect...' + ) + cmd = 0 + else: + cmd = 1 + + if min(all_gather_cmd(world_size=world_size, obj=cmd)) == 0: + continue + + logger.info( + f"[World Model Training] Rank {rank} | Iter {learner.train_iter} | " + f"Updates: {update_per_collect}" + ) + + for i in range(update_per_collect): + with prof.block("train_world_model", rank=rank): + train_data = replay_buffer.sample(batch_size, policy) + train_data.append(learner.train_iter) + + log_vars = learner.train(train_data, collector.envstep) + if cfg.policy.use_priority: + replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + policy.recompute_pos_emb_diff_and_clear_cache() + + # 计算需要收集多少样本才能满足 llm 的训练 + # 一次参数更新是train_batch_size,off次数为broadcast_every,每个rank单独收集数据,所以需要除 + # 此外, 需要的 transitions是样本数 / unroll_steps,即轨迹数 + llm_need_sample_cnt = llm_cfg.train_batch_size * llm_cfg.broadcast_every // world_size + llm_need_transition_cnt = (llm_need_sample_cnt + cfg.policy.num_unroll_steps - 1) // cfg.policy.num_unroll_steps + + if learner.train_iter >= llm_cfg.train_llm_after_wm_warm_step and new_num_of_transitions >= llm_need_transition_cnt: + cmd = 1 + else: + cmd = 0 + + if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: + cmd = 2 + + all_cmd = all_gather_cmd(world_size=world_size, obj=cmd) + if max(all_cmd) == 2: + break + elif min(all_cmd) == 1: + with prof.block("fetch_latest_batch", rank=rank): + print(f"[Batch Fetch] Rank {rank}] | WM Iter: {learner.train_iter} | Required transitions: {llm_need_transition_cnt}") + priorzero_batch = replay_buffer.fetch_latest_batch(batch_size=llm_need_transition_cnt, policy=policy) + print(f"[Batch Fetch] Rank {rank}] completed.") + + with prof.block("train_llm", rank=rank): + sample_count = len(priorzero_batch[0]) if priorzero_batch and len(priorzero_batch) > 0 else 0 + logger.info(f"[LLM Training] Rank {rank} | Samples: {sample_count}") + + train_samples = data_processor.make_llm_train_samples(priorzero_batch, ddp=True) + trainer.train_batch(train_samples, collect_env_steps=collector.envstep) + + torch_dist_barrier_and_cuda_sync() + else: + continue + +def main(): + """ + Main entry point with argument parsing. + """ + import argparse + + parser = argparse.ArgumentParser( + description='PriorZero Training with Auto Model Configuration', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Use default model (qwen2.5-1.5b) + torchrun --nproc_per_node 2 priorzero_entry_sync.py + + # Use specific model + torchrun --nproc_per_node 2 priorzero_entry_sync.py --model qwen2.5-0.5b + torchrun --nproc_per_node 2 priorzero_entry_sync.py --model qwen2.5-7b + + # List all available models + python priorzero_entry_sync.py --list-models + + # Different environment + torchrun --nproc_per_node 2 priorzero_entry_sync.py --env_id zork1.z5 --model qwen2.5-1.5b + """ + ) + parser.add_argument('--env_id', type=str, default='detective.z5', help='Jericho game ID') + parser.add_argument('--seed', type=int, default=0, help='Random seed') + parser.add_argument('--max_iter', type=int, default=int(1e6), help='Max training iterations') + parser.add_argument('--quick_test', action='store_true', default=False, help='Use quick test config') + # Model selection + parser.add_argument('--model', type=str, default="qwen2.5-3b", choices=get_available_models()) + parser.add_argument('--enable_profile', action='store_true', default=False) + parser.add_argument('--use_cot', action='store_true', default=True) + args = parser.parse_args() + + model_key = args.model if args.model else "qwen2.5-1.5b" + print(f"\n{'='*80}") + print(f"PriorZero Training Configuration") + print(f"{'='*80}") + print(f"Environment: {args.env_id}") + print(f"Model: {model_key}") + print(f"Seed: {args.seed}") + print(f"Quick Test: {args.quick_test}") + print(f"use cot: {args.use_cot}") + print(f"enable_profile: {args.enable_profile}") + print(f"{'='*80}\n") + + # use_cot = True + if args.quick_test: + logger.info("Using quick test configuration") + main_cfg, create_cfg, llm_cfg = get_priorzero_debug_config( + args.env_id, args.seed, use_cot=args.use_cot, + exp_name=f'data_priorzero/priorzero_debug_{args.env_id}', + model_key=model_key, + ) + else: + main_cfg, create_cfg, llm_cfg = get_priorzero_config( + args.env_id, args.seed, use_cot=args.use_cot, + model_key=model_key, + multi_gpu=True + ) + + train_priorzero( + main_cfg, + create_cfg, + llm_cfg, + seed=args.seed, + max_train_iter=args.max_iter, + enable_profile=args.enable_profile, # 是否要对各个耗时部分进行 profile + ) + + +if __name__ == "__main__": + os.environ['TOKENIZERS_PARALLELISM'] = 'false' + main() diff --git a/zoo/jericho/priorzero/priorzero_evaluator.py b/zoo/jericho/priorzero/priorzero_evaluator.py index c8a25d0f9..9def4da1e 100644 --- a/zoo/jericho/priorzero/priorzero_evaluator.py +++ b/zoo/jericho/priorzero/priorzero_evaluator.py @@ -1,55 +1,407 @@ -# priorzero_evaluator.py -""" -[PRIORZERO] PriorZero Evaluator +import copy +import time +from collections import namedtuple +from typing import Optional, Callable, Tuple, Dict, Any -Simple evaluator that inherits from MuZeroEvaluator. -Since the policy already integrates LLM priors in its _forward_collect method, -the evaluator can use the parent implementation directly. +from collections import deque, defaultdict +import numpy as np +import torch +import wandb +from ding.envs import BaseEnvManager +from ding.torch_utils import to_ndarray, to_item, to_tensor +from ding.utils import build_logger, EasyTimer +from ding.utils import get_world_size, get_rank, broadcast_object_list +from ding.worker.collector.base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor +from easydict import EasyDict -Author: PriorZero Team -Date: 2025-01-20 -""" - -from typing import Optional - -from ding.worker.collector.base_serial_evaluator import SERIAL_EVALUATOR_REGISTRY +from lzero.mcts.buffer.game_segment import GameSegment +from lzero.mcts.utils import prepare_observation +import threading from lzero.worker.muzero_evaluator import MuZeroEvaluator as OriginalEvaluator -from vllm import AsyncLLMEngine - -@SERIAL_EVALUATOR_REGISTRY.register('priorzero', force_overwrite=True) class PriorZeroEvaluator(OriginalEvaluator): """ - [PRIORZERO-MODIFIED] - Evaluator for PriorZero. - - Since the PriorZero policy already integrates LLM priors in its - _forward_collect method, this evaluator simply inherits all - functionality from MuZeroEvaluator. - - The vLLM engine is passed for potential future enhancements - (e.g., comparative evaluation with/without LLM priors). + PriorZero evaluator with three selectable eval modes: + 1) world_model: default UniZero eval + 2) world_model_llm_prior: inject llm_prior to MCTS root policy logits + 3) llm_prior_only: ignore world model and greedily pick best llm_prior action """ - def __init__( - self, - vllm_engine: Optional[AsyncLLMEngine] = None, - **kwargs - ): + def __init__(self, llm_config: Dict, data_processor = None, **kwargs) -> None: + super().__init__(**kwargs) + self.llm_cfg = llm_config + self.data_processor = data_processor + + + self.eval_mode = llm_config.eval_dict + self.eval_freq = self.eval_mode.eval_freq + self.llm_prior_temperature = llm_config.llm_prior_temperature + self.history_buffers = defaultdict( + lambda: deque(maxlen=self.llm_cfg.history_length) + ) + self._logger.info("✓ PriorZeroEvaluator initialized with vLLM engine") + self._logger.info(f" - History length: {self.llm_cfg.history_length}") + + def should_eval(self, train_iter: int) -> bool: + """ + Overview: + Determine whether it's time to run an evaluation based on the training iteration. + Arguments: + - train_iter (:obj:`int`): The current training iteration. + Returns: + - (:obj:`bool`): True if evaluation should be run, otherwise False. """ - Initialize PriorZeroEvaluator. + if train_iter == self._last_eval_iter: + return False + if (train_iter - self._last_eval_iter) < self.eval_freq and train_iter != 0: + return False + self._last_eval_iter = train_iter + return True + + def eval(self, train_iter: int = -1, envstep: int = -1) -> Tuple[bool, Dict[str, Any]]: + modes = [] + if self.eval_mode.world_model: + world_model_info = super().eval() + modes.append(("WM", world_model_info)) + if self.eval_mode.world_model_llm_prior: + world_model_llm_prior_info = self.eval_with_llm_prior() + modes.append(("WM_LLMPrior", world_model_llm_prior_info)) + if self.eval_mode.llm_prior: + llm_prior_info = self.eval_only_llm_prior() + modes.append(("LLMPrior", llm_prior_info)) + + for tag, info in modes: + metrics_str = " | ".join([f"{k}: {info.get(k, 0):.2f}" for k in ['avg_envstep_per_episode', 'reward_mean', 'reward_max', 'reward_min']]) + self._logger.info(f"[RANK {self._rank}] {tag} >> {metrics_str}") + + keys = ['avg_envstep_per_episode', 'reward_mean', 'reward_std', 'reward_max', 'reward_min'] + for k in keys: + if self.eval_mode.world_model: + self._tb_logger.add_scalar(f'{self._instance_name}_iter/{k}_WM', world_model_info[k], train_iter) + self._tb_logger.add_scalar(f'{self._instance_name}_step/{k}_WM', world_model_info[k], envstep) + if self.eval_mode.world_model_llm_prior: + self._tb_logger.add_scalar(f'{self._instance_name}_iter/{k}_WM_LLMPrior', world_model_llm_prior_info[k], train_iter) + self._tb_logger.add_scalar(f'{self._instance_name}_step/{k}_WM_LLMPrior', world_model_llm_prior_info[k], envstep) + if self.eval_mode.llm_prior: + self._tb_logger.add_scalar(f'{self._instance_name}_iter/{k}_LLMPrior', llm_prior_info[k], train_iter) + self._tb_logger.add_scalar(f'{self._instance_name}_step/{k}_LLMPrior', llm_prior_info[k], envstep) + + + def eval_with_llm_prior(self) -> Dict[str, Any]: + n_episode = self._default_n_episode + assert n_episode is not None, "Please specify the number of evaluation episodes (n_episode)." + envstep_count = 0 + eval_monitor = VectorEvalMonitor(self._env.env_num, n_episode) + env_nums = self._env.env_num + + self._env.reset() + self._policy.reset(task_id=self.task_id) + + init_obs = self._env.ready_obs + + retry_waiting_time = 0.001 + while len(init_obs.keys()) != self._env_num: + self._logger.info(f"Waiting for all environments to reset. Current ready envs: {list(init_obs.keys())}") + time.sleep(retry_waiting_time) + init_obs = self._env.ready_obs + + action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} + to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} + + timestep_dict = {} + for i in range(env_nums): + if 'timestep' not in init_obs[i]: + print(f"Warning: 'timestep' key is missing in init_obs[{i}], assigning value -1") + timestep_dict[i] = to_ndarray(init_obs[i].get('timestep', -1)) + + dones = np.array([False for _ in range(env_nums)]) + + game_segments = [ + GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config, + task_id=self.task_id + ) for _ in range(env_nums) + ] + for i in range(env_nums): + game_segments[i].reset( + [to_ndarray(init_obs[i]['observation']) for _ in range(self.policy_config.model.frame_stack_num)] + ) + + ready_env_id = set() + remain_episode = n_episode + eps_steps_lst = np.zeros(env_nums) + with self._timer: + while not eval_monitor.is_finished(): + # Check if a timeout has occurred. + if self.stop_event.is_set(): + self._logger.info("[EVALUATOR]: Evaluation aborted due to timeout.") + break + + # Get observations from ready environments. + obs = self._env.ready_obs + new_available_env_id = set(obs.keys()).difference(ready_env_id) + ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) + remain_episode -= min(len(new_available_env_id), remain_episode) + + # Prepare stacked observations and other inputs for the policy. + stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} + stack_obs = list(stack_obs.values()) + action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] + to_play = [to_play_dict[env_id] for env_id in ready_env_id] + timestep = [timestep_dict[env_id] for env_id in ready_env_id] + + stack_obs = to_ndarray(stack_obs) + stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) + stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() + + # ============================================ + # 添加 LLM_PRIOR + raw_obs_list = [] + histories_list = [] + valid_actions_list = [] + for env_id in sorted(list(ready_env_id)): + raw_obs_text = obs[env_id]['raw_obs_text'] + raw_obs_list.append(raw_obs_text) + + history = list(self.history_buffers[env_id]) + histories_list.append(history) + + valid_actions = obs[env_id].get('valid_actions', []) + valid_actions_list.append(valid_actions) + + llm_prior_per_seq, _, _ = self.data_processor.get_llm_prior( + states=raw_obs_list, + valid_actions_list=valid_actions_list, # [PRIORZERO] Pass valid actions + histories=histories_list, + return_cot=True # Request CoT prefixes for reuse in training + ) + for env_id, llm_prior in enumerate(llm_prior_per_seq): + scaled_llm_prior = self.apply_temperature_scaling(llm_prior, return_logprobs=True) + llm_prior_per_seq[env_id] = scaled_llm_prior + + policy_kwargs_forward = { + 'llm_prior_logprob': llm_prior_per_seq, + 'valid_actions_list': valid_actions_list, + } + # ============================================ + if self.task_id is not None: + policy_kwargs_forward['task_id'] = self.task_id + # ============================================================== + # Policy Forward Pass + # ============================================================== + policy_output = self._policy.forward(data=stack_obs, action_mask=action_mask, + to_play=to_play, ready_env_id=ready_env_id, + timestep=timestep, **policy_kwargs_forward) + # Unpack policy outputs. + actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} + distributions_dict_with_env_id = {k: v['visit_count_distributions'] for k, v in policy_output.items()} + + value_dict_with_env_id = {k: v['searched_value'] for k, v in policy_output.items()} + pred_value_dict_with_env_id = {k: v['predicted_value'] for k, v in policy_output.items()} + timestep_dict_with_env_id = {k: v.get('timestep', -1) for k, v in policy_output.items()} + visit_entropy_dict_with_env_id = {k: v['visit_count_distribution_entropy'] for k, v in policy_output.items()} + + # Remap outputs from policy's internal IDs to environment IDs. + actions, distributions_dict, value_dict, pred_value_dict, timestep_dict, visit_entropy_dict = {}, {}, {}, {}, {}, {} + + for index, env_id in enumerate(ready_env_id): + actions[env_id] = actions_with_env_id.pop(env_id) + distributions_dict[env_id] = distributions_dict_with_env_id.pop(env_id) - Args: - vllm_engine: vLLM async engine (optional, for future use) - **kwargs: Arguments for parent MuZeroEvaluator + + value_dict[env_id] = value_dict_with_env_id.pop(env_id) + pred_value_dict[env_id] = pred_value_dict_with_env_id.pop(env_id) + timestep_dict[env_id] = timestep_dict_with_env_id.pop(env_id) + visit_entropy_dict[env_id] = visit_entropy_dict_with_env_id.pop(env_id) + + # ============================================================== + # Environment Interaction + # ============================================================== + timesteps = self._env.step(actions) + timesteps = to_tensor(timesteps, dtype=torch.float32) + for env_id, episode_timestep in timesteps.items(): + obs_new, reward, done, info = episode_timestep.obs, episode_timestep.reward, episode_timestep.done, episode_timestep.info + + action = info['action_str'] + self.history_buffers[env_id].append((obs[env_id]['raw_obs_text'], action, float(reward))) + + eps_steps_lst[env_id] += 1 + # This reset logic is specific to UniZero-like models. + if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero', 'priorzero']: + self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) + + game_segments[env_id].append( + actions[env_id], to_ndarray(obs_new['observation']), reward, action_mask_dict[env_id], + to_play_dict[env_id], timestep_dict[env_id] + ) + + # IMPORTANT: The action_mask and to_play from the new observation correspond to the *next* state. + action_mask_dict[env_id] = to_ndarray(obs_new['action_mask']) + to_play_dict[env_id] = to_ndarray(obs_new['to_play']) + timestep_dict[env_id] = to_ndarray(obs_new.get('timestep', -1)) + + dones[env_id] = done + if episode_timestep.done: + self._policy.reset([env_id]) + reward = episode_timestep.info['score'] + saved_info = {'eval_episode_return': episode_timestep.info['score']} + if 'episode_info' in episode_timestep.info: + saved_info.update(episode_timestep.info['episode_info']) + eval_monitor.update_info(env_id, saved_info) + eval_monitor.update_reward(env_id, reward) + self._logger.info( + f"[EVALUATOR] env {env_id} finished episode, final reward: {eval_monitor.get_latest_reward(env_id)}, " + f"current episode count: {eval_monitor.get_current_episode()}" + ) + + # If there are more episodes to run than available environments, reset and reuse this one. + if n_episode > self._env_num: + init_obs = self._env.ready_obs + # Wait for the environment to be ready again. + while len(init_obs.keys()) != self._env_num: + self._logger.info(f"Waiting for env {env_id} to reset. Current ready envs: {list(init_obs.keys())}") + time.sleep(retry_waiting_time) + init_obs = self._env.ready_obs + + new_available_env_id = set(init_obs.keys()).difference(ready_env_id) + ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) + remain_episode -= min(len(new_available_env_id), remain_episode) + + # Re-initialize state for the new episode. + action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) + to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) + timestep_dict[env_id] = to_ndarray(init_obs[env_id].get('timestep', -1)) + + game_segments[env_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config, + task_id=self.task_id + ) + game_segments[env_id].reset( + [init_obs[env_id]['observation'] for _ in range(self.policy_config.model.frame_stack_num)] + ) + + eps_steps_lst[env_id] = 0 + # NOTE: Reset the policy state for this env_id. `reset_init_data` defaults to True. + self._policy.reset([env_id]) + ready_env_id.remove(env_id) + + envstep_count += 1 + + duration = self._timer.value + episode_return = eval_monitor.get_episode_return() + info = { + 'avg_envstep_per_episode': envstep_count / n_episode if n_episode > 0 else 0, + 'reward_mean': np.mean(episode_return), + 'reward_std': np.std(episode_return), + 'reward_max': np.max(episode_return), + 'reward_min': np.min(episode_return), + } + return info + + def eval_only_llm_prior(self) -> Dict[str, Any]: + n_episode = self._default_n_episode + assert n_episode is not None, "Please specify the number of evaluation episodes (n_episode)." + envstep_count = 0 + env_nums = self._env.env_num + + self._env.reset() + + dones = np.array([False for _ in range(env_nums)]) + ready_env_id = [i for i in range(env_nums)] + episode_return = [] + while True: + if all(dones): + break + + obs = self._env.ready_obs + # ============================================ + # 添加 LLM_PRIOR + raw_obs_list = [] + histories_list = [] + valid_actions_list = [] + for env_id in sorted(list(ready_env_id)): + raw_obs_text = obs[env_id]['raw_obs_text'] + raw_obs_list.append(raw_obs_text) + + history = list(self.history_buffers[env_id]) + histories_list.append(history) + + valid_actions = obs[env_id].get('valid_actions', []) + valid_actions_list.append(valid_actions) + + llm_prior_per_seq, _, _ = self.data_processor.get_llm_prior( + states=raw_obs_list, + valid_actions_list=valid_actions_list, # [PRIORZERO] Pass valid actions + histories=histories_list, + return_cot=True # Request CoT prefixes for reuse in training + ) + actions = {env_id: None for env_id in sorted(list(ready_env_id))} + + for env_id, llm_prior, valid_actions in zip(sorted(list(ready_env_id)), llm_prior_per_seq, valid_actions_list): + if len(llm_prior) == 1: # 只有go,即valid_action_len=0 + assert len(valid_actions) == 0 + actions[env_id] = 0 + if 'go' in llm_prior and 'go' not in valid_actions: + llm_prior.pop('go') + action_str_select, max_logprob = "", float(-1e9) + for action_str, logprob in llm_prior.items(): + if logprob > max_logprob: + action_str_select = action_str + max_logprob = logprob + actions[env_id] = valid_actions.index(action_str_select) + + # ============================================ + + timesteps = self._env.step(actions) + timesteps = to_tensor(timesteps, dtype=torch.float32) + for env_id, episode_timestep in timesteps.items(): + obs_new, reward, done, info = episode_timestep.obs, episode_timestep.reward, episode_timestep.done, episode_timestep.info + + action = info['action_str'] + self.history_buffers[env_id].append((obs[env_id]['raw_obs_text'], action, float(reward))) + + dones[env_id] = done + if episode_timestep.done: + ready_env_id.remove(env_id) + episode_return.append(info['score']) + + envstep_count += 1 + info = { + 'avg_envstep_per_episode': envstep_count / n_episode if n_episode > 0 else 0, + 'reward_mean': np.mean(episode_return), + 'reward_std': np.std(episode_return), + 'reward_max': np.max(episode_return), + 'reward_min': np.min(episode_return), + } + return info + + def apply_temperature_scaling(self, logprobs_dict: dict, return_logprobs: bool = True) -> dict: """ - super().__init__(**kwargs) - self.vllm_engine = vllm_engine + 对 Logprobs 字典进行温度缩放,控制分布的平缓程度。 + """ + import math + T = self.llm_prior_temperature + if T <= 1e-8: + max_key = max(logprobs_dict, key=logprobs_dict.get) + return {k: (0.0 if k != max_key else 1.0) for k in logprobs_dict} + + scaled_logits = {k: v / T for k, v in logprobs_dict.items()} + + max_val = max(scaled_logits.values()) + sum_exp = sum(math.exp(v - max_val) for v in scaled_logits.values()) + log_sum_exp = math.log(sum_exp) + max_val - if vllm_engine is not None: - self._logger.info("✓ PriorZeroEvaluator initialized with vLLM engine") - else: - self._logger.info("✓ PriorZeroEvaluator initialized (no vLLM engine)") + result = {} + for k, v in scaled_logits.items(): + normalized_logprob = v - log_sum_exp + + if return_logprobs: + result[k] = normalized_logprob + else: + result[k] = math.exp(normalized_logprob) - # All other methods are inherited from MuZeroEvaluator - # The policy's _forward_collect already handles LLM prior integration + return result \ No newline at end of file diff --git a/zoo/jericho/priorzero/priorzero_orz_complete.py b/zoo/jericho/priorzero/priorzero_orz_complete.py deleted file mode 100644 index f0daf5958..000000000 --- a/zoo/jericho/priorzero/priorzero_orz_complete.py +++ /dev/null @@ -1,965 +0,0 @@ -""" -PriorZero-ORZ Complete Integration -完整可执行版本 with ORZ RayPPOTrainer - -This version includes: -1. Fixed vLLM None handling -2. Fixed asyncio scope issue -3. Complete ORZ RayPPOTrainer integration -4. Robust error handling - -Usage: - DEBUG_MODE=True python -m zoo.jericho.priorzero.priorzero_orz_complete - -Author: PriorZero Team -Date: 2025-10-21 -""" - -import asyncio -import os -import sys -import re -from pathlib import Path -from functools import partial -from typing import Optional, List, Dict, Any, Callable, Awaitable, Tuple -import time -import json - -# ============================================================================== -# Ensure local LightZero is used -# ============================================================================== -from ensure_local_lightzero import ensure_local_lightzero -ensure_local_lightzero() - -import torch -import numpy as np -from ding.config import compile_config -from ding.envs import create_env_manager, get_vec_env_setting -from ding.policy import create_policy -from ding.utils import set_pkg_seed, get_rank -from ding.worker import BaseLearner -from tensorboardX import SummaryWriter -from loguru import logger - -# PriorZero imports -from priorzero_config import get_priorzero_config_for_quick_test, get_priorzero_config -from priorzero_collector import PriorZeroCollector -from priorzero_evaluator import PriorZeroEvaluator -import priorzero_policy # noqa: F401 -from lzero.mcts.buffer.game_buffer_priorzero import PriorZeroGameBufferOptimized - -# vLLM imports (optional) -try: - from vllm import AsyncLLMEngine - from vllm.engine.arg_utils import AsyncEngineArgs - VLLM_AVAILABLE = True -except ImportError: - VLLM_AVAILABLE = False - logger.warning("vLLM not available - LLM inference will be disabled") - -# Try to import ORZ -ORZ_AVAILABLE = False -ORZ_PATH = Path("/mnt/nfs/zhangjinouwen/puyuan/Open-Reasoner-Zero") - -try: - if ORZ_PATH.exists() and str(ORZ_PATH) not in sys.path: - sys.path.insert(0, str(ORZ_PATH)) - - from orz.ppo import RayPPOTrainer, PromptDataset - from orz.exps.examples.ppo.ppo_base_exp import BasePPOExp, BasePPOExpConfig - from orz.ppo.utils import get_strategy - from transformers import AutoTokenizer - import ray - ORZ_AVAILABLE = True - logger.info("✅ ORZ available - will use ORZ RayPPOTrainer for LLM training") -except ImportError as e: - logger.warning(f"⚠️ ORZ not available ({e}) - will use PriorZero's built-in LLM training") - - -# ============================================================================== -# Configuration -# ============================================================================== - -DEBUG_MODE = os.environ.get("DEBUG_MODE", "False") == "True" - - -class HybridTrainingConfig: - """ - Hybrid training configuration combining PriorZero and ORZ settings. - """ - def __init__(self): - # Get base PriorZero config - if DEBUG_MODE: - self.priorzero_cfg, self.priorzero_create_cfg = get_priorzero_config_for_quick_test( - env_id='zork1.z5', - seed=0, - debug_mode=True - ) - else: - self.priorzero_cfg, self.priorzero_create_cfg = get_priorzero_config( - env_id='zork1.z5', - seed=0, - enable_llm=True, - enable_rft=True, - debug_mode=False - ) - - # Hybrid-specific settings - self.wm_training_mode = "parallel" - self.wm_train_freq = 1 - self.llm_train_freq = 5 - self.use_orz_trainer = ORZ_AVAILABLE - - # vLLM settings - self.use_vllm = VLLM_AVAILABLE - self.vllm_required = False # Set to True if vLLM is required - - # ORZ-specific settings (only used if ORZ_AVAILABLE) - if ORZ_AVAILABLE: - self.orz_rollout_batch_size = 32 if DEBUG_MODE else 128 - self.orz_train_batch_size = 8 if DEBUG_MODE else 32 - self.orz_actor_lr = 1e-6 - self.orz_critic_lr = 5e-6 - self.orz_num_episodes = 2 if DEBUG_MODE else 10 - - -# ============================================================================== -# ORZ Data Adapter and Dataset -# ============================================================================== - -class GameSegmentToORZAdapter: - """ - Convert PriorZero game_segments to ORZ-compatible format. - """ - - @staticmethod - def convert_segments_to_prompts(game_segments: List[Any], tokenizer) -> List[Dict]: - """ - Convert game_segments to ORZ prompt format. - - Args: - game_segments: List of GameSegment from PriorZero - tokenizer: HuggingFace tokenizer - - Returns: - List of ORZ-compatible prompt dictionaries - """ - prompts = [] - - for segment in game_segments: - # Extract raw observations if available - if hasattr(segment, 'raw_obs_segment') and segment.raw_obs_segment: - for i, (obs, action) in enumerate(zip( - segment.raw_obs_segment, - segment.action_segment - )): - # Create ORZ format prompt - prompt_dict = { - "prompt": [{"value": obs}], - "final_answer": action, - "file_name": f"segment_{id(segment)}_step_{i}" - } - prompts.append(prompt_dict) - - return prompts - - @staticmethod - def extract_training_data(game_segments: List[Any]) -> Dict[str, List]: - """ - Extract training data from game_segments for ORZ. - - Returns: - Dictionary containing: - - states: List of state descriptions - - actions: List of actions taken - - rewards: List of rewards received - - mcts_policies: List of MCTS visit distributions - """ - training_data = { - 'states': [], - 'actions': [], - 'rewards': [], - 'mcts_policies': [] - } - - for segment in game_segments: - # Extract raw observations (states) - if hasattr(segment, 'raw_obs_segment'): - training_data['states'].extend(segment.raw_obs_segment) - - # Extract actions - if hasattr(segment, 'action_segment'): - training_data['actions'].extend(segment.action_segment) - - # Extract rewards - if hasattr(segment, 'reward_segment'): - training_data['rewards'].extend(segment.reward_segment) - - # Extract MCTS policies - if hasattr(segment, 'mcts_policy_segment'): - training_data['mcts_policies'].extend(segment.mcts_policy_segment) - - return training_data - - -# Only define dataset classes if ORZ is available -if ORZ_AVAILABLE: - from jinja2 import Template - - class JerichoPromptDataset(PromptDataset): - """ - Custom dataset for Jericho text adventure games in ORZ format. - Adapts PriorZero game_segments to ORZ PPO training format. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def process_dialogue(self, dialogue: dict): - """ - Process a single dialogue (observation + action pair) into ORZ format. - - Args: - dialogue: Dict with 'prompt', 'final_answer', 'file_name' - - Returns: - prompt: Formatted prompt string - extra: Dict with answer and metadata - """ - # Template for Jericho text adventure prompts - prompt_template_jinja = """\ -{{bos_token}}A conversation between User and Assistant. The User is playing a text adventure game \ -and needs to decide the next action. The Assistant carefully analyzes the current game state, \ -considers the available actions, and recommends the best action to take. \ -The reasoning process is enclosed within tags, and the recommended action \ -is enclosed within tags. For example: \ - The player is in a dark room and needs light. The lamp is available. \ - take lamp . User: {{prompt}} -Assistant: \ -""" - - prompt_instruction_template_jinja = """\ -Current game state: -{{prompt}} - -What is the best action to take? Put your answer inside tags. -""" - - # Validate dialogue format - assert isinstance(dialogue, dict), "dialogue must be a dict" - assert "prompt" in dialogue, "dialogue must contain prompt" - assert "final_answer" in dialogue, "dialogue must contain final_answer" - - # Build prompt - prompt_instruction_template = Template(prompt_instruction_template_jinja) - prompt_instruction = prompt_instruction_template.render( - prompt=dialogue["prompt"][0]["value"] - ) - - prompt_template = Template(prompt_template_jinja) - if self.tokenizer.bos_token_id is None: - bos_token = "" - else: - bos_token = self.tokenizer.decode([self.tokenizer.bos_token_id]) - - prompt = prompt_template.render( - bos_token=bos_token, - prompt=prompt_instruction - ) - - extra = { - "answer": dialogue["final_answer"], - "file_name": dialogue.get("file_name", "unknown") - } - - return prompt, extra - - -# ============================================================================== -# Main Training Function -# ============================================================================== - -async def train_priorzero_orz_complete( - cfg: dict, - create_cfg: dict, - hybrid_cfg: HybridTrainingConfig, - seed: int = 0, - max_train_iter: int = 10000, - max_env_step: Optional[int] = int(1e10), - enable_save: bool = True, -): - """ - Main hybrid training function with complete ORZ integration. - """ - # ================================================================== - # 1. Compile Configuration - # ================================================================== - cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg) - - # ================================================================== - # 2. Create vLLM Engine (optional) - Based on priorzero_entry.py - # ================================================================== - vllm_engine = None - - if hybrid_cfg.use_vllm and VLLM_AVAILABLE: - logger.info("Creating vLLM engine...") - - # [ROBUST FIX] Handle shared GPU environment - # Solution: Use alternative initialization method with fallback - tensor_parallel = cfg.policy.llm_policy_cfg.vllm_tensor_parallel_size - distributed_backend = "ray" if tensor_parallel > 1 else None - - # [ROBUST FIX] Lower GPU memory utilization in shared environment - gpu_mem_util = cfg.policy.llm_policy_cfg.gpu_memory_utilization - if gpu_mem_util > 0.85: - gpu_mem_util = 0.75 # More conservative - logger.info(f"✓ Adjusted GPU memory utilization to {gpu_mem_util} for stability") - - # [ROBUST FIX] Use vLLM V0 engine for stability (as in priorzero_entry.py) - use_v1_env = os.environ.get('VLLM_USE_V1', None) - if use_v1_env is None: - # Only set if not already set by user - os.environ['VLLM_USE_V1'] = '0' - logger.info("✓ Using vLLM V0 engine for stability") - - # Fix tokenizers parallelism warning - os.environ['TOKENIZERS_PARALLELISM'] = 'false' - - try: - from vllm.engine.arg_utils import AsyncEngineArgs - - engine_args = AsyncEngineArgs( - model=cfg.policy.llm_policy_cfg.pretrain_llm_path, - tensor_parallel_size=tensor_parallel, - gpu_memory_utilization=gpu_mem_util, - distributed_executor_backend=distributed_backend, - trust_remote_code=True, - enable_prefix_caching=False, - enforce_eager=False, - ) - vllm_engine = AsyncLLMEngine.from_engine_args(engine_args) - logger.info(f"✓ vLLM Engine created (backend: {distributed_backend or 'default'})") - - except (ValueError, RuntimeError) as e: - if "VLLM_USE_V1" in str(e) or "memory profiling" in str(e): - # Fallback: Try without V1 env var or with eager mode - logger.warning(f"⚠️ Initial vLLM initialization failed: {e}") - logger.info("Retrying with alternative configuration...") - - if 'VLLM_USE_V1' in os.environ: - del os.environ['VLLM_USE_V1'] - - try: - engine_args = AsyncEngineArgs( - model=cfg.policy.llm_policy_cfg.pretrain_llm_path, - tensor_parallel_size=tensor_parallel, - gpu_memory_utilization=gpu_mem_util * 0.9, # Even more conservative - distributed_executor_backend=distributed_backend, - trust_remote_code=True, - enable_prefix_caching=False, - enforce_eager=True, # Force eager mode as fallback - ) - vllm_engine = AsyncLLMEngine.from_engine_args(engine_args) - logger.info(f"✓ vLLM Engine created with fallback configuration") - except Exception as e2: - logger.error(f"❌ Failed to create vLLM engine with fallback: {e2}") - if hybrid_cfg.vllm_required: - raise - logger.warning("Continuing without vLLM (LLM prior will be disabled)") - else: - logger.error(f"❌ Failed to create vLLM engine: {e}") - import traceback - logger.error(f"Full traceback:\n{traceback.format_exc()}") - if hybrid_cfg.vllm_required: - raise - logger.warning("Continuing without vLLM (LLM prior will be disabled)") - else: - logger.info("vLLM disabled or not available - continuing without LLM inference") - - # ================================================================== - # 3. Create Environments - # ================================================================== - logger.info("Creating environments...") - env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) - - collector_env = create_env_manager( - cfg.env.manager, - [partial(env_fn, cfg=c) for c in collector_env_cfg] - ) - evaluator_env = create_env_manager( - cfg.env.manager, - [partial(env_fn, cfg=c) for c in evaluator_env_cfg] - ) - - # Seed environments - collector_env.seed(seed) - evaluator_env.seed(seed, dynamic_seed=False) - set_pkg_seed(seed, use_cuda=True) - logger.info(f"✓ Environments created and seeded (seed={seed})") - - # ================================================================== - # 4. Create Policy, Buffer, and Components - # ================================================================== - logger.info("Creating policy, buffer, and components...") - - # Create policy - policy = create_policy( - cfg.policy, - enable_field=['learn', 'collect', 'eval'] - ) - logger.info("✓ Policy created") - - # Create TensorBoard logger - os.makedirs(f'./{cfg.exp_name}/log/', exist_ok=True) - tb_logger = SummaryWriter( - os.path.join(f'./{cfg.exp_name}/log/', 'serial') - ) if get_rank() == 0 else None - logger.info(f"✓ TensorBoard logger: ./{cfg.exp_name}/log/") - - # Create learner (for world model training) - learner = BaseLearner( - cfg.policy.learn.learner, - policy.learn_mode, - tb_logger, - exp_name=cfg.exp_name - ) - logger.info("✓ BaseLearner created") - - # Create replay buffer - replay_buffer = PriorZeroGameBufferOptimized(cfg.policy) - logger.info("✓ PriorZero replay buffer created") - - # Create collector - collector = PriorZeroCollector( - env=collector_env, - policy=policy.collect_mode, - tb_logger=tb_logger, - exp_name=cfg.exp_name, - vllm_engine=vllm_engine, # May be None - policy_config=cfg.policy, - debug_mode=cfg.get('debug_mode', False), - ) - logger.info("✓ Collector created") - - # Create evaluator - evaluator = PriorZeroEvaluator( - eval_freq=cfg.policy.eval_freq, - n_evaluator_episode=cfg.env.n_evaluator_episode, - stop_value=cfg.env.stop_value, - env=evaluator_env, - policy=policy.eval_mode, - tb_logger=tb_logger, - exp_name=cfg.exp_name, - vllm_engine=vllm_engine, # May be None - ) - logger.info("✓ Evaluator created") - - # Call learner's before_run hook - learner.call_hook('before_run') - - # ================================================================== - # 5. Initialize ORZ Trainer (if available) - # ================================================================== - orz_trainer = None - orz_adapter = GameSegmentToORZAdapter() - orz_tokenizer = None - orz_strategy = None - - if hybrid_cfg.use_orz_trainer and ORZ_AVAILABLE: - logger.info("="*80) - logger.info("Initializing ORZ RayPPOTrainer for LLM training...") - logger.info("="*80) - - try: - # Initialize Ray if not already running - if not ray.is_initialized(): - ray.init(ignore_reinit_error=True) - logger.info("✓ Ray initialized") - - # Create ORZ tokenizer - orz_tokenizer = AutoTokenizer.from_pretrained( - cfg.policy.llm_policy_cfg.pretrain_llm_path, - trust_remote_code=True - ) - if orz_tokenizer.pad_token is None: - orz_tokenizer.pad_token = orz_tokenizer.eos_token - logger.info("✓ ORZ tokenizer created") - - # Create ORZ strategy (DeepSpeed config) - from orz.ppo.utils import get_strategy - orz_strategy = get_strategy({ - 'zero_stage': 2, - 'bf16': True, - 'gradient_checkpointing': True, - }) - logger.info("✓ ORZ strategy created") - - # Create ORZ configuration (matching ORZ's PPOExpConfig pattern) - from dataclasses import dataclass, field - from omegaconf.listconfig import ListConfig - - @dataclass - class ORZConfig: - """Simplified ORZ config for PriorZero integration""" - # Resource settings (simplified for single-node) - total_num_nodes: int = 1 - ref_num_nodes: int = 1 - ref_num_gpus_per_node: int = 1 - actor_num_nodes: int = 1 - actor_num_gpus_per_node: int = 1 - critic_num_nodes: int = 1 - critic_num_gpus_per_node: int = 1 - colocate_all: bool = True - colocate_critic_reward: bool = True - colocate_actor_ref: bool = True - vllm_num_engines: int = 1 - vllm_tensor_parallel_size: int = 1 - zero_stage: int = 2 - adam_offload: bool = False - - # Model paths - pretrain: str = cfg.policy.llm_policy_cfg.pretrain_llm_path - reward_pretrain: Optional[str] = None - critic_pretrain: Optional[str] = cfg.policy.llm_policy_cfg.pretrain_llm_path - - # Save/log paths - save_interval: int = 50 - ckpt_path: str = f'./{cfg.exp_name}/orz_ckpt' - save_path: str = f'./{cfg.exp_name}/orz_save' - tensorboard_log_dir: str = f'./{cfg.exp_name}/orz_log' - - # Training settings - actor_learning_rate: float = hybrid_cfg.orz_actor_lr if hasattr(hybrid_cfg, 'orz_actor_lr') else 1e-6 - critic_learning_rate: float = hybrid_cfg.orz_critic_lr if hasattr(hybrid_cfg, 'orz_critic_lr') else 5e-6 - num_warmup_steps: int = 50 - prompt_max_len: int = 2048 - enable_prefix_caching: bool = False - update_ref_every_epoch: bool = True - advantage_normalize: bool = True - - # Episode settings - num_episodes: int = hybrid_cfg.orz_num_episodes if hasattr(hybrid_cfg, 'orz_num_episodes') else 2 - rollout_batch_size: int = hybrid_cfg.orz_rollout_batch_size if hasattr(hybrid_cfg, 'orz_rollout_batch_size') else 32 - n_samples_per_prompt: int = 8 if DEBUG_MODE else 32 - micro_rollout_batch_size: int = 2 - policy_update_steps: int = 1 - critic_update_steps: int = 1 if DEBUG_MODE else 12 - micro_train_batch_size: int = 1 - micro_forward_batch_size: int = 1 - freezing_actor_steps: int = -1 - - # KL settings - init_kl_coef: float = 0 - kl_loss_coef: float = 0.0 - use_kl_loss: bool = False - use_kl_estimator_k3: bool = True - - # Eval settings - enable_eval: bool = False # Disable ORZ eval (use PriorZero's) - eval_interval: int = 100 - - # Generation settings - packing_max_len: int = 8192 - generate_max_len: int = cfg.policy.llm_policy_cfg.generate_max_len - max_len: int = 4096 - temperature: float = 1.0 - top_p: float = 1.0 - top_k: int = -1 - stop: ListConfig = field(default_factory=lambda: ListConfig([""])) - - # GRPO settings - use_grpo: bool = False - gamma: float = 1.0 - lambd: float = 1.0 - - # vLLM settings - gpu_memory_utilization: float = 0.3 - - # Custom settings for compute_reward_fn - use_compute_reward_fn: bool = True - use_orm_score: bool = False - - orz_cfg = ORZConfig() - - # Create directories for ORZ - os.makedirs(orz_cfg.ckpt_path, exist_ok=True) - os.makedirs(orz_cfg.save_path, exist_ok=True) - os.makedirs(orz_cfg.tensorboard_log_dir, exist_ok=True) - - logger.info("✓ ORZ config created") - logger.info(f" - Model: {orz_cfg.pretrain}") - logger.info(f" - Rollout batch: {orz_cfg.rollout_batch_size}") - logger.info(f" - Episodes: {orz_cfg.num_episodes}") - - # Note: Full RayPPOTrainer initialization requires: - # 1. Creating vLLM engines for distributed inference - # 2. Creating initial dataset from game_segments - # 3. Initializing Ray actors (will be done lazily on first training call) - # - # We defer full initialization until we have actual game_segments to train on - logger.info("✓ ORZ trainer components ready") - logger.info(" (Full RayPPOTrainer will be initialized on first training iteration)") - - except Exception as e: - logger.error(f"❌ ORZ trainer initialization failed: {e}") - import traceback - logger.error(traceback.format_exc()) - logger.warning("Falling back to PriorZero's built-in LLM training") - hybrid_cfg.use_orz_trainer = False - - # ================================================================== - # 6. Main Training Loop - # ================================================================== - logger.info("="*80) - logger.info("Starting PriorZero-ORZ Complete Training") - logger.info("="*80) - logger.info(f"Experiment: {cfg.exp_name}") - logger.info(f"Max iterations: {max_train_iter}") - logger.info(f"Training mode: {hybrid_cfg.wm_training_mode}") - logger.info(f"Use ORZ trainer: {hybrid_cfg.use_orz_trainer}") - logger.info(f"Use vLLM: {vllm_engine is not None}") - logger.info(f"LLM model: {cfg.policy.llm_policy_cfg.pretrain_llm_path}") - logger.info(f"World model: UniZero") - logger.info("="*80) - - # Training state - best_eval_reward = -float('inf') - total_game_segments_collected = 0 - - try: - while learner.train_iter < max_train_iter and collector.envstep < max_env_step: - current_iter = learner.train_iter - - # ============================================================== - # Step 1: Evaluation (if needed) - # ============================================================== - if current_iter > 0 and evaluator.should_eval(current_iter): - logger.info(f"\n{'='*60}") - logger.info(f"[Iter {current_iter}] Evaluating...") - logger.info(f"{'='*60}") - - eval_result = await evaluator.eval( - save_ckpt_fn=learner.save_checkpoint if enable_save else None, - train_iter=current_iter, - envstep=collector.envstep - ) - - if eval_result is not None: - stop, eval_reward_dict = eval_result - mean_reward = eval_reward_dict.get('reward_mean', 0) - logger.info(f"✓ Evaluation: reward_mean={mean_reward:.2f}") - - if mean_reward > best_eval_reward: - best_eval_reward = mean_reward - logger.info(f"🎯 New best reward: {best_eval_reward:.2f}") - - if stop: - logger.info(f"🎉 Training converged! (reward >= {cfg.env.stop_value})") - break - - # ============================================================== - # Step 2: Collect Data using MCTS - # ============================================================== - logger.info(f"\n[Iter {current_iter}] Collecting data...") - - collect_kwargs = { - 'temperature': 0.25, - 'epsilon': 0.0 - } - - try: - new_data = await collector.collect( - train_iter=current_iter, - policy_kwargs=collect_kwargs - ) - except Exception as e: - logger.error(f"❌ Collection failed: {e}") - logger.warning("Skipping this iteration...") - continue - - # Add to replay buffer - from lzero.entry.utils import calculate_update_per_collect - update_per_collect = calculate_update_per_collect(cfg, new_data, world_size=1) - - # Update buffer - replay_buffer.push_game_segments(new_data) - logger.info( - f"✓ Collected {len(new_data)} segments " - f"(total: {replay_buffer.get_num_of_game_segments()} segments, " - f"{replay_buffer.get_num_of_transitions()} transitions)" - ) - - total_game_segments_collected += len(new_data) - - # ============================================================== - # Step 3: World Model Training - # ============================================================== - if current_iter % hybrid_cfg.wm_train_freq == 0: - if replay_buffer.get_num_of_transitions() >= cfg.policy.batch_size: - logger.info(f"[Iter {current_iter}] Training world model...") - - # Sample and train - for _ in range(update_per_collect): - train_data = replay_buffer.sample( - cfg.policy.batch_size, - policy - ) - - # Train (includes both WM and LLM in PriorZero) - log_dict = learner.train(train_data, collector.envstep) - - # Log to TensorBoard - if tb_logger and get_rank() == 0: - for k, v in log_dict.items(): - tb_logger.add_scalar(f'train/{k}', v, collector.envstep) - - logger.info( - f"✓ WM training done - " - f"wm_loss: {log_dict.get('wm_total_loss', 0):.4f}, " - f"llm_sft_loss: {log_dict.get('llm_sft_loss', 0):.4f}" - ) - else: - logger.info(f"Skipping training - not enough data yet") - - # ============================================================== - # Step 4: LLM Training with ORZ (if enabled) - # ============================================================== - if (hybrid_cfg.use_orz_trainer and orz_trainer is not None and - current_iter % hybrid_cfg.llm_train_freq == 0 and - current_iter > 0): - logger.info(f"[Iter {current_iter}] Training LLM with ORZ...") - - try: - # Extract game_segments from recent collections - training_data = orz_adapter.extract_training_data(new_data) - num_samples = len(training_data['states']) - - if num_samples > 0: - logger.info(f" Extracted {num_samples} training samples for ORZ") - - # Initialize ORZ trainer on first use (lazy initialization) - if orz_trainer is None: - logger.info(" Initializing ORZ RayPPOTrainer...") - - # Convert game_segments to ORZ dataset format - dialogues = orz_adapter.convert_segments_to_prompts( - new_data, - orz_tokenizer - ) - - # Create ORZ dataset - orz_dataset = JerichoPromptDataset( - dialogues, - orz_tokenizer, - orz_cfg.prompt_max_len, - orz_strategy, - pretrain_mode=False, - num_processors=1 - ) - - # Create custom reward trainer - from orz.exps.examples.ppo.ppo_base_exp import BasePPOExp - - class JerichoRewardTrainer(RayPPOTrainer): - """Custom reward trainer for Jericho text adventures""" - - async def custom_reward_fn( - self, - prompts: List[str], - outputs: List[Any], - extras: List[dict], - reward_model_fn, - ): - """ - Compute rewards for Jericho actions. - Reward is 1.0 if action matches ground truth, else 0.0 - """ - import torch - scores = [] - responses = [] - - for output, extra in zip(outputs, extras): - response = output["response"] - responses.append(response) - - # Extract action from response - # Look for ... tags - import re - pattern = re.compile(r"(.*?)", re.DOTALL) - matches = re.findall(pattern, response) - predicted_action = matches[-1].strip() if matches else "" - - # Ground truth action - true_action = extra["answer"] - - # Simple exact match for now - # TODO: Could use fuzzy matching or LLM-based similarity - score = 1.0 if predicted_action.lower() == true_action.lower() else 0.0 - scores.append(score) - - # Log statistics - avg_score = sum(scores) / len(scores) if scores else 0.0 - logger.info(f" ORZ reward - avg: {avg_score:.3f}, samples: {len(scores)}") - - # Create score tensors (reward only on last token) - output_tokens = self._tokenize(responses, self.cfg.generate_max_len, padding=False)["input_ids"] - score_tensors = [] - for score, output_token in zip(scores, output_tokens): - score_tensor = torch.zeros(len(output_token)) - if len(output_token) > 0: - score_tensor[-1] = score - score_tensors.append(score_tensor) - - # Remove empty responses - res_prompts, res_responses, res_score_tensors = [], [], [] - for prompt, response, score_tensor in zip(prompts, responses, score_tensors): - if len(response) > 0: - res_prompts.append(prompt) - res_responses.append(response) - res_score_tensors.append(score_tensor) - - return res_prompts, res_responses, res_score_tensors - - # Create vLLM engines for ORZ - logger.info(" Creating vLLM inference engines for ORZ...") - from orz.exps.examples.ppo.ppo_base_exp import BasePPOExp - - # Use BasePPOExp helper to create engines - class TempExp(BasePPOExp): - def __init__(self): - self.cfg = orz_cfg - self.tokenizer = orz_tokenizer - self.strategy = orz_strategy - - temp_exp = TempExp() - vllm_engines = temp_exp.create_inference_engine() - logger.info(f" ✓ Created {len(vllm_engines)} vLLM engines") - - # Get colocate placement groups if needed - colocate_pg = temp_exp.get_colocate_pg if orz_cfg.colocate_all else None - - # Create ORZ trainer - orz_trainer = JerichoRewardTrainer( - cfg=orz_cfg, - strategy=orz_strategy, - tokenizer=orz_tokenizer, - train_dataset=orz_dataset, - eval_dataset=None, # No separate eval for now - vllm_engines=vllm_engines, - colocate_pg=colocate_pg - ) - - logger.info(" ✓ ORZ RayPPOTrainer initialized") - - # Run ORZ training for one episode - logger.info(f" Running ORZ PPO training (episode {current_iter // hybrid_cfg.llm_train_freq})...") - - # Train using ORZ's fit_episode method - # Note: This will do full PPO update with actor/critic training - await orz_trainer.fit_episode() - - logger.info(f" ✓ ORZ training completed for iteration {current_iter}") - - else: - logger.warning(" No training samples extracted from game_segments") - - except Exception as e: - logger.error(f" ✗ ORZ training failed: {e}") - import traceback - logger.error(traceback.format_exc()) - logger.warning(" Continuing with PriorZero LLM training only") - - # ============================================================== - # Step 5: Logging and Checkpointing - # ============================================================== - if current_iter % 10 == 0: - logger.info(f"\n{'='*60}") - logger.info(f"Progress Summary (Iter {current_iter})") - logger.info(f"{'='*60}") - logger.info(f"Env steps: {collector.envstep}") - logger.info(f"Game segments collected: {total_game_segments_collected}") - logger.info(f"Buffer size: {replay_buffer.get_num_of_transitions()} transitions") - logger.info(f"Best eval reward: {best_eval_reward:.2f}") - logger.info(f"{'='*60}\n") - - # Save checkpoint periodically - if enable_save and current_iter % 100 == 0 and current_iter > 0: - logger.info(f"[Iter {current_iter}] Saving checkpoint...") - learner.save_checkpoint(collector.envstep) - logger.info("✓ Checkpoint saved") - - except KeyboardInterrupt: - logger.info("\n⚠️ Training interrupted by user") - except Exception as e: - logger.error(f"\n❌ Training failed with error: {e}") - import traceback - traceback.print_exc() - raise - finally: - # ============================================================== - # Cleanup - # ============================================================== - logger.info("\nCleaning up...") - - # Save final checkpoint - if enable_save: - logger.info("Saving final checkpoint...") - try: - learner.save_checkpoint(collector.envstep) - except Exception as e: - logger.error(f"Failed to save checkpoint: {e}") - - # Close environments - try: - collector_env.close() - evaluator_env.close() - except Exception as e: - logger.error(f"Failed to close environments: {e}") - - # Close loggers - if tb_logger: - try: - tb_logger.close() - except Exception as e: - logger.error(f"Failed to close tensorboard: {e}") - - logger.info("✓ Cleanup complete") - logger.info("="*80) - logger.info("Training finished!") - logger.info(f"Total iterations: {learner.train_iter}") - logger.info(f"Total env steps: {collector.envstep}") - logger.info(f"Best eval reward: {best_eval_reward:.2f}") - logger.info("="*80) - - -# ============================================================================== -# Entry Point -# ============================================================================== - -async def main(): - """Main entry point.""" - # Create hybrid configuration - hybrid_cfg = HybridTrainingConfig() - - # Run training - await train_priorzero_orz_complete( - cfg=hybrid_cfg.priorzero_cfg, - create_cfg=hybrid_cfg.priorzero_create_cfg, - hybrid_cfg=hybrid_cfg, - seed=0, - max_train_iter=10000 if not DEBUG_MODE else 100, - enable_save=True, - ) - - -if __name__ == "__main__": - logger.info("="*80) - logger.info("PriorZero-ORZ Complete Training Pipeline") - logger.info("="*80) - logger.info(f"Debug mode: {DEBUG_MODE}") - logger.info(f"ORZ available: {ORZ_AVAILABLE}") - logger.info(f"vLLM available: {VLLM_AVAILABLE}") - logger.info("="*80) - - # Run async training - asyncio.run(main()) diff --git a/zoo/jericho/priorzero/priorzero_policy.py b/zoo/jericho/priorzero/priorzero_policy.py index 26e50e060..e0a54e8d6 100644 --- a/zoo/jericho/priorzero/priorzero_policy.py +++ b/zoo/jericho/priorzero/priorzero_policy.py @@ -1,944 +1,107 @@ -# priorzero_policy.py -""" -[PRIORZERO] PriorZero Policy Implementation - -This module implements the PriorZero policy that combines: -1. UniZero world model for planning in latent space -2. LLM policy model for providing high-quality action priors - -Key Features: -- Dual-model training: world model + LLM policy -- LLM-guided MCTS: inject LLM priors into MCTS root node -- SFT + RFT: supervised fine-tuning with MCTS policies + reinforcement fine-tuning with environment rewards -- Full alignment with UniZero implementation - -Author: PriorZero Team -Date: 2025-01-20 -""" - +import asyncio import copy +import inspect import re import sys import logging from pathlib import Path from typing import List, Dict, Any, Tuple, Union, Optional -# [CRITICAL] Ensure local LightZero is used -from ensure_local_lightzero import ensure_local_lightzero -ensure_local_lightzero() - import numpy as np import torch +import torch.distributed as dist import torch.nn.functional as F from ding.utils import POLICY_REGISTRY from ding.model import model_wrap -from transformers import AutoTokenizer, AutoModelForCausalLM -from peft import get_peft_model, LoraConfig, TaskType +import os # Import from local LightZero from lzero.policy.unizero import UniZeroPolicy as OriginalUniZeroPolicy -from lzero.policy import ( - phi_transform, - InverseScalarTransform, - scalar_transform, # [PRIORZERO] Added for reward/value transformation - DiscreteSupport, # [PRIORZERO] Added for categorical distribution support - to_torch_float_tensor, - mz_network_output_unpack -) +from lzero.policy import phi_transform, InverseScalarTransform, scalar_transform, DiscreteSupport +from lzero.policy import to_torch_float_tensor,mz_network_output_unpack, prepare_obs from lzero.policy.utils import select_action from lzero.mcts import UniZeroMCTSCtree as MCTSCtree from lzero.entry.utils import initialize_zeros_batch -# Import UniZeroModel to ensure it's registered in MODEL_REGISTRY -import lzero.model.unizero_model # noqa: F401 - - -# ============================================================================== -# Helper Functions for LLM Prior Processing -# ============================================================================== - -def parse_llm_action_ranking( - text: str, - action_map: Dict[str, int], - action_space_size: int, - fallback_to_uniform: bool = True -) -> np.ndarray: - """ - [PRIORZERO-NEW] - Parse LLM generated action ranking text into a policy distribution. - - Args: - text: LLM generated text with ranked actions (e.g., "1. take key\\n2. go north") - action_map: Mapping from action text to action index - action_space_size: Size of the action space - fallback_to_uniform: If True, return uniform distribution when no valid action found - - Returns: - policy: Probability distribution over actions (shape: [action_space_size]) - """ - # Extract ranked actions using regex - # Supports formats: "1. action", "1) action", "1: action" - ranked_actions = re.findall(r'(?:^|\n)\s*\d+[\.\):\s]+(.+?)(?=\n|$)', text, re.MULTILINE) - - policy = np.zeros(action_space_size, dtype=np.float32) - found_count = 0 - - for rank, action_text in enumerate(ranked_actions): - action_text = action_text.strip().lower() - - # Try exact match first - if action_text in action_map: - action_idx = action_map[action_text] - # Assign decreasing weights (higher rank = higher weight) - policy[action_idx] = len(ranked_actions) - rank - found_count += 1 - else: - # Try fuzzy matching (find best substring match) - best_match_score = 0 - best_action_idx = None - for candidate_text, candidate_idx in action_map.items(): - if candidate_text in action_text or action_text in candidate_text: - score = len(set(candidate_text.split()) & set(action_text.split())) - if score > best_match_score: - best_match_score = score - best_action_idx = candidate_idx - - if best_action_idx is not None: - policy[best_action_idx] = len(ranked_actions) - rank - found_count += 1 - - # Normalize to probability distribution - if policy.sum() > 0: - policy /= policy.sum() - elif fallback_to_uniform: - # If LLM didn't generate any valid actions, return uniform distribution - policy = np.ones(action_space_size, dtype=np.float32) / action_space_size - - return policy - - -def format_mcts_policy_to_text( - mcts_policy: np.ndarray, - action_inv_map: Dict[int, str], - top_k: int = 5 -) -> str: - """ - [PRIORZERO-NEW] - Convert MCTS policy vector into ranked action text for SFT training. - - Args: - mcts_policy: MCTS visit count distribution (shape: [action_space_size]) - action_inv_map: Mapping from action index to action text - top_k: Number of top actions to include - - Returns: - Formatted text with ranked actions (e.g., "1. take key\\n2. go north\\n...") - """ - # Sort actions by policy probability (descending) - sorted_indices = np.argsort(mcts_policy)[::-1] - - output_lines = [] - rank = 1 - for idx in sorted_indices: - if mcts_policy[idx] > 0 and rank <= top_k: - action_text = action_inv_map.get(idx, f"action_{idx}") - output_lines.append(f"{rank}. {action_text}") - rank += 1 - - return "\n".join(output_lines) if output_lines else "No valid actions found." - - -def build_llm_prompt( - current_obs: str, - history: Optional[List[Tuple[str, str, float]]] = None, - action_descriptions: Optional[Dict[str, str]] = None, - use_cot: bool = True -) -> str: - """ - [PRIORZERO-NEW] - Build a high-quality prompt for LLM to generate action ranking. - - Args: - current_obs: Current observation text - history: List of (observation, action, reward) tuples - action_descriptions: Optional descriptions for each action - use_cot: Whether to encourage chain-of-thought reasoning - - Returns: - Formatted prompt string - """ - prompt_parts = [] - - # System instruction - prompt_parts.append( - "You are an expert player in a text-based adventure game. " - "Your goal is to maximize the score by taking the best actions." - ) - - # Add history if available - if history and len(history) > 0: - prompt_parts.append("\n=== Recent History ===") - for i, (obs, action, reward) in enumerate(history[-5:]): # Last 5 steps - prompt_parts.append(f"Step {i+1}:") - prompt_parts.append(f" Observation: {obs[:100]}...") # Truncate long obs - prompt_parts.append(f" Action: {action}") - prompt_parts.append(f" Reward: {reward}") - - # Current observation - prompt_parts.append("\n=== Current Situation ===") - prompt_parts.append(current_obs) - - # Task instruction - if use_cot: - prompt_parts.append( - "\n=== Task ===\n" - "Think step-by-step:\n" - "1. Analyze the current situation and your goal\n" - "2. Consider what actions might help you progress\n" - "3. Rank the best actions in order of priority\n" - "\nProvide your analysis and then list the top 5 actions in this format:\n" - "1. [first action]\n" - "2. [second action]\n" - "..." - ) - else: - prompt_parts.append( - "\n=== Task ===\n" - "List the top 5 best actions in order of priority:\n" - "1. [first action]\n" - "2. [second action]\n" - "..." - ) - - return "\n".join(prompt_parts) - - -# ============================================================================== -# PriorZero Policy Class -# ============================================================================== +import lzero.model.unizero_model @POLICY_REGISTRY.register('priorzero', force_overwrite=True) class PriorZeroPolicy(OriginalUniZeroPolicy): - """ - [PRIORZERO-MODIFIED] - PriorZero policy that combines UniZero world model with LLM policy. - - Architecture: - - UniZero World Model: Learns latent dynamics, value, and policy in latent space - - LLM Policy Model: Provides high-quality action priors based on language understanding - - Training: - - World Model: Trained with standard UniZero losses (value, policy, reward, latent) - - LLM: Trained with SFT (using MCTS policies) + RFT (using environment rewards) - - Inference: - - LLM generates action ranking → converted to policy prior - - Policy prior injected into MCTS root node - - MCTS search refines the policy → selects best action - """ - - config = dict( - **OriginalUniZeroPolicy.config, - # LLM-specific config - llm_policy_cfg=dict( - pretrain_llm_path="Qwen/Qwen1.5-1.8B-Chat", - use_lora=False, # Whether to use LoRA for efficient fine-tuning - lora_r=8, - lora_alpha=16, - lora_dropout=0.05, - llm_learning_rate=1e-6, - llm_weight_decay=0.01, - llm_loss_weight=0.5, # Weight of LLM loss in total loss - rft_loss_weight=0.3, # Weight of RFT loss in total loss - prompt_max_len=2048, - generate_max_len=128, - history_length=5, # Number of recent steps to include in prompt - use_cot=True, # Whether to use chain-of-thought prompting - sft_target='mcts_policy', # 'mcts_policy' or 'oracle_policy' - enable_rft=True, # Whether to enable RFT training - ), - ) - - def __init__(self, cfg: Dict, model: torch.nn.Module = None, enable_field: List[str] = None): - # [PRIORZERO-NEW] Initialize LLM-related attributes BEFORE super().__init__ - # because super().__init__ will call _init_learn which needs these attributes - self.llm_policy_model = None - self.llm_tokenizer = None - self._optimizer_llm = None - self._lr_scheduler_llm = None - self.llm_policy_cfg = cfg.llm_policy_cfg # Set from cfg, not self._cfg (not set yet) - - # Action mapping (will be set from config) - self.action_map = None # str -> int - self.action_inv_map = None # int -> str - - # Call parent init (this will trigger _init_learn, _init_collect, _init_eval) + def __init__(self, cfg: Dict, model: torch.nn.Module = None, enable_field: List[str] = None, **kwargs): super().__init__(cfg, model, enable_field) def _init_learn(self) -> None: - """ - [PRIORZERO-MODIFIED] - Initialize both UniZero world model and LLM policy model with their optimizers. - Align with UniZero implementation - use logging instead of self._logger. - """ - import logging - - # ====================================================================== - # 1. Initialize UniZero World Model (from parent class) - # ====================================================================== super()._init_learn() logging.info("✓ UniZero World Model and optimizer initialized") - # [PRIORZERO-FIX] Ensure scalar transform handles are initialized - # These are normally initialized in UniZeroPolicy.__init__ but we need to ensure they exist - if not hasattr(self, 'value_support') or self.value_support is None: - self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) - if not hasattr(self, 'reward_support') or self.reward_support is None: - self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) - if not hasattr(self, 'value_inverse_scalar_transform_handle'): - self.value_inverse_scalar_transform_handle = InverseScalarTransform( - self.value_support, self._cfg.model.categorical_distribution - ) - if not hasattr(self, 'reward_inverse_scalar_transform_handle'): - self.reward_inverse_scalar_transform_handle = InverseScalarTransform( - self.reward_support, self._cfg.model.categorical_distribution - ) - logging.info("✓ Scalar transform handles verified/initialized") - - # ====================================================================== - # 2. [PRIORZERO-NEW] Initialize LLM Policy Model - # ====================================================================== - logging.info(f"Loading LLM from: {self.llm_policy_cfg.pretrain_llm_path}") - - # Load tokenizer - self.llm_tokenizer = AutoTokenizer.from_pretrained( - self.llm_policy_cfg.pretrain_llm_path, - trust_remote_code=True, - padding_side='left' # For batch generation - ) - if self.llm_tokenizer.pad_token is None: - self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token - - # Load LLM - self.llm_policy_model = AutoModelForCausalLM.from_pretrained( - self.llm_policy_cfg.pretrain_llm_path, - trust_remote_code=True, - torch_dtype=torch.bfloat16, # Use bfloat16 to save memory - device_map=None, # We'll manually move to device - ) - - # Apply LoRA if enabled - if self.llm_policy_cfg.use_lora: - logging.info("Applying LoRA for parameter-efficient fine-tuning") - lora_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, - r=self.llm_policy_cfg.lora_r, - lora_alpha=self.llm_policy_cfg.lora_alpha, - lora_dropout=self.llm_policy_cfg.lora_dropout, - target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], # Qwen-specific - ) - self.llm_policy_model = get_peft_model(self.llm_policy_model, lora_config) - self.llm_policy_model.print_trainable_parameters() - - self.llm_policy_model.to(self._cfg.device) - self.llm_policy_model.train() - - # ====================================================================== - # 3. [PRIORZERO-NEW] Initialize LLM Optimizer - # ====================================================================== - self._optimizer_llm = torch.optim.AdamW( - self.llm_policy_model.parameters(), - lr=self.llm_policy_cfg.llm_learning_rate, - weight_decay=self.llm_policy_cfg.llm_weight_decay, - betas=(0.9, 0.999), - ) - - # Optional: learning rate scheduler - self._lr_scheduler_llm = torch.optim.lr_scheduler.CosineAnnealingLR( - self._optimizer_llm, - T_max=100000, # Will be set from config - eta_min=self.llm_policy_cfg.llm_learning_rate * 0.1 - ) - - logging.info(f"✓ LLM Policy Model ({self.llm_policy_cfg.pretrain_llm_path}) initialized") - logging.info(f" - LLM learning rate: {self.llm_policy_cfg.llm_learning_rate}") - logging.info(f" - LoRA enabled: {self.llm_policy_cfg.use_lora}") - - # ====================================================================== - # 4. [PRIORZERO-NEW] Load Action Mappings - # ====================================================================== - if hasattr(self._cfg, 'action_map') and self._cfg.action_map is not None: - self.action_map = self._cfg.action_map - self.action_inv_map = {v: k for k, v in self.action_map.items()} - logging.info(f"✓ Action mappings loaded ({len(self.action_map)} actions)") - else: - logging.warning("⚠ Action mappings not found in config. Will use index-based actions.") - # Fallback: create dummy mappings - action_space_size = self._cfg.model.action_space_size - self.action_inv_map = {i: f"action_{i}" for i in range(action_space_size)} - self.action_map = {v: k for k, v in self.action_inv_map.items()} - def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: - """ - [PRIORZERO-MODIFIED] - Dual-model training: UniZero world model + LLM policy. - - Training process: - 1. Train UniZero world model with standard losses (value, policy, reward, latent) - 2. Train LLM with SFT (supervised by MCTS policies) - 3. Optionally train LLM with RFT (reinforced by environment rewards) - 4. Joint optimization with combined loss - - Args: - data: Tuple containing (current_batch, target_batch, train_iter, game_segments) - - Returns: - log_dict: Dictionary of training metrics - """ - import logging - self._learn_model.train() - self.llm_policy_model.train() - - # Unpack data - # NOTE: game_segments is our custom GameSegment with mcts_policy_segment - # [FIX] Handle both 3-element (from buffer) and 4-element (with explicit train_iter) formats - if len(data) == 4: - # Format: [current_batch, target_batch, train_iter, game_segments] - # This is when learner explicitly adds train_iter - current_batch, target_batch, train_iter, game_segments = data - elif len(data) == 3: - # Format: [current_batch, target_batch, game_segments] - # This is the standard format from PriorZeroGameBuffer.sample() - current_batch, target_batch, game_segments = data - train_iter = self._train_iteration # Get from instance variable - import logging - logger = logging.getLogger(__name__) - logger.debug( - f"[PRIORZERO] Using 3-element format. game_segments: " - f"{type(game_segments)}, count: {len(game_segments) if game_segments else 0}" - ) - else: - raise ValueError(f"Unexpected data format: expected 3 or 4 elements, got {len(data)}") + self._target_model.train() - # ============================================================================== - # Part 1: UniZero World Model Training (Full Implementation) - # ============================================================================== + current_batch, target_batch, train_iter = data - # Unpack batches - (obs_batch_ori, action_batch, mask_batch, batch_index_tensor, - weights, make_time) = current_batch[:6] + # CoT reuse optimization: unpack cot_prefix_list (12 elements total) + obs_batch_ori, action_batch, target_action_batch, mask_batch, batch_index_tensor, weights, make_time, timestep_batch, raw_obs_list, history_obs_list, llm_prior_per_tok_list, cot_prefix_list, llm_action_list = current_batch target_reward, target_value, target_policy = target_batch + + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze( + -1).long() + timestep_batch = torch.from_numpy(timestep_batch).to(self._cfg.device).unsqueeze( + -1).long() - # Handle optional timestep - if len(current_batch) > 6: - timestep_batch = current_batch[6] - else: - timestep_batch = None - - # Convert to tensors and move to device data_list = [mask_batch, target_reward, target_value, target_policy, weights] - (mask_batch, target_reward, target_value, - target_policy, weights) = to_torch_float_tensor(data_list, self._cfg.device) + (mask_batch, target_reward, target_value, target_policy, weights) = to_torch_float_tensor(data_list, self._cfg.device) - # Reshape targets batch_size = self._cfg.batch_size target_reward = target_reward.view(batch_size, -1) target_value = target_value.view(batch_size, -1) - # Apply scalar transform (for value and reward) - # [FIX] Use scalar_transform function (not self.scalar_transform) - # scalar_transform is a standalone function imported from lzero.policy transformed_target_reward = scalar_transform(target_reward) transformed_target_value = scalar_transform(target_value) # Convert to categorical distribution (for distributional RL) - target_reward_categorical = phi_transform( - self.reward_support, transformed_target_reward - ) - target_value_categorical = phi_transform( - self.value_support, transformed_target_value - ) - - # Prepare batch for world model - # NOTE: This follows the exact format required by UniZero world model - # [FIX] Convert obs_batch_ori to tensor if needed - if not isinstance(obs_batch_ori, torch.Tensor): - # [DEBUG] Check obs_batch_ori shape - import logging - logger = logging.getLogger(__name__) - if isinstance(obs_batch_ori, np.ndarray): - logger.info(f"[DEBUG] obs_batch_ori type: numpy, shape: {obs_batch_ori.shape}, dtype: {obs_batch_ori.dtype}") - - # [FIX] Reshape if observations are flattened (2D instead of 3D) - # Expected: [batch_size, num_unroll_steps+1, obs_dim] (buffer includes next_obs) - # Got: [batch_size, (num_unroll_steps+1) * obs_dim] - if len(obs_batch_ori.shape) == 2: - # Infer num_unroll_steps and obs_dim - # For text: obs_dim should be max_seq_len (e.g., 512) - obs_dim = 512 # Standard max_seq_len for BERT - total_size = obs_batch_ori.shape[1] - if total_size % obs_dim == 0: - inferred_steps = total_size // obs_dim - # Simply reshape to [batch_size, inferred_steps, obs_dim] - # The truncation to match action_batch will happen later (like unizero.py line 675) - obs_batch_ori = obs_batch_ori.reshape(batch_size, inferred_steps, obs_dim) - logger.info(f"[RESHAPE] Reshaped obs_batch_ori from (batch_size, {total_size}) to {obs_batch_ori.shape}") - else: - logger.warning(f"[RESHAPE_ERROR] Cannot reshape: total_size ({total_size}) not divisible by obs_dim ({obs_dim})") - - # Check if it's an object array (inhomogeneous shapes) - if obs_batch_ori.dtype == np.object_: - logger.warning(f"[SHAPE_ISSUE] obs_batch_ori is object array - inhomogeneous shapes!") - logger.warning(f"[SHAPE_ISSUE] First element shape: {obs_batch_ori[0].shape if len(obs_batch_ori) > 0 else 'N/A'}") - if len(obs_batch_ori) > 1: - logger.warning(f"[SHAPE_ISSUE] Second element shape: {obs_batch_ori[1].shape}") - # Try to handle inhomogeneous array by padding/truncating - # For now, just raise a descriptive error - raise ValueError( - f"obs_batch_ori has inhomogeneous shapes. " - f"First element shape: {obs_batch_ori[0].shape}, " - f"Cannot directly convert to tensor. " - f"This suggests the replay buffer is storing observations with different sequence lengths." - ) - obs_batch_ori = torch.from_numpy(obs_batch_ori).to(self._cfg.device) - - # [FIX] Convert action_batch to tensor and handle shape correctly - if not isinstance(action_batch, torch.Tensor): - action_batch = torch.from_numpy(action_batch).to(self._cfg.device) - - if action_batch.shape[-1] == 1: - actions_processed = action_batch.squeeze(-1).long() - elif len(action_batch.shape) == 1: - actions_processed = action_batch.long() - else: - actions_processed = action_batch.long() - - if timestep_batch is not None: - # Convert timestep_batch to tensor if needed - if not isinstance(timestep_batch, torch.Tensor): - timestep_batch = torch.from_numpy(timestep_batch).to(self._cfg.device) - - # Handle timestep_batch shape - if timestep_batch.shape[-1] == 1: - timestep_processed = timestep_batch.squeeze(-1).long() - elif len(timestep_batch.shape) == 1: - timestep_processed = timestep_batch.long() - else: - timestep_processed = timestep_batch.long() - - batch_for_gpt = { - 'observations': obs_batch_ori, - 'actions': actions_processed, - 'timestep': timestep_processed, - 'rewards': target_reward_categorical[:, :-1], - 'target_value': target_value_categorical[:, :-1], - 'target_policy': target_policy[:, :-1], - } - else: - batch_for_gpt = { - 'observations': obs_batch_ori, - 'actions': actions_processed, - 'rewards': target_reward_categorical[:, :-1], - 'target_value': target_value_categorical[:, :-1], - 'target_policy': target_policy[:, :-1], - } - - # [FIX] Following unizero.py lines 673-675 exactly: - # Convert mask_batch to boolean, then truncate to align with observations/rewards - batch_for_gpt['mask_padding'] = mask_batch == 1.0 # 0 means invalid padding data. Shape: (B, T) - - # [DEBUG] Log shapes before truncation - logger.info(f"[SHAPE_DEBUG] Before truncation: obs={batch_for_gpt['observations'].shape}, " - f"mask_padding={batch_for_gpt['mask_padding'].shape}, " - f"actions={batch_for_gpt['actions'].shape}") - - # [CRITICAL] Truncate observations to align with rewards/actions - # - observations from buffer include next_obs → shape (B, T+1, obs_dim) - # - mask_padding is already (B, T) from buffer - DO NOT truncate again! - # - After target processing: rewards[:, :-1] → (B, T-1) - # - So only observations need truncation - batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] # Shape: (B, T-1, obs_dim) - - # [FIX] Check if mask_padding needs truncation based on actual shape - if batch_for_gpt['mask_padding'].shape[1] > batch_for_gpt['observations'].shape[1]: - logger.warning(f"[SHAPE_FIX] Truncating mask_padding from {batch_for_gpt['mask_padding'].shape} to match obs") - batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] - - logger.info(f"[SHAPE_DEBUG] After truncation: obs={batch_for_gpt['observations'].shape}, " - f"mask_padding={batch_for_gpt['mask_padding'].shape}") - - # [FIX] Add missing 'ends' field (following unizero.py line 676) - # 'ends' marks terminal states in the trajectory (0 = not terminal) + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + batch_for_gpt = { + 'actions': action_batch.squeeze(-1), + 'timestep': timestep_batch.squeeze(-1), + 'rewards': target_reward_categorical[:, :-1], + 'target_value': target_value_categorical[:, :-1], + 'target_policy': target_policy[:, :-1], + } + if isinstance(self._cfg.model.observation_shape, int) or len(self._cfg.model.observation_shape) == 1: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + self._cfg.batch_size, -1, self._cfg.model.observation_shape) + elif len(self._cfg.model.observation_shape) == 3: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + self._cfg.batch_size, -1, *self._cfg.model.observation_shape) + + batch_for_gpt['mask_padding'] = mask_batch == 1.0 + batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] + batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] batch_for_gpt['ends'] = torch.zeros(batch_for_gpt['mask_padding'].shape, dtype=torch.long, device=self._cfg.device) - - # [FIX] Add 'scalar_target_value' field for priority calculation (following unizero.py line 681) batch_for_gpt['scalar_target_value'] = target_value - # [FIX] Log shapes for debugging - import logging - logger = logging.getLogger(__name__) - logger.info(f"[BATCH_SHAPES] obs: {batch_for_gpt['observations'].shape}, actions: {batch_for_gpt['actions'].shape}, rewards: {batch_for_gpt['rewards'].shape}, mask_padding: {batch_for_gpt['mask_padding'].shape}") - - # Compute world model loss - wm_losses = self._learn_model.world_model.compute_loss( + wm_losses, pred_values = self._learn_model.world_model.compute_loss( batch_for_gpt, self._target_model.world_model.tokenizer, self.value_inverse_scalar_transform_handle, ) - # Weighted world model loss (for prioritized experience replay) wm_total_loss = (weights * wm_losses.loss_total).mean() - - # ============================================================================== - # Part 2: [PRIORZERO-NEW] LLM Policy Training (SFT + RFT) - # ============================================================================== - - llm_sft_loss = torch.tensor(0.0, device=self._cfg.device) - llm_rft_loss = torch.tensor(0.0, device=self._cfg.device) - num_sft_samples = 0 - num_rft_samples = 0 - # [FIX] Only perform LLM training if game_segments available - # [DEBUG] Always log game_segments status - logger = logging.getLogger(__name__) - logger.info(f"[LLM Training] game_segments type: {type(game_segments)}, " - f"is None: {game_segments is None}, " - f"len: {len(game_segments) if game_segments is not None else 'N/A'}") - - # [DEBUG] Check first segment's data - if game_segments is not None and len(game_segments) > 0: - seg0 = game_segments[0] - logger.info(f"[LLM Training] First segment stats: " - f"mcts_policies={len(seg0.mcts_policy_segment) if hasattr(seg0, 'mcts_policy_segment') else 0}, " - f"raw_obs={len([x for x in (seg0.raw_obs_segment if hasattr(seg0, 'raw_obs_segment') else []) if x is not None])}/{len(seg0.raw_obs_segment) if hasattr(seg0, 'raw_obs_segment') else 0}, " - f"actions={len(seg0.action_segment) if hasattr(seg0, 'action_segment') else 0}") - - if game_segments is not None and len(game_segments) > 0: - # Collect training data from game segments - sft_prompts = [] - sft_targets = [] - rft_prompts = [] - rft_rewards = [] - - # [DEBUG] Log segment information - logger.info(f"[LLM Training] Processing {len(game_segments)} game segments") - - for seg_idx, segment in enumerate(game_segments): - # [FIX] Use action_segment length, not obs_segment - # obs_segment includes frame_stack + unroll_steps, while - # mcts_policy_segment only has entries for actual actions taken - segment_length = len(segment.action_segment) - - # [FIX] Ensure mcts_policy_segment has the same length - # It might be a list or numpy array depending on whether game_segment_to_array() was called - mcts_policy_length = len(segment.mcts_policy_segment) if hasattr(segment, 'mcts_policy_segment') else 0 - - # [DEBUG] Log segment lengths for debugging - if self._cfg.get('debug_segment_processing', False): - obs_len = len(segment.obs_segment) if hasattr(segment, 'obs_segment') else 0 - raw_obs_len = len(segment.raw_obs_segment) if hasattr(segment, 'raw_obs_segment') else 0 - logging.info( - f"[Segment {seg_idx}] action_len={segment_length}, " - f"mcts_policy_len={mcts_policy_length}, obs_len={obs_len}, raw_obs_len={raw_obs_len}" - ) - - # [SAFETY] Use the minimum of the two lengths to avoid IndexError - max_index = min(segment_length, mcts_policy_length) - - if max_index == 0: - if self._cfg.get('debug_segment_processing', False): - logging.warning(f"[Segment {seg_idx}] Empty segment, skipping") - continue # Skip empty segments - - for i in range(max_index): - # [FIX] Safe access to mcts_policy_segment with bounds check - try: - mcts_policy = segment.mcts_policy_segment[i] - except (IndexError, KeyError, TypeError) as e: - # Log detailed error information for debugging - if self._cfg.get('debug_segment_processing', False): - logging.error( - f"[Segment {seg_idx}, Index {i}] Failed to access mcts_policy_segment: {e}\n" - f" segment_length={segment_length}, mcts_policy_length={mcts_policy_length}\n" - f" mcts_policy_segment type: {type(segment.mcts_policy_segment)}" - ) - continue - - # Skip if no MCTS policy available - if mcts_policy is None: - continue - - # [FIX] Use raw_obs_segment for text observations - # PriorZero's GameSegment stores raw text in raw_obs_segment - raw_obs_text = None - if hasattr(segment, 'raw_obs_segment') and i < len(segment.raw_obs_segment): - raw_obs_text = segment.raw_obs_segment[i] - elif i < len(segment.obs_segment): - # Fallback to obs_segment if raw_obs_segment not available - raw_obs_text = str(segment.obs_segment[i]) - - # Skip if raw_obs_text is None - if raw_obs_text is None: - continue - - # Build history context - history = [] - for j in range(max(0, i - self.llm_policy_cfg.history_length), i): - # [FIX] Use raw_obs_segment for history as well - obs_text = None - if hasattr(segment, 'raw_obs_segment') and j < len(segment.raw_obs_segment): - obs_text = segment.raw_obs_segment[j] - elif j < len(segment.obs_segment): - obs_text = str(segment.obs_segment[j]) - - if obs_text is not None and j < len(segment.action_segment): - history.append(( - obs_text, - self.action_inv_map.get(segment.action_segment[j], f"action_{segment.action_segment[j]}"), - float(segment.reward_segment[j]) if j < len(segment.reward_segment) else 0.0 - )) - - # Build prompt - instruction = build_llm_prompt( - current_obs=raw_obs_text, - history=history, - use_cot=self.llm_policy_cfg.use_cot - ) - - # Apply chat template - prompt = self.llm_tokenizer.apply_chat_template( - [{"role": "user", "content": instruction}], - tokenize=False, - add_generation_prompt=True - ) - - # ============================================================ - # SFT: Supervised Fine-Tuning with MCTS Policy - # ============================================================ - if self.llm_policy_cfg.sft_target == 'mcts_policy': - # [FIX] Use the mcts_policy we already safely retrieved above - # Don't access segment.mcts_policy_segment[i] again to avoid IndexError - mcts_policy_vec = mcts_policy - - # Convert MCTS policy to ranked action text - target_text = format_mcts_policy_to_text( - mcts_policy_vec, - self.action_inv_map, - top_k=5 - ) - - sft_prompts.append(prompt) - sft_targets.append(target_text) - num_sft_samples += 1 - - # ============================================================ - # RFT: Reinforcement Fine-Tuning with Environment Reward - # ============================================================ - if self.llm_policy_cfg.enable_rft and i < len(segment.reward_segment): - env_reward = float(segment.reward_segment[i]) - - # TODO - # Only use transitions with non-zero reward for RFT - if abs(env_reward) > 1e-9: - rft_prompts.append(prompt) - rft_rewards.append(env_reward) - num_rft_samples += 1 - - # ============================================================ - # Train LLM with SFT (with gradient accumulation for memory efficiency) - # ============================================================ - # num_sft_samples=0 # TODO - if num_sft_samples > 0: - # [PRIORZERO-OOM-FIX] Use micro-batching with gradient accumulation - micro_batch_size = self.llm_policy_cfg.llm_micro_batch_size - num_micro_batches = (num_sft_samples + micro_batch_size - 1) // micro_batch_size - accumulation_steps = self.llm_policy_cfg.llm_gradient_accumulation_steps - - # Prepare full texts (prompt + target + eos) - full_texts = [ - p + t + self.llm_tokenizer.eos_token - for p, t in zip(sft_prompts, sft_targets) - ] - - # Process in micro-batches - accumulated_sft_loss = 0.0 - for micro_batch_idx in range(num_micro_batches): - start_idx = micro_batch_idx * micro_batch_size - end_idx = min((micro_batch_idx + 1) * micro_batch_size, num_sft_samples) - - # Get micro-batch - micro_batch_texts = full_texts[start_idx:end_idx] - micro_batch_prompts = sft_prompts[start_idx:end_idx] - - # Tokenize micro-batch - inputs = self.llm_tokenizer( - micro_batch_texts, - padding=True, - truncation=True, - max_length=self.llm_policy_cfg.prompt_max_len, - return_tensors="pt" - ).to(self._cfg.device) - - # Create labels (mask prompt tokens to only compute loss on target) - labels = inputs.input_ids.clone() - labels[labels == self.llm_tokenizer.pad_token_id] = -100 - - # Mask prompt tokens - for i, prompt in enumerate(micro_batch_prompts): - prompt_tokens = self.llm_tokenizer.encode(prompt, add_special_tokens=False) - prompt_len = len(prompt_tokens) - labels[i, :prompt_len] = -100 - - # Forward pass - llm_outputs = self.llm_policy_model( - input_ids=inputs.input_ids, - attention_mask=inputs.attention_mask, - labels=labels - ) - - # Scale loss by number of accumulation steps (for correct gradient magnitude) - micro_batch_loss = llm_outputs.loss / accumulation_steps - accumulated_sft_loss += micro_batch_loss.item() - - # Backward pass (accumulate gradients) - micro_batch_loss.backward() - - # Free memory - del inputs, labels, llm_outputs - torch.cuda.empty_cache() - - # Average loss for logging - llm_sft_loss = torch.tensor(accumulated_sft_loss, device=self._cfg.device) - - # ============================================================ - # Train LLM with RFT (Policy Gradient with gradient accumulation) - # ============================================================ - if num_rft_samples > 0 and self.llm_policy_cfg.enable_rft: - # [PRIORZERO-OOM-FIX] Use micro-batching with gradient accumulation - micro_batch_size = self.llm_policy_cfg.llm_micro_batch_size - num_micro_batches = (num_rft_samples + micro_batch_size - 1) // micro_batch_size - accumulation_steps = self.llm_policy_cfg.llm_gradient_accumulation_steps - - # Process in micro-batches - accumulated_rft_loss = 0.0 - for micro_batch_idx in range(num_micro_batches): - start_idx = micro_batch_idx * micro_batch_size - end_idx = min((micro_batch_idx + 1) * micro_batch_size, num_rft_samples) - - # Get micro-batch - micro_batch_prompts = rft_prompts[start_idx:end_idx] - micro_batch_rewards = rft_rewards[start_idx:end_idx] - - # Tokenize prompts - inputs = self.llm_tokenizer( - micro_batch_prompts, - padding=True, - truncation=True, - max_length=self.llm_policy_cfg.prompt_max_len, - return_tensors="pt" - ).to(self._cfg.device) - - # [FIX] Forward pass WITH gradient tracking (remove no_grad) - outputs = self.llm_policy_model( - input_ids=inputs.input_ids, - attention_mask=inputs.attention_mask - ) - - # Compute policy gradient loss (REINFORCE) - # Loss = -reward * log_prob(action) - logits = outputs.logits - log_probs = F.log_softmax(logits, dim=-1) - - # Get log probability of actual tokens - shifted_log_probs = log_probs[:, :-1, :].contiguous() - shifted_labels = inputs.input_ids[:, 1:].contiguous() - - # Gather log probs of actual tokens - token_log_probs = shifted_log_probs.gather( - dim=-1, - index=shifted_labels.unsqueeze(-1) - ).squeeze(-1) - - # Mask padding tokens - mask = (shifted_labels != self.llm_tokenizer.pad_token_id).float() - token_log_probs = token_log_probs * mask - - # Sum log probs per sequence - sequence_log_probs = token_log_probs.sum(dim=-1) / (mask.sum(dim=-1) + 1e-8) - - # Compute REINFORCE loss for micro-batch - rewards_tensor = torch.tensor( - micro_batch_rewards, - device=self._cfg.device, - dtype=torch.float32 - ) - - # Normalize rewards within micro-batch (important for stable training) - if len(micro_batch_rewards) > 1: - rewards_tensor = (rewards_tensor - rewards_tensor.mean()) / (rewards_tensor.std() + 1e-8) - - micro_batch_rft_loss = -(rewards_tensor * sequence_log_probs).mean() / accumulation_steps - accumulated_rft_loss += micro_batch_rft_loss.item() - - # Backward pass (accumulate gradients) - micro_batch_rft_loss.backward() - - # Free memory - del inputs, outputs, logits, log_probs, rewards_tensor - torch.cuda.empty_cache() - - # Average loss for logging - llm_rft_loss = torch.tensor(accumulated_rft_loss, device=self._cfg.device) - - # ============================================================================== - # Part 3: Joint Optimization - # ============================================================================== - - # [PRIORZERO-OOM-FIX] Note: LLM gradients already accumulated via micro-batching above - # Only need to compute world model gradients here - - # Combine losses (for logging only - LLM loss already backpropagated) - llm_loss = ( - self.llm_policy_cfg.llm_loss_weight * llm_sft_loss + - self.llm_policy_cfg.rft_loss_weight * llm_rft_loss - ) - total_loss = wm_total_loss + llm_loss # For logging - - # Zero world model gradients only (LLM gradients already accumulated) self._optimizer_world_model.zero_grad() - - # Backward pass for world model only wm_total_loss.backward() - - # Gradient clipping for both models wm_grad_norm = torch.nn.utils.clip_grad_norm_( self._learn_model.world_model.parameters(), self._cfg.grad_clip_value ) - llm_grad_norm = torch.nn.utils.clip_grad_norm_( - self.llm_policy_model.parameters(), - self._cfg.grad_clip_value - ) - - # Optimizer step for both models + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) self._optimizer_world_model.step() - self._optimizer_llm.step() # Apply accumulated LLM gradients - - # Zero LLM gradients after step (ready for next iteration) - self._optimizer_llm.zero_grad() - - # Learning rate scheduler step (optional) - if self._lr_scheduler_llm is not None: - self._lr_scheduler_llm.step() - - # Update target model (soft update) self._target_model.update(self._learn_model.state_dict()) - # ============================================================================== - # Part 4: Logging (Aligned with UniZero) - # ============================================================================== - - # Extract intermediate losses from world model (like UniZero) intermediate_losses = wm_losses.intermediate_losses obs_loss = intermediate_losses.get('loss_obs', torch.tensor(0.0)) reward_loss = intermediate_losses.get('loss_rewards', torch.tensor(0.0)) @@ -952,15 +115,6 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in middle_step_losses = intermediate_losses.get('middle_step_losses', {}) last_step_losses = intermediate_losses.get('last_step_losses', {}) - # Analysis metrics (dormant ratio, weight magnitude, etc.) - dormant_ratio_encoder = intermediate_losses.get('dormant_ratio_encoder', 0.0) - dormant_ratio_transformer = intermediate_losses.get('dormant_ratio_transformer', 0.0) - dormant_ratio_head = intermediate_losses.get('dormant_ratio_head', 0.0) - avg_weight_mag_encoder = intermediate_losses.get('avg_weight_mag_encoder', 0.0) - avg_weight_mag_transformer = intermediate_losses.get('avg_weight_mag_transformer', 0.0) - avg_weight_mag_head = intermediate_losses.get('avg_weight_mag_head', 0.0) - e_rank_last_linear = intermediate_losses.get('e_rank_last_linear', 0.0) - e_rank_sim_norm = intermediate_losses.get('e_rank_sim_norm', 0.0) latent_state_l2_norms = intermediate_losses.get('latent_state_l2_norms', torch.tensor(0.0)) latent_action_l2_norms = intermediate_losses.get('latent_action_l2_norms', 0.0) @@ -989,16 +143,16 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # Build comprehensive log dict (aligned with UniZero) log_dict = { # ============ Core Losses ============ - 'weighted_total_loss': wm_total_loss.item(), - 'obs_loss': obs_loss.item() if torch.is_tensor(obs_loss) else obs_loss, - 'reward_loss': reward_loss.item() if torch.is_tensor(reward_loss) else reward_loss, - 'policy_loss': policy_loss.item() if torch.is_tensor(policy_loss) else policy_loss, - 'value_loss': value_loss.item() if torch.is_tensor(value_loss) else value_loss, - 'latent_recon_loss': latent_recon_loss.item() if torch.is_tensor(latent_recon_loss) else latent_recon_loss, - 'perceptual_loss': perceptual_loss.item() if torch.is_tensor(perceptual_loss) else perceptual_loss, - 'orig_policy_loss': orig_policy_loss.item() if torch.is_tensor(orig_policy_loss) else orig_policy_loss, - 'policy_entropy': policy_entropy.item() if torch.is_tensor(policy_entropy) else policy_entropy, - 'target_policy_entropy': average_target_policy_entropy.item(), + 'wm_total_loss': wm_total_loss.item(), + 'wm_obs_loss': obs_loss.item() if torch.is_tensor(obs_loss) else obs_loss, + 'wm_reward_loss': reward_loss.item() if torch.is_tensor(reward_loss) else reward_loss, + 'wm_policy_loss': policy_loss.item() if torch.is_tensor(policy_loss) else policy_loss, + 'wm_value_loss': value_loss.item() if torch.is_tensor(value_loss) else value_loss, + 'wm_latent_recon_loss': latent_recon_loss.item() if torch.is_tensor(latent_recon_loss) else latent_recon_loss, + 'wm_perceptual_loss': perceptual_loss.item() if torch.is_tensor(perceptual_loss) else perceptual_loss, + 'wm_orig_policy_loss': orig_policy_loss.item() if torch.is_tensor(orig_policy_loss) else orig_policy_loss, + 'wm_policy_entropy': policy_entropy.item() if torch.is_tensor(policy_entropy) else policy_entropy, + 'wm_target_policy_entropy': average_target_policy_entropy.item(), # ============ Step-wise Losses ============ @@ -1018,14 +172,6 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in 'analysis/last_step_loss_obs': last_step_losses.get('loss_obs', torch.tensor(0.0)).item() if isinstance(last_step_losses.get('loss_obs'), torch.Tensor) else 0.0, # ============ Analysis Metrics ============ - 'analysis/dormant_ratio_encoder': dormant_ratio_encoder, - 'analysis/dormant_ratio_transformer': dormant_ratio_transformer, - 'analysis/dormant_ratio_head': dormant_ratio_head, - 'analysis/avg_weight_mag_encoder': avg_weight_mag_encoder, - 'analysis/avg_weight_mag_transformer': avg_weight_mag_transformer, - 'analysis/avg_weight_mag_head': avg_weight_mag_head, - 'analysis/e_rank_last_linear': e_rank_last_linear, - 'analysis/e_rank_sim_norm': e_rank_sim_norm, 'analysis/latent_state_l2_norms': latent_state_l2_norms.item() if torch.is_tensor(latent_state_l2_norms) else latent_state_l2_norms, 'analysis/latent_action_l2_norms': latent_action_l2_norms, @@ -1043,66 +189,22 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in 'temperature_policy': temperature_policy, # ============ Targets ============ - 'target_reward': target_reward.mean().item(), - 'target_value': target_value.mean().item(), + 'wm_target_reward': target_reward.mean().item(), + 'wm_target_value': target_value.mean().item(), 'transformed_target_reward': transformed_target_reward.mean().item(), 'transformed_target_value': transformed_target_value.mean().item(), 'value_priority': value_priority_np.mean().item(), 'value_priority_orig': value_priority_np, # ============ Gradient Norms ============ - 'total_grad_norm_before_clip_wm': wm_grad_norm.item(), - 'llm_grad_norm': llm_grad_norm.item(), + 'wm_grad_norm': wm_grad_norm.item(), # ============ Learning Rates ============ 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], - 'llm_lr': self._optimizer_llm.param_groups[0]['lr'], - - # ============ [PRIORZERO] LLM-specific Metrics ============ - 'llm_sft_loss': llm_sft_loss.item(), - 'llm_rft_loss': llm_rft_loss.item(), - 'llm_total_loss': llm_loss.item(), - 'num_sft_samples': float(num_sft_samples), - 'num_rft_samples': float(num_rft_samples), - 'total_loss': total_loss.item(), } - # ============================================================================== - # [PRIORZERO-NEW] WandB Logging (if enabled) - # ============================================================================== - if self._cfg.get('use_wandb', False): - try: - import wandb - if wandb.run is not None: - # Log all metrics to WandB with hierarchical naming - wandb.log({ - # World Model Metrics - 'train/wm/total_loss': log_dict['wm_total_loss'], - 'train/wm/value_loss': log_dict['wm_value_loss'], - 'train/wm/policy_loss': log_dict['wm_policy_loss'], - 'train/wm/reward_loss': log_dict['wm_reward_loss'], - 'train/wm/grad_norm': log_dict['wm_grad_norm'], - 'train/wm/learning_rate': log_dict['wm_lr'], - - # LLM Policy Metrics - 'train/llm/sft_loss': log_dict['llm_sft_loss'], - 'train/llm/rft_loss': log_dict['llm_rft_loss'], - 'train/llm/total_loss': log_dict['llm_total_loss'], - 'train/llm/grad_norm': log_dict['llm_grad_norm'], - 'train/llm/learning_rate': log_dict['llm_lr'], - 'train/llm/num_sft_samples': float(log_dict['num_sft_samples']), - 'train/llm/num_rft_samples': float(log_dict['num_rft_samples']), - - # Combined Metrics - 'train/total_loss': log_dict['total_loss'], - }, step=self._train_iteration) - except Exception as e: - # Don't fail training if wandb logging fails - import logging - logging.warning(f"WandB logging failed: {e}") - return log_dict - + def _monitor_vars_learn(self) -> List[str]: """ [PRIORZERO-MODIFIED] @@ -1115,60 +217,15 @@ def _monitor_vars_learn(self) -> List[str]: """ return [ - # ============ LLM Loss Metrics ============ - 'llm_sft_loss', # Supervised fine-tuning loss - 'llm_rft_loss', # Reinforcement fine-tuning loss - 'llm_total_loss', # Combined LLM loss - 'llm_grad_norm', # LLM gradient norm - 'llm_lr', # LLM learning rate - - # ============ LLM Training Statistics ============ - 'num_sft_samples', # Number of SFT samples in batch - 'num_rft_samples', # Number of RFT samples in batch - # ============ Combined Metrics ============ - 'total_loss', # Total loss (WM + LLM) 'wm_total_loss', # World model total loss 'wm_grad_norm', # World model gradient norm - 'wm_lr', # World model learning rate - # ============ World Model Component Losses ============ 'wm_value_loss', 'wm_policy_loss', 'wm_reward_loss', 'wm_obs_loss', - 'analysis/dormant_ratio_encoder', - 'analysis/dormant_ratio_transformer', - 'analysis/dormant_ratio_head', - - 'analysis/avg_weight_mag_encoder', - 'analysis/avg_weight_mag_transformer', - 'analysis/avg_weight_mag_head', - 'analysis/e_rank_last_linear', - 'analysis/e_rank_sim_norm', - - 'analysis/latent_state_l2_norms', - 'analysis/l2_norm_before', - 'analysis/l2_norm_after', - 'analysis/grad_norm_before', - 'analysis/grad_norm_after', - - 'analysis/first_step_loss_value', - 'analysis/first_step_loss_policy', - 'analysis/first_step_loss_rewards', - 'analysis/first_step_loss_obs', - - 'analysis/middle_step_loss_value', - 'analysis/middle_step_loss_policy', - 'analysis/middle_step_loss_rewards', - 'analysis/middle_step_loss_obs', - - 'analysis/last_step_loss_value', - 'analysis/last_step_loss_policy', - 'analysis/last_step_loss_rewards', - 'analysis/last_step_loss_obs', - 'adaptive_alpha', "adaptive_target_entropy_ratio", 'alpha_loss', @@ -1179,62 +236,53 @@ def _monitor_vars_learn(self) -> List[str]: 'collect_mcts_temperature', 'cur_lr_world_model', 'cur_lr_tokenizer', - - 'weighted_total_loss', - 'obs_loss', - 'policy_loss', - 'orig_policy_loss', - 'policy_entropy', - 'latent_recon_loss', - 'target_policy_entropy', - 'reward_loss', - 'value_loss', + + 'wm_orig_policy_loss', + 'wm_policy_entropy', + 'wm_latent_recon_loss', + 'wm_target_policy_entropy', 'consistency_loss', 'value_priority', - 'target_reward', - 'target_value', + 'wm_target_reward', + 'wm_target_value', 'total_grad_norm_before_clip_wm', # tokenizer 'commitment_loss', 'reconstruction_loss', - 'perceptual_loss', - - - "logits_value_mean", - "logits_value_max", - "logits_value_min", - "logits_policy_mean", - "logits_policy_max", - "logits_policy_min", - - "temperature_value", - "temperature_reward", - "temperature_policy", - "current_policy_label_eps", - 'adaptive_alpha', - "adaptive_target_entropy_ratio", + 'wm_perceptual_loss', + + "logits_value_mean", + "logits_value_max", + "logits_value_min", + "logits_policy_mean", + "logits_policy_max", + "logits_policy_min", + + "temperature_value", + "temperature_reward", + "temperature_policy", + "current_policy_label_eps", + 'adaptive_alpha', + "adaptive_target_entropy_ratio", 'alpha_loss', "current_encoder_clip_value", - - # ==================== [新增] 添加范数和中间张量监控变量 ==================== - # 模块总范数 - 'norm/encoder/_total_norm', - 'norm/transformer/_total_norm', - 'norm/head_value/_total_norm', - 'norm/head_reward/_total_norm', - 'norm/head_policy/_total_norm', - # 中间张量 x 的统计信息 - 'norm/x_token/mean', - 'norm/x_token/std', - 'norm/x_token/max', - 'norm/x_token/min', ] - # 注意:我们不把每一层的范数都加到这里,因为数量太多会导致日志混乱。 - # 在实践中,如果通过总范数发现问题,可以临时在TensorBoard中搜索特定层的范数, - # 或者在本地打印 `norm_log_dict` 来进行详细分析。 - # wandb等工具可以更好地处理大量的动态指标。 # ======================================================================== + def pad_to_fixed_length(self, data, target_len=55, pad_val=-1e9, dtype=torch.float32): + """ + data: List[Sequence[Number]],每个元素长度可以不一样(比如 3 或 4) + 返回: tensor, 形状 [B, target_len],多余部分全是 pad_val + """ + batch_size = len(data) + out = torch.full((batch_size, target_len), pad_val, dtype=dtype) + for i, seq in enumerate(data): + if isinstance(seq, np.ndarray): + seq = seq.tolist() + L = min(len(seq), target_len) + if L > 0: + out[i, :L] = torch.tensor(seq[:L], dtype=dtype) + return out def _forward_collect( self, @@ -1244,223 +292,181 @@ def _forward_collect( to_play: List[int] = None, epsilon: float = 0.0, ready_env_id: List[int] = None, + timestep: List = [0], **kwargs ) -> Dict[int, Dict[str, Any]]: - """ - [PRIORZERO-MODIFIED] - Forward pass for data collection with LLM-guided MCTS. - - Process: - 1. Get LLM prior outputs from kwargs - 2. Parse LLM outputs into policy priors - 3. Run world model initial inference - 4. Inject LLM priors into MCTS root node (replace policy logits) - 5. Run MCTS search with LLM-guided priors - 6. Return best action and statistics - - Args: - data: Stacked observations (tensor) - action_mask: Action masks for each environment - temperature: Temperature for action selection - to_play: Player IDs (for multi-agent) - epsilon: Epsilon for epsilon-greedy exploration - ready_env_id: List of ready environment IDs - **kwargs: Additional arguments, including 'llm_prior_outputs' - - Returns: - output_dict: Dictionary mapping env_id to action and search statistics - """ self._collect_model.eval() - # ====================================================================== - # [PRIORZERO-NEW] Get LLM Prior Outputs - # ====================================================================== - llm_prior_outputs = kwargs.pop('llm_prior_outputs', None) - - if llm_prior_outputs is None: - # If no LLM prior available, fall back to standard UniZero behavior + llm_prior_logprob = kwargs.pop('llm_prior_logprob', None) + valid_actions_list = kwargs.get('valid_actions_list', None) + if not any(llm_prior_logprob): logging.debug("No LLM priors provided, using standard UniZero MCTS") return super()._forward_collect( data, action_mask, temperature, to_play, epsilon, - ready_env_id=ready_env_id, **kwargs + ready_env_id=ready_env_id, timestep=timestep ) - - # ====================================================================== - # Parse LLM Outputs into Policy Priors - # ====================================================================== + self._collect_mcts_temperature = temperature + self._collect_epsilon = epsilon + active_collect_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + output = {i: None for i in ready_env_id} + policy_priors = [] - for output in llm_prior_outputs: - # Extract generated text - generated_text = output.outputs[0].text if hasattr(output, 'outputs') else str(output) - - # Parse into policy distribution - prior_policy = parse_llm_action_ranking( - generated_text, - self.action_map, - self._cfg.model.action_space_size, - fallback_to_uniform=True - ) - - # Convert to log probabilities (for compatibility with MCTS) - policy_logits = torch.log(torch.from_numpy(prior_policy) + 1e-9) - policy_priors.append(policy_logits) - - policy_priors = torch.stack(policy_priors).to(self._cfg.device) - - # ====================================================================== - # World Model Initial Inference - # ====================================================================== + for env_id in range(active_collect_env_num): + actions = valid_actions_list[env_id] + prior = [] + if len(actions) == 0: + print("When valid actions is None, the action must be 'go'") + prior.append(llm_prior_logprob[env_id]['go']) + else: + for action in actions: + prior.append(llm_prior_logprob[env_id][action]) + policy_priors.append(prior) + policy_priors = self.pad_to_fixed_length(data=policy_priors, target_len=self.cfg.model.action_space_size, pad_val=-1e9) + with torch.no_grad(): - # Run representation network to get latent state - network_output = self._collect_model.initial_inference(data) + network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, timestep) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) - # Unpack network outputs - latent_state_roots, reward_roots, pred_values, policy_logits_roots = \ - mz_network_output_unpack(network_output) - - # [PRIORZERO-KEY] Replace policy logits with LLM priors network_output.policy_logits = policy_priors - - # Prepare for MCTS if not self._cfg.mcts_ctree: - # Python implementation (not recommended for performance) raise NotImplementedError("Python MCTS not supported for PriorZero") # ====================================================================== # MCTS Search with LLM-Guided Priors # ====================================================================== - # This is the key part where LLM priors guide the search - - # [FIX] Align with UniZero: construct legal_actions from action_mask - active_collect_env_num = len(ready_env_id) - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] - for j in range(active_collect_env_num)] - - # Get timestep if available - timestep = kwargs.get('timestep', None) - - # [FIX] Align with UniZero: transform values and prepare data pred_values_np = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() latent_state_roots_np = latent_state_roots.detach().cpu().numpy() - # reward_roots_np = reward_roots.detach().cpu().numpy() - policy_logits_for_mcts = policy_priors.detach().cpu().numpy().tolist() - - # [FIX] Align with UniZero: Create MCTS roots with legal_actions (not action_space_size) - roots = MCTSCtree.roots(active_collect_env_num, legal_actions) - - # [FIX] Align with UniZero: noises based on number of valid actions per environment + policy_logits = policy_priors.detach().cpu().numpy().tolist() + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] noises = [ np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) - ).astype(np.float32).tolist() - for j in range(active_collect_env_num) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) ] + roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots_np, to_play, timestep=timestep) - # [FIX] Align with UniZero: prepare roots (note reward_roots_np, not list(pred_values_np)) - roots.prepare( - self._cfg.root_noise_weight, - noises, - reward_roots, - # reward_roots_np, - policy_logits_for_mcts, - to_play if to_play is not None else [-1] * active_collect_env_num, - ) - - # Run MCTS search - MCTSCtree(self._cfg).search( - roots, - self._collect_model, - latent_state_roots_np, - reward_roots, - to_play if to_play is not None else [-1] * latent_state_roots_np.shape[0], - ) - - # Extract search results roots_visit_count = roots.get_distributions() roots_values = roots.get_values() - # ====================================================================== - # [PRIORZERO] Get valid_actions_list for dynamic action mapping - # ====================================================================== - valid_actions_list = kwargs.get('valid_actions_list', None) - - # ====================================================================== - # Select Actions and Prepare Output (Aligned with UniZero) - # ====================================================================== - output = {} - + batch_action = [] for i, env_id in enumerate(ready_env_id): - # [FIX] Get visit count distribution (only contains legal actions) distributions = roots_visit_count[i] value = roots_values[i] - # [FIX] Use select_action from UniZero (aligns with UniZero line 1115-1117) - # NOTE: Only legal actions possess visit counts, so action_index_in_legal_action_set - # represents the index within the legal action set, not the entire action set action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( distributions, - temperature=temperature if temperature is not None else self._collect_mcts_temperature, + temperature=self._collect_mcts_temperature, deterministic=False ) - # [FIX] Convert action_index_in_legal_action_set to the actual action in full action space - # (aligns with UniZero line 1119) legal_action_indices = np.where(action_mask[i] == 1.0)[0] action = legal_action_indices[action_index_in_legal_action_set] - # [PRIORZERO] Create dynamic action_inv_map for this specific state - # This maps action_index -> action_text using the current state's valid_actions - if valid_actions_list is not None and i < len(valid_actions_list): - dynamic_action_inv_map = { - idx: act_text - for idx, act_text in enumerate(valid_actions_list[i]) - } - else: - # Fallback to static mapping if valid_actions not available - dynamic_action_inv_map = self.action_inv_map - output[env_id] = { 'action': int(action), 'visit_count_distributions': distributions, 'visit_count_distribution_entropy': visit_count_distribution_entropy, 'searched_value': value, 'predicted_value': pred_values_np[i], - 'dynamic_action_inv_map': dynamic_action_inv_map, # [PRIORZERO] Include dynamic mapping + 'predicted_policy_logits': policy_logits[i], + 'timestep': timestep[i], } - + batch_action.append(action) + self.last_batch_obs = data + self.last_batch_action = batch_action return output + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, + ready_env_id: np.array = None, timestep: List = [0], **kwargs) -> Dict: + self._eval_model.eval() + llm_prior_logprob = kwargs.pop('llm_prior_logprob', None) + valid_actions_list = kwargs.get('valid_actions_list', None) + + if llm_prior_logprob is None or not any(llm_prior_logprob): + logging.debug("No LLM priors provided, using standard UniZero MCTS") + return super()._forward_eval( + data, action_mask, to_play=to_play, ready_env_id=ready_env_id, timestep=timestep + ) + + active_eval_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + output = {i: None for i in ready_env_id} + + policy_priors = [] + for env_id in range(active_eval_env_num): + actions = valid_actions_list[env_id] + prior = [] + if len(actions) == 0: + print("When valid actions is None, the action must be 'go'") + prior.append(llm_prior_logprob[env_id]['go']) + else: + for action in actions: + prior.append(llm_prior_logprob[env_id][action]) + policy_priors.append(prior) + policy_priors = self.pad_to_fixed_length(data=policy_priors, target_len=self.cfg.model.action_space_size, pad_val=-1e9) + + with torch.no_grad(): + network_output = self._eval_model.initial_inference(self.last_batch_obs_eval, self.last_batch_action, data, timestep) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + network_output.policy_logits = policy_priors - def _state_dict_learn(self) -> Dict[str, Any]: - """ - [PRIORZERO-MODIFIED] - Save state dict for both world model and LLM. - """ - state_dict = super()._state_dict_learn() + # if not in training, obtain the scalars of the value/reward + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_priors.detach().cpu().numpy().tolist() - # Add LLM model and optimizer - state_dict['llm_model'] = self.llm_policy_model.state_dict() - state_dict['optimizer_llm'] = self._optimizer_llm.state_dict() + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_eval_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + next_latent_state_with_env = self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, timestep) - if self._lr_scheduler_llm is not None: - state_dict['lr_scheduler_llm'] = self._lr_scheduler_llm.state_dict() + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} - return state_dict + batch_action = [] + + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + # print("roots_visit_count_distributions:", distributions, "root_value:", value) - def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: - """ - [PRIORZERO-MODIFIED] - Load state dict for both world model and LLM. - """ - super()._load_state_dict_learn(state_dict) + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than + # sampling during the evaluation phase. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the + # entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - # Load LLM model and optimizer - if 'llm_model' in state_dict: - self.llm_policy_model.load_state_dict(state_dict['llm_model']) - logging.info("✓ LLM model state loaded") + # Predict the next latent state based on the selected action and policy + next_latent_state = next_latent_state_with_env[i][action] - if 'optimizer_llm' in state_dict: - self._optimizer_llm.load_state_dict(state_dict['optimizer_llm']) - logging.info("✓ LLM optimizer state loaded") + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + 'timestep': timestep[i], + } + batch_action.append(action) + + self.last_batch_obs_eval = data + self.last_batch_action = batch_action - if 'lr_scheduler_llm' in state_dict and self._lr_scheduler_llm is not None: - self._lr_scheduler_llm.load_state_dict(state_dict['lr_scheduler_llm']) - logging.info("✓ LLM scheduler state loaded") + return output diff --git a/zoo/jericho/priorzero/priorzero_prompts.py b/zoo/jericho/priorzero/priorzero_prompts.py deleted file mode 100644 index 4ce7ef787..000000000 --- a/zoo/jericho/priorzero/priorzero_prompts.py +++ /dev/null @@ -1,399 +0,0 @@ -""" -PriorZero LLM Prompts Module - -This module provides optimized prompt templates for PriorZero's LLM policy, -based on the successful prompt structure from Open-Reasoner-Zero. - -Key Features: -- Structured reasoning with and tags -- Clear role definitions (User/Assistant paradigm) -- Explicit format examples to guide the LLM -- Game-specific context integration - -Author: PriorZero Team -Date: 2025-10-21 -""" - -from jinja2 import Template -from typing import List, Dict, Any, Optional - - -class PriorZeroPromptTemplates: - """ - Centralized prompt templates for PriorZero LLM policy. - - Prompt Structure: - 1. System instruction (role definition) - 2. Format specification ( and tags) - 3. Example format to prime the model - 4. User query with game state - 5. Start reasoning with "" tag - """ - - # ============================================================================== - # MCTS Policy Guidance Prompts - # ============================================================================== - - MCTS_POLICY_TEMPLATE = """\ -{{bos_token}}A conversation between User and Assistant. The User is playing a text adventure game \ -and needs to decide the next action. The Assistant carefully analyzes the current game state, \ -considers the available actions, and recommends the best action to take. \ -The reasoning process is enclosed within tags, and the recommended action \ -is enclosed within tags. For example: \ - The player is in a dark room and needs light. The lamp is available. \ - take lamp . \ - -User: Current game state: -{{game_state}} - -Available actions: -{{valid_actions}} - -Recent history: -{{history}} - -What is the best action to take? -Assistant: \ -""" - - # ============================================================================== - # Supervised Fine-Tuning (SFT) Prompts - Learning from MCTS Policy - # ============================================================================== - - SFT_FROM_MCTS_TEMPLATE = """\ -{{bos_token}}A conversation between User and Assistant. The User is playing a text adventure game. \ -The Assistant provides step-by-step reasoning and selects the best action based on MCTS search results. \ -The reasoning is in tags and the action is in tags. \ - -User: Game state: {{game_state}} -Available actions: {{valid_actions}} -MCTS recommended action: {{mcts_action}} -MCTS value estimate: {{mcts_value}} - -Please explain why this is the best action and then select it. -Assistant: \ -""" - - # ============================================================================== - # Reward Fine-Tuning (RFT) Prompts - Learning from Environment Rewards - # ============================================================================== - - RFT_TEMPLATE = """\ -{{bos_token}}A conversation between User and Assistant. The User is playing a text adventure game \ -and wants to maximize the total reward. The Assistant analyzes the game state, considers past rewards, \ -and selects actions that lead to higher rewards. \ -The reasoning is in tags and the action is in tags. \ - -User: Current game state: -{{game_state}} - -Available actions: -{{valid_actions}} - -Recent trajectory: -{{trajectory_with_rewards}} - -Cumulative reward so far: {{cumulative_reward}} - -What action should I take to maximize future rewards? -Assistant: \ -""" - - # ============================================================================== - # Evaluation Prompts - For Testing LLM Policy - # ============================================================================== - - EVAL_TEMPLATE = """\ -{{bos_token}}A conversation between User and Assistant. The User is playing a text adventure game. \ -The Assistant thinks carefully about the situation and provides the best action. \ -Format: reasoning action . \ - -User: {{game_state}} -Available actions: {{valid_actions}} -Assistant: \ -""" - - # ============================================================================== - # Few-Shot Learning Prompts - With Example Demonstrations - # ============================================================================== - - FEW_SHOT_TEMPLATE = """\ -{{bos_token}}A conversation between User and Assistant. The User is playing a text adventure game. \ -The Assistant learns from examples and applies similar reasoning to new situations. \ - -Example 1: -User: You are in a dark room. You can't see anything. -Available actions: [go north, take lamp, light lamp] -Assistant: I need light to see. I should take the lamp first, then light it. take lamp - -Example 2: -User: You are holding a lamp. It is dark. -Available actions: [go north, light lamp, drop lamp] -Assistant: I have the lamp but it's not lit. I should light it to see. light lamp - -Now your turn: -User: {{game_state}} -Available actions: {{valid_actions}} -Assistant: \ -""" - - -class PriorZeroPromptBuilder: - """ - Builder class for constructing prompts with specific game context. - """ - - def __init__(self, tokenizer): - """ - Initialize the prompt builder. - - Args: - tokenizer: HuggingFace tokenizer with bos_token - """ - self.tokenizer = tokenizer - self.templates = PriorZeroPromptTemplates() - - def _get_bos_token(self) -> str: - """Get the beginning-of-sequence token.""" - if self.tokenizer.bos_token_id is None: - return "" - return self.tokenizer.decode([self.tokenizer.bos_token_id]) - - def build_mcts_policy_prompt( - self, - game_state: str, - valid_actions: List[str], - history: Optional[List[Dict[str, Any]]] = None, - ) -> str: - """ - Build a prompt for MCTS policy guidance. - - Args: - game_state: Current observation text from the game - valid_actions: List of valid action strings - history: Recent trajectory [(obs, action, reward), ...] - - Returns: - Formatted prompt string - """ - # Format valid actions as a numbered list - actions_str = "\n".join([f"{i+1}. {action}" for i, action in enumerate(valid_actions)]) - - # Format history - if history is None or len(history) == 0: - history_str = "This is the beginning of the game." - else: - history_lines = [] - for i, step in enumerate(history[-5:]): # Last 5 steps - obs = step.get('observation', 'N/A') - action = step.get('action', 'N/A') - reward = step.get('reward', 0) - history_lines.append(f"Step {i+1}: Observation: {obs[:100]}... | Action: {action} | Reward: {reward}") - history_str = "\n".join(history_lines) - - # Render template - template = Template(self.templates.MCTS_POLICY_TEMPLATE) - return template.render( - bos_token=self._get_bos_token(), - game_state=game_state, - valid_actions=actions_str, - history=history_str, - ) - - def build_sft_prompt( - self, - game_state: str, - valid_actions: List[str], - mcts_action: str, - mcts_value: float, - ) -> str: - """ - Build a prompt for supervised fine-tuning from MCTS policy. - - Args: - game_state: Current observation text - valid_actions: List of valid action strings - mcts_action: Action recommended by MCTS - mcts_value: Value estimate from MCTS - - Returns: - Formatted prompt string - """ - actions_str = "\n".join([f"{i+1}. {action}" for i, action in enumerate(valid_actions)]) - - template = Template(self.templates.SFT_FROM_MCTS_TEMPLATE) - return template.render( - bos_token=self._get_bos_token(), - game_state=game_state, - valid_actions=actions_str, - mcts_action=mcts_action, - mcts_value=f"{mcts_value:.3f}", - ) - - def build_rft_prompt( - self, - game_state: str, - valid_actions: List[str], - trajectory: List[Dict[str, Any]], - cumulative_reward: float, - ) -> str: - """ - Build a prompt for reward fine-tuning. - - Args: - game_state: Current observation text - valid_actions: List of valid action strings - trajectory: Recent trajectory with rewards - cumulative_reward: Total reward accumulated - - Returns: - Formatted prompt string - """ - actions_str = "\n".join([f"{i+1}. {action}" for i, action in enumerate(valid_actions)]) - - # Format trajectory with rewards - traj_lines = [] - for i, step in enumerate(trajectory[-5:]): - action = step.get('action', 'N/A') - reward = step.get('reward', 0) - traj_lines.append(f" Step {i+1}: Action: {action} → Reward: {reward:+.2f}") - trajectory_str = "\n".join(traj_lines) - - template = Template(self.templates.RFT_TEMPLATE) - return template.render( - bos_token=self._get_bos_token(), - game_state=game_state, - valid_actions=actions_str, - trajectory_with_rewards=trajectory_str, - cumulative_reward=f"{cumulative_reward:+.2f}", - ) - - def build_eval_prompt( - self, - game_state: str, - valid_actions: List[str], - ) -> str: - """ - Build a simple prompt for evaluation. - - Args: - game_state: Current observation text - valid_actions: List of valid action strings - - Returns: - Formatted prompt string - """ - actions_str = "\n".join([f"{i+1}. {action}" for i, action in enumerate(valid_actions)]) - - template = Template(self.templates.EVAL_TEMPLATE) - return template.render( - bos_token=self._get_bos_token(), - game_state=game_state, - valid_actions=actions_str, - ) - - -# ============================================================================== -# Utility Functions -# ============================================================================== - -def extract_action_from_llm_output(llm_output: str, valid_actions: List[str]) -> Optional[str]: - """ - Extract the action from LLM output with tags. - - Args: - llm_output: Full LLM response including and tags - valid_actions: List of valid action strings to match against - - Returns: - Extracted action string, or None if extraction fails - - Example: - >>> output = "I need light take lamp" - >>> extract_action_from_llm_output(output, ["take lamp", "go north"]) - "take lamp" - """ - import re - - # Pattern to extract content between and - pattern = r"\s*(.*?)\s*" - match = re.search(pattern, llm_output, re.DOTALL | re.IGNORECASE) - - if not match: - return None - - extracted = match.group(1).strip() - - # Try exact match first - if extracted in valid_actions: - return extracted - - # Try case-insensitive match - extracted_lower = extracted.lower() - for action in valid_actions: - if action.lower() == extracted_lower: - return action - - # Try fuzzy match (substring) - for action in valid_actions: - if extracted_lower in action.lower() or action.lower() in extracted_lower: - return action - - return None - - -# ============================================================================== -# Example Usage -# ============================================================================== - -if __name__ == "__main__": - print("="*80) - print("PriorZero Prompt Templates - Example Usage") - print("="*80) - - # Mock tokenizer - class MockTokenizer: - bos_token_id = 1 - def decode(self, ids): - return "" - - tokenizer = MockTokenizer() - builder = PriorZeroPromptBuilder(tokenizer) - - # Example game state - game_state = "You are standing in an open field west of a white house." - valid_actions = ["go north", "go south", "go east", "open mailbox", "take mailbox"] - history = [ - {"observation": "West of House", "action": "look", "reward": 0}, - {"observation": "You see a mailbox", "action": "examine mailbox", "reward": 0}, - ] - - print("\n1. MCTS Policy Prompt:") - print("-"*80) - prompt = builder.build_mcts_policy_prompt(game_state, valid_actions, history) - print(prompt) - - print("\n2. SFT Prompt:") - print("-"*80) - sft_prompt = builder.build_sft_prompt(game_state, valid_actions, "open mailbox", 0.75) - print(sft_prompt) - - print("\n3. RFT Prompt:") - print("-"*80) - trajectory = [ - {"action": "go east", "reward": 0}, - {"action": "open mailbox", "reward": 5}, - ] - rft_prompt = builder.build_rft_prompt(game_state, valid_actions, trajectory, 5.0) - print(rft_prompt) - - print("\n4. Action Extraction:") - print("-"*80) - llm_output = "The mailbox might contain something useful. open mailbox" - extracted = extract_action_from_llm_output(llm_output, valid_actions) - print(f"LLM Output: {llm_output}") - print(f"Extracted Action: {extracted}") - - print("\n" + "="*80) - print("✓ All prompt templates demonstrated successfully!") - print("="*80) diff --git a/zoo/jericho/priorzero/priorzero_trainer.py b/zoo/jericho/priorzero/priorzero_trainer.py new file mode 100644 index 000000000..303c9817e --- /dev/null +++ b/zoo/jericho/priorzero/priorzero_trainer.py @@ -0,0 +1,161 @@ +from __future__ import annotations +import os +import copy +import json + +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +import ray +import numpy as np +from transformers import AutoTokenizer + +import ray +import torch + +import numpy as np + + +class AdaptiveKLController: + """ + Adaptive KL controller described in the paper: + https://arxiv.org/pdf/1909.08593.pdf + """ + + def __init__(self, init_kl_coef, target, horizon): + self.value = init_kl_coef + self.target = target + self.horizon = horizon + + def update(self, current, n_steps): + target = self.target + proportional_error = np.clip(current / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.horizon + self.value *= mult + + +class FixedKLController: + """Fixed KL controller.""" + + def __init__(self, kl_coef): + self.value = kl_coef + + def update(self, current, n_steps): + pass + + +def get_tokenizer(pretrain: str) -> AutoTokenizer: + tokenizer = AutoTokenizer.from_pretrained( + pretrain, trust_remote_code=True, padding_side="left" + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + +class PriorZeroLLMTrainer: + + def __init__( + self, + cfg, + pretrain: str, + strategy, + vllm_engine, + policy_model, # RayActorGroup(PolicyModelActor) + reference_model=None, # RayActorGroup(ReferenceModelActor) or None + exp_name: str = None, + tb_logger = None, + instance_name: str = "llm_ppo", + llm_save_freq: int = 1000, + ): + self.cfg = cfg + self.pretrain = pretrain + self.strategy = strategy + self.args = getattr(strategy, "args", None) + + self.policy_model = policy_model + self.reference_model = reference_model + self.vllm_engine = vllm_engine + self.global_step = 0 + self.llm_save_freq = llm_save_freq + + self.tokenizer = get_tokenizer(self.pretrain) + + self.init_kl_coef = float(getattr(cfg, "rft_kl_coef", 0.0)) + + self.kl_ctl = FixedKLController(self.init_kl_coef) + self.rank = self.strategy.get_rank() + self.world_size = self.strategy.world_size + + if tb_logger is not None: + from ding.utils import build_logger + self._logger, _ = build_logger( + path=f'./{exp_name}/log/{instance_name}', name=instance_name, need_tb=False + ) + self._tb_logger = tb_logger + else: + self._logger = None + self._tb_logger = None + + def train_batch(self, data, collect_env_steps) -> Dict[str, float]: + if data is None: + return {} + input_ids, attention_mask, action_mask, advantage, old_lp, log_status = data + assert len(input_ids) == len(attention_mask) == len(action_mask) == len(advantage) == len(old_lp) == len(log_status) + + batch = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "action_mask": action_mask, + "advantages": advantage, + "old_action_logprob": old_lp, + "log_status": log_status, + } + if self.reference_model is not None: + base_action_log_probs = self.reference_model.forward( + sequences = batch['input_ids'], + action_mask = batch['action_mask'], + attention_mask=batch['attention_mask'], + ) + batch["ref_action_log_probs"] = base_action_log_probs + else: + batch["ref_action_log_probs"] = None + + if self.strategy.args.deepspeed_enable_sleep: + self.policy_model.reload_states() + + status = self.policy_model.fit(batch, self.kl_ctl) + + if self.vllm_engine is not None: + self._broadcast_to_vllm() + + if self.strategy.args.deepspeed_enable_sleep: + self.policy_model.offload_states() + + if self._tb_logger is not None and self.strategy.is_rank_0(): + for tmp_dict in status: + for k, v in tmp_dict.items(): + if k == 'iter': + continue + self._tb_logger.add_scalar(f"learner_llm_iter/{k}", float(v), int(tmp_dict['iter'])) + self._tb_logger.add_scalar(f"learner_llm_envstep/{k}", float(v), int(collect_env_steps)) + self.global_step = max(self.global_step, int(tmp_dict['iter'])) + + if self.strategy.is_rank_0(): + if self.global_step > 0 and self.global_step % self.llm_save_freq == 0: + self.policy_model.save_model() + + def get_state(self) -> Dict[str, Any]: + kl_val = float(self.kl_ctl.value) if hasattr(self.kl_ctl, "value") else float(self.init_kl_coef) + return {"global_step": self.global_step, "kl_coef": kl_val} + + def _broadcast_to_vllm(self): + if self.strategy.args.vllm_enable_sleep: + self.vllm_engine.wake_up() + + print(f"[Rank {self.rank}]: vllm starting update weights....") + self.policy_model.broadcast_to_vllm() + print(f"[Rank {self.rank}]: vllm has updating done.") + + if self.strategy.args.vllm_enable_sleep: + self.vllm_engine.sleep() \ No newline at end of file diff --git a/zoo/jericho/priorzero/ray_utils/model.py b/zoo/jericho/priorzero/ray_utils/model.py new file mode 100644 index 000000000..6e6d41373 --- /dev/null +++ b/zoo/jericho/priorzero/ray_utils/model.py @@ -0,0 +1,354 @@ +from typing import Dict, List, Optional, Union +import os +from abc import ABC +import math +import socket + +import ray +import torch +import deepspeed +import torch.distributed +from torch.optim import Optimizer +from transformers.trainer import get_scheduler + +from ..vllm_engine import get_bundle_indices, get_physical_gpu_id +from openrlhf.utils.distributed_util import stateless_init_process_group, torch_dist_barrier_and_cuda_sync +from openrlhf.trainer.ray.launcher import BaseModelActor +from openrlhf.models import Actor, PolicyLoss +from openrlhf.utils.deepspeed import DeepspeedStrategy +from openrlhf.utils import get_tokenizer +from openrlhf.utils.deepspeed.deepspeed_utils import offload_deepspeed_states, reload_deepspeed_states + +@ray.remote(num_gpus=1) +class ReferenceModel(BaseModelActor): + def init_model_from_pretrained(self, strategy: DeepspeedStrategy, pretrain): + self._setup_distributed(strategy) + model = Actor( + pretrain, + attn_implementation=strategy.args.attn_implementation, + bf16=strategy.args.bf16, + ds_config=strategy.get_ds_eval_config(offload=False), + temperature=strategy.args.temperature, + ) + strategy.print(model) + + self.model = self.strategy.prepare(model, is_rlhf=True) + self.model.eval() + + def forward( + self, + sequences: torch.LongTensor, + action_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + return_output=False, + packed_seq_lens: Optional[list[int]] = None, + ) -> torch.Tensor: + device = torch.cuda.current_device() + with torch.no_grad(): + log_probs = self.model( + sequences.to(device), + action_mask.to(device), + attention_mask.to(device), + ring_attn_group=self.strategy.ring_attn_group, + packed_seq_lens=packed_seq_lens, + ) + return log_probs.to("cpu") + + +class ActorPPOTrainer(ABC): + def __init__( + self, + strategy, + actor: Actor, + ema_model: Actor, + actor_optim: Optimizer, + actor_scheduler, + ema_beta: float = 0.992, + micro_train_batch_size: int = 8, + eps_clip: float = 0.2, + tokenizer=None, + vllm_engines: List = None, + **kwargs, + ): + """PPOTrainer for ray. + + Args: + vllm_engines (List, optional): vllm engines for text generation, if not specified, generate text by actor model directly. Defaults to None. + """ + self.strategy = strategy + self.args = strategy.args + self.tokenizer = tokenizer + self.generate_kwargs = kwargs + self.micro_train_batch_size = micro_train_batch_size + self.ema_beta = ema_beta + + self.actor = actor + self.ema_model = ema_model + self.actor_optim = actor_optim + self.actor_scheduler = actor_scheduler + self.vllm_engines = vllm_engines + + self.actor_loss_fn = PolicyLoss( + clip_eps_low=eps_clip, + clip_eps_high=eps_clip, + ) + + # Init torch group for weights sync + backend = getattr(self.strategy.args, "vllm_sync_backend", "nccl") + self.use_cuda_ipc = False + if backend == "nccl" and self.args.policy_model_num_gpus == 1: + self.use_cuda_ipc = True + + # Create torch group with deepspeed rank 0 and all vllm ranks + # to update vllm engine's weights after each training stage. + # + # Say we have 3 vllm engines and each of them has 4 GPUs, + # then the torch group is: + # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] + # |ds rank 0 | engine-0 | engine-1 | engine-2 | + # + # For ZeRO-1/2: + # 1. Broadcast parameters from rank 0 to all vllm engines + # For ZeRO-3: + # 1. AllGather paramters to rank 0 + # 2. Broadcast parameters from rank 0 to all vllm engines + if self.vllm_engines is not None and not self.use_cuda_ipc and torch.distributed.get_rank() == 0: + master_address = ray._private.services.get_node_ip_address() + with socket.socket() as sock: + sock.bind(("", 0)) + master_port = sock.getsockname()[1] + + vllm_num_engines, vllm_tensor_parallel_size = ( + self.strategy.args.vllm_num_engines, + self.strategy.args.vllm_tensor_parallel_size, + ) + world_size = vllm_num_engines * vllm_tensor_parallel_size + 1 + + use_ray = getattr(self.strategy.args, "vllm_sync_with_ray", False) + group_name = "openrlhf" + refs = [ + engine.init_process_group.remote( + master_address, + master_port, + i * vllm_tensor_parallel_size + 1, + world_size, + group_name, + backend=backend, + use_ray=use_ray, + ) + for i, engine in enumerate(self.vllm_engines) + ] + if use_ray: + import ray.util.collective as collective + + collective.init_collective_group(world_size=world_size, rank=0, backend=backend, group_name=group_name) + self._model_update_group = group_name + else: + self._model_update_group = stateless_init_process_group( + master_address, master_port, 0, world_size, torch.cuda.current_device() + ) + + ray.get(refs) + + torch_dist_barrier_and_cuda_sync() + + def ppo_train(self, kl_ctl: float): + pass + + def training_step(self, experience, kl_ctl: float, step: int) -> Dict[str, float]: + pass + + def _broadcast_to_vllm(self): + use_prefix_cache = getattr(self.strategy.args, "enable_prefix_caching", False) + cache_reset_refs = [] + if use_prefix_cache and torch.distributed.get_rank() == 0: + # clear prefix cache + for engine in self.vllm_engines: + cache_reset_refs.append(engine.reset_prefix_cache.remote()) + + torch.cuda.empty_cache() + model = self.actor.model.module + count, num_params = 0, len(list(model.named_parameters())) + + def _broadcast_param(param, count, num_params): + use_ray = getattr(self.strategy.args, "vllm_sync_with_ray", False) + # Fire all vllm engines for broadcast + if torch.distributed.get_rank() == 0: + shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape + refs = [ + engine.update_weight.remote(name, dtype=param.dtype, shape=shape, empty_cache=count == num_params) + for engine in self.vllm_engines + ] + + if use_ray: + import ray.util.collective as collective + + collective.broadcast(param.data, 0, group_name=self._model_update_group) + else: + self._model_update_group.broadcast(param.data, src=0, stream=torch.cuda.current_stream()) + ray.get(refs) + + def _handle_cuda_ipc(param, count, num_params): + from torch.multiprocessing.reductions import reduce_tensor + + weight = param.data.clone() + ipc_handle = reduce_tensor(weight) + + ipc_handle = {get_physical_gpu_id(): ipc_handle} + ipc_handle_list = [None] * torch.distributed.get_world_size() + torch.distributed.all_gather_object(ipc_handle_list, ipc_handle) + + if torch.distributed.get_rank() == 0: + ipc_handles = {} + for d in ipc_handle_list: + ipc_handles.update(d) + + shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape + refs = [ + engine.update_weight_cuda_ipc.remote( + name, + dtype=param.dtype, + shape=shape, + ipc_handles=ipc_handles, + empty_cache=count == num_params, + ) + for engine in self.vllm_engines + ] + ray.get(refs) + torch_dist_barrier_and_cuda_sync() + + for name, param in model.named_parameters(): + count += 1 # empty_cache at last param + + # broadcast + if not self.use_cuda_ipc: + # For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0 + if self.strategy.args.ds_tensor_parallel_size > 1: + with deepspeed.module_inject.layers.GatherReplacedLayerParams([param], model, enabled=True): + _broadcast_param(param, count, num_params) + else: + with deepspeed.zero.GatheredParameters([param], enabled=self.strategy.args.zero_stage == 3): + _broadcast_param(param, count, num_params) + # CUDA IPC + else: + if self.strategy.args.ds_tensor_parallel_size > 1: + with deepspeed.module_inject.layers.GatherReplacedLayerParams([param], model, enabled=True): + _handle_cuda_ipc(param, count, num_params) + else: + with deepspeed.zero.GatheredParameters([param], enabled=self.strategy.args.zero_stage == 3): + _handle_cuda_ipc(param, count, num_params) + + if cache_reset_refs: + ray.get(cache_reset_refs) + torch.cuda.empty_cache() + torch_dist_barrier_and_cuda_sync() + + +@ray.remote(num_gpus=1) +class PolicyModel(BaseModelActor): + def init_model_from_pretrained(self, strategy: DeepspeedStrategy, pretrain, max_steps=None, vllm_engines=None): + args = strategy.args + self.vllm_engines = vllm_engines + self.max_steps = max_steps + + if getattr(args, "vllm_num_engines", 0) > 0: + # To prevent hanging during NCCL synchronization of weights between DeepSpeed and vLLM. + # see https://github.com/vllm-project/vllm/blob/c6b0a7d3ba03ca414be1174e9bd86a97191b7090/vllm/worker/worker_base.py#L445 + if getattr(args, "vllm_sync_backend", "nccl") == "nccl": + os.environ["NCCL_CUMEM_ENABLE"] = "0" + + self._setup_distributed(strategy) + + actor = Actor( + pretrain, + attn_implementation=strategy.args.attn_implementation, + bf16=strategy.args.bf16, + ds_config=strategy.get_ds_train_config(is_actor=True), + temperature=strategy.args.temperature, + ) + strategy.print(actor) + + # configure tokenizer + self.tokenizer = get_tokenizer( + pretrain, actor.model, "left", strategy) + + # configure optimizer + actor_optim = strategy.create_optimizer( + actor, lr=args.learning_rate, betas=args.adam_betas, weight_decay=args.weight_decay + ) + + # actor_scheduler = get_scheduler(args.lr_scheduler, actor_optim, num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio), + # num_training_steps=max_steps, + # scheduler_specific_kwargs={"min_lr": args.actor_learning_rate * 0.1}, + # ) + actor_scheduler = None + + if args.gradient_checkpointing: + actor.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": False} + ) + + # prepare models/optimizers... + self.actor, self.actor_optim, self.actor_scheduler = strategy.prepare( + (actor, actor_optim, actor_scheduler), + is_rlhf=True, + ) + + # initial offload + if strategy.args.deepspeed_enable_sleep: + offload_deepspeed_states(self.actor.model) + + # configure Trainer + self.trainer = ActorPPOTrainer( + strategy, + self.actor, + ema_model=None, + actor_optim=self.actor_optim, + actor_scheduler=self.actor_scheduler, + micro_train_batch_size=args.micro_train_batch_size, + tokenizer=self.tokenizer, + eps_clip=args.eps_clip, + vllm_engines=self.vllm_engines, + ) + + def fit(self, kl_ctl: float = 0): + """Train actor model with the replay buffer.""" + torch.cuda.empty_cache() + self.actor.train() + status = self.trainer.ppo_train(kl_ctl) + self.trainer.replay_buffer.clear() + torch.cuda.empty_cache() + torch.cuda.synchronize() + return status + + def forward( + self, + sequences: torch.LongTensor, + action_mask: Optional[Union[int, list[int]]] = None, + attention_mask: Optional[torch.Tensor] = None, + packed_seq_lens=None, + ) -> torch.Tensor: + """Generates actor values.""" + device = torch.cuda.current_device() + self.actor.eval() + with torch.no_grad(): + action_log_probs = self.actor( + sequences.to(device), + action_mask.to(device), + attention_mask.to(device), + ring_attn_group=self.strategy.ring_attn_group, + ) + self.actor.train() # reset model state + return action_log_probs.to("cpu") + + def broadcast_to_vllm(self): + self.trainer._broadcast_to_vllm() + + def append(self, experience): + self.trainer.replay_buffer.append(experience) + + def reload_states(self): + reload_deepspeed_states(self.actor.model) + + def offload_states(self): + offload_deepspeed_states(self.actor.model) diff --git a/zoo/jericho/priorzero/strategy/deepspeed.py b/zoo/jericho/priorzero/strategy/deepspeed.py new file mode 100644 index 000000000..d28788062 --- /dev/null +++ b/zoo/jericho/priorzero/strategy/deepspeed.py @@ -0,0 +1,644 @@ +import os +import shutil +from abc import ABC +from collections import defaultdict +from datetime import timedelta +from typing import List, Tuple, Union +import math + +import deepspeed +import torch +import torch.nn as nn +import torch.optim as optim +import transformers +from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam +from peft import PeftModel, get_peft_model_state_dict +from torch import distributed as dist +from torch.distributed.device_mesh import init_device_mesh +from torch.optim import Optimizer + +from utils import torch_dist_barrier_and_cuda_sync +from models.actor import Actor +from packaging import version + +ModelOptimPair = Tuple[nn.Module, Optimizer] +ModelOrModelOptimPair = Union[nn.Module, ModelOptimPair] + + +def get_train_ds_config( + offload, + adam_offload=True, + stage=2, + bf16=True, + max_norm=1.0, + zpg=8, + grad_accum_dtype=None, + overlap_comm=False, + use_ds_universal_ckpt=False, + deepcompile=False, + tensor_parallel_size=1, +): + device = "cpu" if offload else "none" + zero_opt_dict = { + "stage": stage, + "offload_param": {"device": device}, + "offload_optimizer": { + "device": "cpu" if adam_offload else "none", + "pin_memory": True, + }, + "sub_group_size": "auto", + "stage3_max_live_parameters": "auto", + "stage3_max_reuse_distance": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_prefetch_bucket_size": "auto", + "reduce_bucket_size": "auto", + # ZeRO++ + "zero_hpz_partition_size": zpg, + "zero_quantized_weights": False, + "zero_quantized_gradients": False, + } + if overlap_comm: + zero_opt_dict["overlap_comm"] = True + zero_opt_dict["contiguous_gradients"] = True + if stage == 3: + zero_opt_dict["reduce_scatter"] = True + + return { + "steps_per_print": 100, + "zero_optimization": zero_opt_dict, + "bf16": { + "enabled": bf16, + }, + "gradient_clipping": max_norm, + "prescale_gradients": False, + "wall_clock_breakdown": False, + "data_types": {"grad_accum_dtype": grad_accum_dtype}, + "checkpoint": { + "load_universal": use_ds_universal_ckpt, + }, + "compile": { + "deepcompile": deepcompile, + }, + "tensor_parallel": { + "autotp_size": tensor_parallel_size, + }, + } + + +def get_eval_ds_config( + offload, + stage=0, + bf16=True, + deepcompile=False, + tensor_parallel_size=1, +): + # At least for 0.16.6, DeepCompile hasn't support pure inference mode + # https://github.com/deepspeedai/DeepSpeed/pull/7225 + deepcompile = False + + zero_opt_dict = { + "stage": stage, + "stage3_max_live_parameters": "auto", + "stage3_max_reuse_distance": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_prefetch_bucket_size": "auto", + "offload_param": { + "device": "cpu" if offload else "none", + "pin_memory": True, + }, + } + return { + "steps_per_print": 100, + "zero_optimization": zero_opt_dict, + "bf16": { + "enabled": bf16, + }, + "gradient_clipping": 1.0, + "prescale_gradients": False, + "wall_clock_breakdown": False, + "compile": { + "deepcompile": deepcompile, + }, + "tensor_parallel": { + "autotp_size": tensor_parallel_size, + }, + } + + +def get_optimizer_grouped_parameters( + model, + weight_decay, + no_decay_name_list=["bias", "layer_norm.weight", "layernorm.weight", "norm.weight", "ln_f.weight"], +): + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in model.named_parameters() + if (not any(nd in n for nd in no_decay_name_list) and p.requires_grad) + ], + "weight_decay": weight_decay, + }, + { + "params": [ + p + for n, p in model.named_parameters() + if (any(nd in n for nd in no_decay_name_list) and p.requires_grad) + ], + "weight_decay": 0.0, + }, + ] + return optimizer_grouped_parameters + +def offload_deepspeed_states(model, pin_memory=True, non_blocking=True): + zero_stage = model.zero_optimization_stage() # config['zero_optimization']['stage'] + adam_offload = model.config["zero_optimization"]["offload_optimizer"]["device"] == "cpu" + + # state offloading not required when using Adam optimizer offloading + if adam_offload: + return + + if zero_stage != 3 and version.parse(deepspeed.__version__) <= version.parse("0.17.5"): + raise NotImplementedError( + "Only Zero stage 3 is currently supported when using DeepSpeed version 0.17.5 or lower" + ) + + # if zero_stage == 3 and not adam_offload: + from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum + + offload_state_types = [ + OffloadStateTypeEnum.optim_states, + OffloadStateTypeEnum.contiguous_grad_buffer, + OffloadStateTypeEnum.hp_params, + ] + + if version.parse(deepspeed.__version__) >= version.parse("0.16.5"): + # These offload types are fixed in https://github.com/deepspeedai/DeepSpeed/pull/7050 + offload_state_types += [ + OffloadStateTypeEnum.lp_grads, + # OffloadStateTypeEnum.lp_params, + ] + + model.optimizer.offload_states( + include=offload_state_types, + device=OffloadDeviceEnum.cpu, + pin_memory=pin_memory, + non_blocking=non_blocking, + ) + model.empty_partition_cache() + torch.cuda.empty_cache() + torch.distributed.barrier() + torch.cuda.synchronize() + +def reload_deepspeed_states(model, non_blocking=True): + zero_stage = model.zero_optimization_stage() # config['zero_optimization']['stage'] + adam_offload = model.config["zero_optimization"]["offload_optimizer"]["device"] == "cpu" + + # state offloading not required when using Adam optimizer offloading + if adam_offload: + return + + if zero_stage != 3 and version.parse(deepspeed.__version__) <= version.parse("0.17.5"): + raise NotImplementedError( + "Only Zero stage 3 is currently supported when using DeepSpeed version 0.17.5 or lower" + ) + model.reload_states(non_blocking=non_blocking) + torch.cuda.empty_cache() + torch.distributed.barrier() + torch.cuda.synchronize() + +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus +def _z3_params_to_fetch(param_list): + return [p for p in param_list if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE] + + +def get_strategy(args): + strategy = DeepspeedStrategy( + seed=getattr(args, "seed", 42), + max_norm=getattr(args, "max_norm", 1.0), + micro_train_batch_size=getattr(args, "micro_train_batch_size", 1), + train_batch_size=getattr(args, "train_batch_size", 128), + zero_stage=args.zero_stage, + bf16=getattr(args, "bf16", True), + args=args, + ) + return strategy + + +class DeepspeedStrategy(ABC): + """ + The strategy for training with Accelerator. + """ + + def __init__( + self, + seed: int = 42, + max_norm: float = 0.0, + micro_train_batch_size=1, + train_batch_size=1, + zero_stage=2, + bf16=True, + args=None, + ) -> None: + super().__init__() + + self.args = args + self.stage = zero_stage + self.train_batch_size = train_batch_size + self.micro_train_batch_size = micro_train_batch_size + self.bf16 = bf16 + self.seed = seed + self.max_norm = max_norm + + self.adam_offload = getattr(args, "adam_offload", False) + self.zpg = getattr(args, "zpg", 1) + self.grad_accum_dtype = getattr(args, "grad_accum_dtype", None) + self.overlap_comm = getattr(args, "overlap_comm", False) + self.deepcompile = getattr(args, "deepcompile", False) + self.ds_tensor_parallel_size = getattr(args, "ds_tensor_parallel_size", 1) + self.use_dynamic_batch = getattr(self.args, "use_dynamic_batch", False) + + if self.ds_tensor_parallel_size > 1: + assert deepspeed.version >= "0.16.4", "DeepSpeed version must be >= 0.16.4 for tensor parallel training" + assert bf16, "BF16 is required for tensor parallel training" + + self.is_rlhf = False + self.time_steps = defaultdict(int) + + def setup_distributed(self, timeout=timedelta(minutes=60)) -> None: + transformers.set_seed(self.seed) + + local_rank = int(os.environ.get("LOCAL_RANK", "-1")) + if local_rank != -1: + torch.cuda.set_device(local_rank) + + # Initializes the distributed backend which will take care of synchronizing nodes/GPUs + # deepspeed.init_distributed(dist_backend="nccl", timeout=timeout) + if not dist.is_initialized(): + print(f"[System] Initializing Distributed Process Group via torch.distributed...") + dist.init_process_group(backend="nccl", timeout=timeout) + + # mesh + self.world_size = dist.get_world_size() + dp_size = self.world_size // self.ds_tensor_parallel_size + self.ds_device_mesh = init_device_mesh( + "cuda", (dp_size, self.ds_tensor_parallel_size), mesh_dim_names=("dp", "tp") + ) + + self.accumulated_gradient = ( + self.train_batch_size + * self.ds_tensor_parallel_size + // self.micro_train_batch_size + // self.world_size + ) + + def create_optimizer(self, model, **kwargs) -> Optimizer: + if isinstance(model, Actor): + model = model.model + # Optimizer + AdamOptimizer = DeepSpeedCPUAdam if self.adam_offload else FusedAdam + optim_params = get_optimizer_grouped_parameters(model, kwargs["weight_decay"]) + optim = AdamOptimizer(optim_params, **kwargs) + return optim + + def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None: + if isinstance(model, Actor): + model = model.model + model.backward(loss) + + def optimizer_step( + self, + optimizer: optim.Optimizer, + model: nn.Module, + scheduler, + name="model", + **kwargs, + ) -> None: + if isinstance(model, Actor): + model = model.model + model.step() + + + def _unwrap_model(self, model) -> nn.Module: + if isinstance(model, Actor): + return self._unwrap_model(model.model) + elif hasattr(model, "module"): + return model.module + else: + return model + + def prepare( + self, *models_or_model_optim_pairs: ModelOrModelOptimPair, is_rlhf=False + ) -> Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: + ret = [] + self.is_rlhf = is_rlhf + for arg in models_or_model_optim_pairs: + if isinstance(arg, tuple): + assert len(arg) == 3, f'Expect (model, optimizer, scheduler) pair, got a tuple with size "{len(arg)}"' + if arg[0] is not None: + ret.append(self._ds_init_train_model(*arg)) + else: + ret.append((None, None, None)) + else: + ret.append(self._ds_init_eval_model(arg)) + + return ret[0] if len(ret) == 1 else ret + + def _ds_init_train_model(self, model, optim, scheduler): + is_actor = isinstance(model, Actor) + ds_config = self.get_ds_train_config(is_actor) + + if self.ds_tensor_parallel_size > 1: + tp_model = deepspeed.tp_model_init( + model=model.model if is_actor else model, tp_size=self.ds_tensor_parallel_size, dtype=torch.bfloat16 + ) + if is_actor: + model.model = tp_model + else: + model = tp_model + + engine, optim, _, scheduler = deepspeed.initialize( + model=model.model if is_actor else model, + optimizer=optim, + lr_scheduler=scheduler, + config=ds_config, + args={"local_rank": int(os.environ.get("LOCAL_RANK", "-1"))}, + dist_init_required=True, + ) + if self.deepcompile: + engine.compile() + if is_actor: + model.model = engine + else: + model = engine + + return model, optim, scheduler + + def get_ds_train_config(self, is_actor): + # DS Config + ds_config = get_train_ds_config( + offload=False, + adam_offload=self.adam_offload, + stage=self.stage, + bf16=self.bf16, + max_norm=self.max_norm, + zpg=self.zpg, + grad_accum_dtype=self.grad_accum_dtype, + overlap_comm=self.overlap_comm, + deepcompile=self.deepcompile, + tensor_parallel_size=self.ds_tensor_parallel_size, + ) + if self.use_dynamic_batch: + ds_config["train_micro_batch_size_per_gpu"] = 1 + ds_config["gradient_accumulation_steps"] = 1 + else: + ds_config["train_micro_batch_size_per_gpu"] = self.micro_train_batch_size + ds_config["train_batch_size"] = self.train_batch_size * self.ds_tensor_parallel_size + + return ds_config + + def _ds_init_eval_model(self, model): + if not model: + return model + is_actor = isinstance(model, Actor) + ds_config = self.get_ds_eval_config(offload=getattr(model, "_offload", False)) + + if self.ds_tensor_parallel_size > 1: + tp_model = deepspeed.tp_model_init( + model=model.model if is_actor else model, tp_size=self.ds_tensor_parallel_size, dtype=torch.bfloat16 + ) + if is_actor: + model.model = tp_model + else: + model = tp_model + + engine, *_ = deepspeed.initialize( + model=model.model if is_actor else model, + args={"local_rank": int(os.environ.get("LOCAL_RANK", "-1"))}, + config=ds_config, + dist_init_required=True, + ) + if self.deepcompile: + engine.compile() + if is_actor: + model.model = engine + else: + model = engine + return model + + def get_ds_eval_config(self, offload=False): + # DS Config + ds_config = get_eval_ds_config( + offload=offload, + stage=self.stage if self.stage == 3 else 0, + bf16=self.bf16, + deepcompile=self.deepcompile, + tensor_parallel_size=self.ds_tensor_parallel_size, + ) + ds_config["train_micro_batch_size_per_gpu"] = self.micro_train_batch_size + ds_config["train_batch_size"] = self.train_batch_size * self.ds_tensor_parallel_size + + return ds_config + + def moving_average(self, model, model_ema, beta=0.992, device="cpu"): + self.time_steps["ema"] += 1 + if self.time_steps["ema"] % self.accumulated_gradient == 0 or self.use_dynamic_batch: + with torch.no_grad(): + for param, param_ema in zip(model.parameters(), model_ema.parameters()): + if param.requires_grad: + if self.stage != 3: + data = param.data.to(device) + param_ema.data.copy_((1 - beta) * data + beta * param_ema.data) + else: + # TODO: use prefiltering for efficiency + params_to_fetch = _z3_params_to_fetch([param, param_ema]) + with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=len(params_to_fetch) > 0): + data = param.data.to(device) + param_ema.data.copy_((1 - beta) * data + beta * param_ema.data) + + def load_model( + self, + model: nn.Module, + path: str, + map_location="cpu", + strict: bool = False, + key_replace_fn=None, + ) -> None: + unwrapped_model = self._unwrap_model(model) + state_dict = torch.load(path, map_location=map_location) + if key_replace_fn: + state_dict = key_replace_fn(state_dict) + unwrapped_model.load_state_dict(state_dict, strict=strict) + + def save_model(self, model: nn.Module, tokenizer, output_dir, **kwargs) -> None: + if self.is_rank_0(): + os.makedirs(output_dir, exist_ok=True) + + # save model weights for ZeRO2/3 + model_to_save = self._unwrap_model(model) + + # gather parameters + if self.args.zero_stage > 2 or self.args.ds_tensor_parallel_size > 1: + output_state_dict = ( + model.model._consolidated_16bit_state_dict() + if isinstance(model, Actor) + else model._consolidated_16bit_state_dict() + ) + else: + from deepspeed.checkpoint.utils import clone_tensors_for_torch_save + + output_state_dict = clone_tensors_for_torch_save(model_to_save.state_dict()) + + if self.is_rank_0(): + state_dict_keys = set(model_to_save.state_dict().keys()) + output_state_dict_keys = set(output_state_dict.keys()) + + # corner case for tie_word_embeddings, such as Qwen2-0.5B + if getattr(model_to_save.config, "tie_word_embeddings", False) and "lm_head.weight" in state_dict_keys: + state_dict_keys.remove("lm_head.weight") + + assert state_dict_keys.issubset( + output_state_dict_keys + ), f"mismatch keys {output_state_dict_keys.symmetric_difference(state_dict_keys)}" + + # only save peft weights https://github.com/microsoft/DeepSpeed/issues/4295 + if isinstance(model_to_save, PeftModel): + model_to_save.save_pretrained(output_dir, **kwargs) + if self.ds_tensor_parallel_size > 1 or self.stage == 3: + torch.save( + get_peft_model_state_dict(model_to_save, output_state_dict), + os.path.join(output_dir, "adapter_model.bin"), + ) + filename = os.path.join(output_dir, "adapter_model.safetensors") + if os.path.exists(filename): + os.remove(filename) + else: + # save model + model_to_save.save_pretrained(output_dir, state_dict=output_state_dict, **kwargs) + + # save config + output_config_file = os.path.join(output_dir, "config.json") + model_to_save.config.to_json_file(output_config_file) + # save tokenizer + tokenizer.save_pretrained(output_dir) + + del output_state_dict + # Explicitly release memory + import gc + + gc.collect() + + torch_dist_barrier_and_cuda_sync() + + def all_reduce(self, data, op="mean"): + assert op in ("mean", "max", "sum") + if isinstance(data, dict): + ret = {} + for k, v in data.items(): + ret[k] = self.all_reduce(v, op) + return ret + else: + is_tensor = True + if not isinstance(data, torch.Tensor): + data = torch.Tensor([data]) + is_tensor = False + is_cpu_tensor = data.device.type == "cpu" + + if is_cpu_tensor: + data = data.to(torch.cuda.current_device()) + if op == "mean": + data /= self.world_size + dist.all_reduce(data, op=dist.ReduceOp.MAX if op == "max" else dist.ReduceOp.SUM) + if is_cpu_tensor: + data = data.cpu() + return data.item() if not is_tensor else data + + def all_gather(self, data): + if isinstance(data, dict): + ret = {} + for k, v in data.items(): + ret[k] = self.all_gather(v) + return ret + else: + if not isinstance(data, torch.Tensor): + data = torch.Tensor([data]) + is_cpu_tensor = data.device.type == "cpu" + + ret = [torch.zeros_like(data).to(torch.cuda.current_device()) for _ in range(self.world_size)] + dist.all_gather(ret, data.to(torch.cuda.current_device())) + return torch.cat(ret).cpu() if is_cpu_tensor else torch.cat(ret) + + def print(self, *msg): + if self.is_rank_0(): + print(*msg) + + def is_rank_0(self) -> bool: + if not dist.is_initialized(): + return True + return dist.get_rank() == 0 + + def get_rank(self) -> int: + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + def save_ckpt(self, model, save_dir, tag=None, max_num=3, max_mem=1000, client_state={}, save_latest=True): + assert isinstance(model, deepspeed.DeepSpeedEngine) + if self.is_rank_0(): + os.makedirs(save_dir, exist_ok=True) + MAX_SIZE = max_mem * 1024**3 # Convert GB to bytes + + while True: + subdirs = sorted( + [ + (os.path.join(save_dir, d), os.path.getmtime(os.path.join(save_dir, d))) + for d in os.listdir(save_dir) + if os.path.isdir(os.path.join(save_dir, d)) + ], + key=lambda x: x[1], + ) + total_size = sum( + os.path.getsize(os.path.join(dirpath, f)) + for subdir, _ in subdirs + for dirpath, _, filenames in os.walk(subdir) + for f in filenames + ) + + if len(subdirs) >= max_num or total_size > MAX_SIZE: + oldest_dir = subdirs[0][0] + if os.path.exists(oldest_dir): + shutil.rmtree(oldest_dir) + self.print(f"Deleted oldest ckpt {oldest_dir}") + else: + break + + torch_dist_barrier_and_cuda_sync() + model.save_checkpoint(save_dir, tag=tag, client_state=client_state, save_latest=save_latest) + + # Explicitly release memory + import gc + + gc.collect() + + def load_ckpt( + self, + model, + load_dir, + tag=None, + load_module_strict=True, + load_optimizer_states=True, + load_lr_scheduler_states=True, + load_module_only=False, + ): + assert isinstance(model, deepspeed.DeepSpeedEngine) + load_path, states = model.load_checkpoint( + load_dir, + tag, + load_module_strict=load_module_strict, + load_optimizer_states=load_optimizer_states, + load_lr_scheduler_states=load_lr_scheduler_states, + load_module_only=load_module_only, + ) + if load_path is None: + raise Exception(f"[deepspeed] failed to resume from checkpoint {load_dir}") + return load_path, states diff --git a/zoo/jericho/priorzero/utils.py b/zoo/jericho/priorzero/utils.py new file mode 100644 index 000000000..81ccd94bd --- /dev/null +++ b/zoo/jericho/priorzero/utils.py @@ -0,0 +1,178 @@ +import torch +import torch.nn.functional as F +from typing import List, Dict, Any, Tuple, Union, Optional +from transformers import AutoTokenizer +from dataclasses import is_dataclass +import os +import inspect +import textwrap + +def dump_dataclass_cfg_py(cfg, path: str) -> str: + if not is_dataclass(cfg): + raise TypeError(type(cfg)) + + def norm(x): + if isinstance(x, dict): + return {k: norm(v) for k, v in x.items()} + if hasattr(x, "__class__") and x.__class__.__name__ == "EasyDict": + return {k: norm(v) for k, v in dict(x).items()} + if isinstance(x, (list, tuple)): + t = [norm(v) for v in x] + return tuple(t) if isinstance(x, tuple) else t + return x + cls = type(cfg) + fields = cls.__dataclass_fields__.keys() + lines = [f"{k} = {repr(norm(getattr(cfg, k)))}" for k in fields] + [""] + with open(path, "w", encoding="utf-8") as f: + f.write("\n".join(lines)) + return + +def torch_dist_barrier_and_cuda_sync(): + """Synchronize distributed training and CUDA operations. + This function ensures that: + 1. All distributed processes reach this point (barrier) + 2. All CUDA operations are completed (synchronize) + """ + import torch + + torch.distributed.barrier() + torch.cuda.synchronize() + + +def get_tokenizer(pretrain, model, padding_side="left", use_fast=True): + tokenizer = AutoTokenizer.from_pretrained(pretrain, trust_remote_code=True, use_fast=use_fast) + tokenizer.padding_side = padding_side + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + tokenizer.pad_token_id = tokenizer.eos_token_id + if model is not None: + model.config.pad_token_id = tokenizer.pad_token_id + + return tokenizer + +@torch.compile +def compute_entropy(logits: torch.Tensor): + pd = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1) + return entropy + + +def compute_approx_kl( + log_probs: torch.Tensor, + log_probs_base: torch.Tensor, + kl_estimator: str = "k1", +) -> torch.Tensor: + """ + Compute the approximate KL divergence between two distributions. + Schulman blog: http://joschu.net/blog/kl-approx.html + + Args: + log_probs: Log probabilities of the new distribution. + log_probs_base: Log probabilities of the base distribution. + """ + + if kl_estimator == "k1": + log_ratio = log_probs.float() - log_probs_base.float() + + # The k2 estimator is the non negative kl approximation in + # http://joschu.net/blog/kl-approx.html + # The k2_loss is approximately equivalent to the + # one-step KL divergence penalty with the k1 estimator + # used in https://arxiv.org/pdf/2310.10505. + if kl_estimator == "k2": + log_ratio = log_probs.float() - log_probs_base.float() + log_ratio = log_ratio**2 / 2.0 + + # The k3 estimator is the non negative kl approximation in + # http://joschu.net/blog/kl-approx.html + if kl_estimator == "k3": + log_ratio = log_probs.float() - log_probs_base.float() + log_ratio = -log_ratio + log_ratio = log_ratio.exp() - 1 - log_ratio + + log_ratio = log_ratio.clamp(min=-10, max=10) + return log_ratio + +def masked_mean(tensor: torch.Tensor, mask: Optional[torch.Tensor], dim: int = None) -> torch.Tensor: + if mask is None: + return tensor.mean(dim=dim) + return (tensor * mask).sum(dim=dim) / mask.sum(dim=dim) + + +def _logsumexp_by_chunk(logits: torch.Tensor, chunk_size: int = 1024) -> torch.Tensor: + seq_len = logits.shape[0] + logsumexp_values = torch.zeros((seq_len), device=logits.device, dtype=logits.dtype) + for s_idx in range(0, seq_len, chunk_size): + end_idx = min(s_idx + chunk_size, seq_len) + logsumexp_values[s_idx:end_idx] = torch.logsumexp(logits[s_idx:end_idx], dim=-1) + + return logsumexp_values + +def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0) -> torch.Tensor: + if temperature != 1.0: + logits.div_(temperature) + # https://github.com/OpenRLHF/OpenRLHF/pull/718#issuecomment-2641081881 + if logits.dtype in [torch.float32, torch.float64]: + batch_dim = logits.shape[:-1] + last_dim = logits.shape[-1] + try: + from flash_attn.ops.triton.cross_entropy import cross_entropy_loss + + output = cross_entropy_loss(logits.reshape(-1, last_dim), labels.reshape(-1)) + log_probs_labels = -output[0].view(*batch_dim) + except ImportError: + logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) + logsumexp_values = _logsumexp_by_chunk(logits.reshape(-1, last_dim)) + logsumexp_values = logsumexp_values.view(*batch_dim) + log_probs_labels = logits_labels - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) + else: + log_probs_labels = [] + for row_logits, row_labels in zip(logits, labels): # loop to reduce peak mem consumption + row_log_probs = F.log_softmax(row_logits, dim=-1) + row_log_probs_labels = row_log_probs.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) + log_probs_labels.append(row_log_probs_labels) + log_probs_labels = torch.stack(log_probs_labels) + return log_probs_labels + + + +import time +from contextlib import contextmanager +from collections import defaultdict + +class Profiler: + def __init__(self, log_interval: int = 10, stats_file: str = None, enable_profile: bool = False): + self.log_interval = max(1, int(log_interval)) + self.stats_file = stats_file + self.stats = defaultdict(lambda: {"count": 0, "total": 0.0, "max": 0.0}) + self._inited = False + self.enable_profile = enable_profile + + def _init_once(self): + if self._inited: + return + with open(self.stats_file, "a", encoding="utf-8") as f: + f.write("ts\tname\tcount\ttotal_s\tavg_s\tmax_s\n") + self._inited = True + + def _record(self, name: str, elapsed: float): + s = self.stats[name] + s["count"] += 1 + s["total"] += elapsed + s["max"] = max(s["max"], elapsed) + if s["count"] % self.log_interval == 0: + avg = s["total"] / s["count"] + with open(self.stats_file, "a", encoding="utf-8") as f: + f.write(f"{time.time():.3f}\t{name}\t{s['count']}\t{s['total']:.6f}\t{avg:.6f}\t{s['max']:.6f}\n") + + @contextmanager + def block(self, name: str, rank: int = 0): + if not self.enable_profile or rank != 0: + yield None + return + self._init_once() + t0 = time.perf_counter() + try: + yield None + finally: + self._record(name, time.perf_counter() - t0) \ No newline at end of file diff --git a/zoo/jericho/priorzero/vllm_utils/vllm_engine.py b/zoo/jericho/priorzero/vllm_utils/vllm_engine.py new file mode 100644 index 000000000..0908d0f6d --- /dev/null +++ b/zoo/jericho/priorzero/vllm_utils/vllm_engine.py @@ -0,0 +1,85 @@ +import os +import queue +from typing import Any, List +import vllm + +class LLMActor: + def __init__(self, model: str = None, **kwargs): + self.requests = {} + self.kwargs = kwargs + self.llm = vllm.LLM(model=model, **self.kwargs) + + # def update_weight(self, name, dtype, shape, empty_cache=False): + # return self.llm.collective_rpc("update_weight", args=(name, dtype, shape, empty_cache)) + + def update_weight(self, name, dtype, shape, weight, empty_cache=False): + return self.llm.collective_rpc("update_weight", args=(name, dtype, shape, weight, empty_cache)) + + def update_weight_cuda_ipc(self, name, dtype, shape, ipc_handles, empty_cache=False): + return self.llm.collective_rpc("update_weight_cuda_ipc", args=(name, dtype, shape, ipc_handles, empty_cache)) + + def reset_prefix_cache(self): + self.llm.llm_engine.reset_prefix_cache() + + def sleep(self, level=1): + self.llm.sleep(level=level) + + def wake_up(self): + self.llm.wake_up() + + def add_requests(self, sampling_params, prompt_token_ids): + """ + Process requests from rank0 and generate responses. + Since only rank0 will send requests, we don't need to track actor ranks. + """ + from vllm.inputs import TokensPrompt + self.sampling_params = sampling_params + self.requests = [TokensPrompt(prompt_token_ids=r) for r in prompt_token_ids] + + def get_responses(self): + """ + Return the responses for the actor with the given rank + """ + responses = self.llm.generate( + prompts=self.requests, + sampling_params=self.sampling_params, + use_tqdm=False + ) + self.requests = {} + return responses + + +def create_vllm_engine( + tensor_parallel_size: int, + pretrain: str, + enable_prefix_caching: bool, + max_model_len: int, + gpu_memory_utilization=None, + vllm_enable_sleep=False, +): + from packaging import version + + distributed_executor_backend = "external_launcher" + + vllm_engine = LLMActor( + model=pretrain, + worker_extension_cls="vllm_utils.worker.WorkerWrap", + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + max_model_len=max_model_len, + enable_prefix_caching=enable_prefix_caching, + dtype="bfloat16", + gpu_memory_utilization=gpu_memory_utilization, + enable_sleep_mode=vllm_enable_sleep, + ) + if vllm_enable_sleep: + vllm_engine.sleep() + return vllm_engine + + +def get_physical_gpu_id(): + import torch + + device = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device) + return str(props.uuid) diff --git a/zoo/jericho/priorzero/vllm_utils/worker.py b/zoo/jericho/priorzero/vllm_utils/worker.py new file mode 100644 index 000000000..aac32e704 --- /dev/null +++ b/zoo/jericho/priorzero/vllm_utils/worker.py @@ -0,0 +1,47 @@ +class WorkerWrap: + def update_weight_cuda_ipc(self, name, dtype, shape, ipc_handles=None, empty_cache=False): + import torch + from vllm_utils.vllm_engine import get_physical_gpu_id + + if torch.distributed.get_rank() == 0: + print(f"update weight: {name}, dtype: {dtype}, shape: {shape}") + + assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}" + + handle = ipc_handles[get_physical_gpu_id()] + device_id = self.device.index + func, args = handle + list_args = list(args) + # the key is to change device id to the current device id + # in case two processes have different CUDA_VISIBLE_DEVICES + list_args[6] = device_id + weight = func(*list_args) + self.model_runner.model.load_weights(weights=[(name, weight)]) + torch.cuda.synchronize() + + # def update_weight(self, name, dtype, shape, empty_cache=False): + # import torch + + # """Broadcast weight to all vllm workers from source rank 0 (actor model)""" + # if torch.distributed.get_rank() == 0: + # print(f"update weight: {name}, dtype: {dtype}, shape: {shape}") + + # assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}" + # weight = torch.empty(shape, dtype=dtype, device="cuda") + + # self._model_update_group.broadcast(weight, src=0, stream=torch.cuda.current_stream()) + # self.model_runner.model.load_weights(weights=[(name, weight)]) + + # del weight + + def update_weight(self, name, dtype, shape, weight, empty_cache=False): # pylint: disable=R0917, W0613 + import torch + """Broadcast weight to all vllm workers from source rank 0 (actor model)""" + if torch.distributed.get_rank() == 0: + print(f"update weight: {name}, dtype: {dtype}, shape: {shape}") + + assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}" + + self.model_runner.model.load_weights(weights=[(name, weight)]) + + del weight