diff --git a/lzero/entry/__init__.py b/lzero/entry/__init__.py index f17126527..a59944967 100644 --- a/lzero/entry/__init__.py +++ b/lzero/entry/__init__.py @@ -10,4 +10,5 @@ from .train_rezero import train_rezero from .train_unizero import train_unizero from .train_unizero_segment import train_unizero_segment +from .train_unizero_segment_async import train_unizero_segment_async from .utils import * diff --git a/lzero/entry/async_training_guide.md b/lzero/entry/async_training_guide.md new file mode 100644 index 000000000..9197105ab --- /dev/null +++ b/lzero/entry/async_training_guide.md @@ -0,0 +1,127 @@ +# LightZero 异步训练改造指南 + +## 概述 + +本文档详细说明了如何将LightZero的collector、learner、evaluator从同步串行架构改造为异步并行架构,以提高训练效率。 + +## 当前架构分析 + +### 同步架构特点 +- **串行执行**:collector → learner → evaluator 按顺序执行 +- **强耦合**:各组件之间存在强依赖关系 +- **阻塞等待**:每个组件必须等待前一个组件完成 + +### 性能瓶颈 +1. **CPU利用率低**:GPU训练时CPU空闲,CPU收集时GPU空闲 +2. **资源浪费**:无法充分利用多核CPU和多GPU +3. **训练效率低**:总训练时间 = collector时间 + learner时间 + evaluator时间 + +## 异步改造方案 + +### 核心思想 +将三个组件解耦,通过线程池和消息队列实现并行执行: + +``` +┌─────────────┐ ┌─────────────┐ ┌─────────────┐ +│ Collector │ │ Learner │ │ Evaluator │ +│ Thread │ │ Thread │ │ Thread │ +└─────────────┘ └─────────────┘ └─────────────┘ + │ │ │ + ▼ ▼ ▼ +┌─────────────────────────────────────────────────────┐ +│ Data Queue & Policy Lock │ +└─────────────────────────────────────────────────────┘ +``` + +### 关键技术点 + +#### 1. 线程安全的数据共享 +```python +# 数据缓冲队列 +self.data_queue = queue.Queue(maxsize=10) + +# Policy更新锁 +self.policy_lock = threading.Lock() + +# 停止信号 +self.stop_event = threading.Event() +``` + +#### 2. 异步数据收集 +```python +def _collector_worker(self): + while not self.stop_event.is_set(): + # 获取最新policy(线程安全) + with self.policy_lock: + current_policy = self.policy.collect_mode + + # 收集数据 + new_data = self.collector.collect(...) + + # 放入队列供learner使用 + self.data_queue.put((new_data, self.env_step)) +``` + +#### 3. 异步模型训练 +```python +def _learner_worker(self): + while not self.stop_event.is_set(): + # 从队列获取数据 + new_data, data_env_step = self.data_queue.get(timeout=1.0) + + # 训练模型 + log_vars = self.learner.train(train_data, data_env_step) + + # 更新policy(线程安全) + with self.policy_lock: + # 确保policy更新是线程安全的 + pass +``` + +#### 4. 异步评估 +```python +def _evaluator_worker(self): + while not self.stop_event.is_set(): + if self.evaluator.should_eval(self.train_iter): + # 获取最新policy进行评估 + with self.policy_lock: + current_policy = self.policy.eval_mode + + stop, reward = self.evaluator.eval(...) + + # 定期检查,不阻塞主流程 + time.sleep(1.0) +``` + +## 最小化改动实现 + +### 1. 新增异步训练入口 +- 文件:`lzero/entry/train_unizero_segment_async.py` +- 功能:提供异步训练的主要逻辑 + +### 2. 配置文件支持 +- 文件:`zoo/classic_control/cartpole/config/cartpole_unizero_segment_async_config.py` +- 新增配置项: + ```python + enable_async_training = True + data_queue_size = 10 + enable_async_debug_log = True # 控制详细调试信息输出 + ``` + +### 3. 使用方式 +```python +from lzero.entry import train_unizero_segment_async +from zoo.classic_control.cartpole.config.cartpole_unizero_segment_async_config import main_config, create_config + +# 启动异步训练 +policy = train_unizero_segment_async( + [main_config, create_config], + seed=0, + max_env_step=int(2e5) +) + +# 控制调试信息输出 +# 在配置文件中设置 enable_async_debug_log = True/False +# True: 输出详细的异步训练调试信息 +# False: 只输出基本的训练信息 +``` \ No newline at end of file diff --git a/lzero/entry/train_unizero_segment_async.py b/lzero/entry/train_unizero_segment_async.py new file mode 100644 index 000000000..8abb5bc27 --- /dev/null +++ b/lzero/entry/train_unizero_segment_async.py @@ -0,0 +1,467 @@ +import logging +import os +import threading +import queue +import time +from functools import partial +from typing import Tuple, Optional + +import torch +import wandb +from ding.config import compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import EasyTimer +from ding.utils import set_pkg_seed, get_rank, get_world_size +from ding.worker import BaseLearner +from torch.utils.tensorboard import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.policy.random_policy import LightZeroRandomPolicy +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector +from .utils import random_collect, calculate_update_per_collect +import copy +timer = EasyTimer() + + +class AsyncTrainer: + """ + 异步训练器,实现collector、learner、evaluator的并行执行 + """ + + def __init__(self, cfg, create_cfg, model=None, model_path=None, max_train_iter=int(1e10), max_env_step=int(1e10)): + self.cfg = cfg + self.create_cfg = create_cfg + self.model = model + self.model_path = model_path + self.max_train_iter = max_train_iter + self.max_env_step = max_env_step + + # --- 优化: 使用配置中的队列大小 --- + queue_size = getattr(self.cfg.policy, 'data_queue_size', 10) + self.data_queue = queue.Queue(maxsize=queue_size) + + # 异步组件 + self.policy_lock = threading.Lock() + self.stop_event = threading.Event() + + # --- 优化: 添加用于评估的事件通知机制 --- + self.eval_event = threading.Event() + + # 训练状态 + self.train_iter = 0 + self.env_step = 0 + self.best_reward = float('-inf') + + # Buffer reanalyze相关状态 + self.buffer_reanalyze_count = 0 + self.train_epoch = 0 + self.reanalyze_batch_size = getattr(self.cfg.policy, 'reanalyze_batch_size', 2000) + + # 初始化组件 + self._init_components() + + def _init_components(self): + """初始化训练组件""" + try: + self.cfg = compile_config(self.cfg, seed=0, env=None, auto=True, create_cfg=self.create_cfg, save_cfg=True) + + if self.cfg.policy.enable_async_debug_log: + logging.info("[ASYNC_DEBUG] Configuration compiled successfully") + + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(self.cfg.env) + self.collector_env = create_env_manager( + self.cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg] + ) + self.evaluator_env = create_env_manager( + self.cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg] + ) + + self.collector_env.seed(self.cfg.seed) + self.evaluator_env.seed(self.cfg.seed, dynamic_seed=False) + set_pkg_seed(self.cfg.seed, use_cuda=torch.cuda.is_available()) + + if self.cfg.policy.enable_async_debug_log: + logging.info("[ASYNC_DEBUG] Environments created and seeded successfully") + + self.policy = create_policy(self.cfg.policy, model=self.model, enable_field=['learn', 'collect', 'eval']) + + # TODO(pu): share model, current have Race Condition, 导致KV-Cache 相关的并发读写不一致 + self.policy._collect_model = copy.deepcopy(self.policy._model) + self.policy._eval_model = copy.deepcopy(self.policy._model) + + if not hasattr(self.policy, 'collect_mode') or self.policy.collect_mode is None: + raise RuntimeError("Policy collect_mode is None after creation") + if not hasattr(self.policy, 'eval_mode') or self.policy.eval_mode is None: + raise RuntimeError("Policy eval_mode is None after creation") + if not hasattr(self.policy, 'learn_mode') or self.policy.learn_mode is None: + raise RuntimeError("Policy learn_mode is None after creation") + + if self.model_path is not None: + self.policy.learn_mode.load_state_dict(torch.load(self.model_path, map_location=self.cfg.policy.device)) + # --- 优化: 初始化时就同步策略到collector和evaluator --- + self.policy.collect_mode.load_state_dict(self.policy.learn_mode.state_dict()) + self.policy.eval_mode.load_state_dict(self.policy.learn_mode.state_dict()) + + if self.cfg.policy.enable_async_debug_log: + logging.info("[ASYNC_DEBUG] Policy created and validated successfully") + + self.tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(self.cfg.exp_name), 'async')) if get_rank() == 0 else None + self.learner = BaseLearner(self.cfg.policy.learn.learner, self.policy.learn_mode, self.tb_logger, exp_name=self.cfg.exp_name) + + GameBuffer = getattr(__import__('lzero.mcts', fromlist=['UniZeroGameBuffer']), 'UniZeroGameBuffer') + self.replay_buffer = GameBuffer(self.cfg.policy) + + if self.cfg.policy.enable_async_debug_log: + logging.info("[ASYNC_DEBUG] Learner and replay buffer created successfully") + + self.collector = Collector( + env=self.collector_env, + policy=self.policy.collect_mode, + tb_logger=self.tb_logger, + exp_name=self.cfg.exp_name, + policy_config=self.cfg.policy + ) + self.evaluator = Evaluator( + eval_freq=self.cfg.policy.eval_freq, + n_evaluator_episode=self.cfg.env.n_evaluator_episode, + stop_value=self.cfg.env.stop_value, + env=self.evaluator_env, + policy=self.policy.eval_mode, + tb_logger=self.tb_logger, + exp_name=self.cfg.exp_name, + policy_config=self.cfg.policy + ) + + if self.cfg.policy.enable_async_debug_log: + logging.info("[ASYNC_DEBUG] Collector and evaluator created successfully") + + self.learner.call_hook('before_run') + + if self.cfg.policy.enable_async_debug_log: + logging.info("[ASYNC_DEBUG] All components initialized successfully") + + except Exception as e: + logging.error(f"[ASYNC_DEBUG] Component initialization error: {e}") + import traceback + logging.error(f"[ASYNC_DEBUG] Initialization error traceback: {traceback.format_exc()}") + raise + + def _should_reanalyze_buffer(self, update_per_collect: int, training_iteration: int) -> bool: + if not hasattr(self.cfg.policy, 'buffer_reanalyze_freq') or self.cfg.policy.buffer_reanalyze_freq <= 0: + return False + if self.cfg.policy.buffer_reanalyze_freq >= 1: + reanalyze_interval = update_per_collect // self.cfg.policy.buffer_reanalyze_freq + return training_iteration > 0 and training_iteration % reanalyze_interval == 0 + else: + if self.train_epoch > 0 and self.train_epoch % int(1/self.cfg.policy.buffer_reanalyze_freq) == 0: + min_transitions = int(self.reanalyze_batch_size / getattr(self.cfg.policy, 'reanalyze_partition', 0.75)) + return self.replay_buffer.get_num_of_transitions() // self.cfg.policy.num_unroll_steps > min_transitions + return False + + def _perform_buffer_reanalyze(self): + try: + with timer: + self.replay_buffer.reanalyze_buffer(self.reanalyze_batch_size, self.policy) + self.buffer_reanalyze_count += 1 + if self.cfg.policy.enable_async_debug_log: + logging.info(f"[ASYNC_DEBUG] Buffer reanalyze #{self.buffer_reanalyze_count} completed, " + f"time={timer.value:.3f}s, buffer_transitions={self.replay_buffer.get_num_of_transitions()}") + else: + logging.info(f'Buffer reanalyze count: {self.buffer_reanalyze_count}, time: {timer.value:.3f}s') + except Exception as e: + logging.error(f"[ASYNC_DEBUG] Buffer reanalyze error: {e}") + import traceback + logging.error(f"[ASYNC_DEBUG] Buffer reanalyze error traceback: {traceback.format_exc()}") + + def _collector_worker(self): + if self.cfg.policy.enable_async_debug_log: + logging.info("[ASYNC_DEBUG] Collector worker started") + collection_count = 0 + while not self.stop_event.is_set(): + try: + collection_start_time = time.time() + + collect_kwargs = { + 'temperature': visit_count_temperature( + self.cfg.policy.manual_temperature_decay, + self.cfg.policy.fixed_temperature_value, + self.cfg.policy.threshold_training_steps_for_final_temperature, + trained_steps=self.train_iter + ), + 'epsilon': 0.0 + } + if self.cfg.policy.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=self.cfg.policy.eps.start, + end=self.cfg.policy.eps.end, + decay=self.cfg.policy.eps.decay, + type_=self.cfg.policy.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(self.env_step) + + # TODO(pu): collector是收集一个game_segment就返回, model 更新后 kv_cache 也应该更新? 应该是在trainer向collector同步模型后clear kv cache + try: + if hasattr(self.policy._collect_model, 'world_model'): + with self.policy_lock: + self.policy._collect_model.world_model.clear_caches() + except Exception as e: + logging.error(f"[ASYNC_DEBUG] Error clearing evaluator caches: {e}", exc_info=True) + + new_data = self.collector.collect(train_iter=self.train_iter, policy_kwargs=collect_kwargs) + + if new_data is None or len(new_data) == 0: + if self.cfg.policy.enable_async_debug_log: + logging.warning("[ASYNC_DEBUG] Collector: collected data is None or empty, retrying...") + time.sleep(0.5) # 短暂等待以避免空转 + continue + + collection_time = time.time() - collection_start_time + collection_count += 1 + + try: + self.data_queue.put((new_data, self.env_step), timeout=5.0) + except queue.Full: + if self.cfg.policy.enable_async_debug_log: + logging.warning("[ASYNC_DEBUG] Collector: data queue is full, dropping data") + continue + + self.env_step = self.collector.envstep + + if self.env_step >= self.max_env_step: + if self.cfg.policy.enable_async_debug_log: + logging.info(f"[ASYNC_DEBUG] Collector reached max_env_step: {self.env_step}") + self.stop_event.set() + break + + except Exception as e: + logging.error(f"[ASYNC_DEBUG] Collector worker error: {e}") + import traceback + logging.error(f"[ASYNC_DEBUG] Collector error traceback: {traceback.format_exc()}") + time.sleep(1.0) + continue + + if self.cfg.policy.enable_async_debug_log: + logging.info(f"[ASYNC_DEBUG] Collector worker stopped, total collections: {collection_count}") + + def _learner_worker(self): + """学习器工作线程""" + if self.cfg.policy.enable_async_debug_log: + logging.info("[ASYNC_DEBUG] Learner worker started") + + training_count = 0 + while not self.stop_event.is_set(): + try: + try: + new_data, data_env_step = self.data_queue.get(timeout=1.0) + except queue.Empty: + if self.stop_event.is_set(): + break + continue + + self.replay_buffer.push_game_segments(new_data) + self.replay_buffer.remove_oldest_data_to_fit() + + update_per_collect = calculate_update_per_collect(self.cfg, new_data, 1) + + batch_size = self.cfg.policy.batch_size + for i in range(update_per_collect): + if self.stop_event.is_set(): + break + + if self._should_reanalyze_buffer(update_per_collect, i): + self._perform_buffer_reanalyze() + + if self.replay_buffer.get_num_of_transitions() > batch_size: + try: + train_data = self.replay_buffer.sample(batch_size, self.policy) + if train_data is None: + logging.warning("[ASYNC_DEBUG] Learner: sampled train_data is None") + break + train_data.append(self.train_iter) + except Exception as sample_error: + logging.error(f"[ASYNC_DEBUG] Learner sampling error: {sample_error}") + break + + log_vars = self.learner.train(train_data, data_env_step) + self.train_iter += 1 + training_count += 1 + + if self.cfg.policy.use_priority: + self.replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + + # --- 触发评估事件 --- + if self.evaluator.should_eval(self.train_iter): + if not self.eval_event.is_set(): + if self.cfg.policy.enable_async_debug_log: + logging.info(f"[ASYNC_DEBUG] Learner is signaling evaluator at train_iter: {self.train_iter}") + self.eval_event.set() + else: + break + + + # learner 更新完一个epoch的后:# TODO(pu) + with self.policy_lock: + new_sd = self.policy._model.state_dict() + self.policy._collect_model.load_state_dict(new_sd, strict=False) + self.policy._eval_model.load_state_dict(new_sd, strict=False) + + # precompute positional embedding matrices in the learn model + self.policy.recompute_pos_emb_diff_for_async() # TODO(pu) + + # TODO(pu): collector是收集一个game_segment就返回, model 更新后 kv_cache 也应该更新? + # 应该是在trainer向collector同步模型后clear kv cache, 但目前打开会导致多线程读写错误 + # self.policy.recompute_pos_emb_diff_and_clear_cache_for_async() # TODO(pu) + + + self.train_epoch += 1 + + if self.train_iter >= self.max_train_iter: + self.stop_event.set() + break + + except Exception as e: + logging.error(f"[ASYNC_DEBUG] Learner worker error: {e}") + import traceback + logging.error(f"[ASYNC_DEBUG] Learner error traceback: {traceback.format_exc()}") + time.sleep(2.0) + continue + + if self.cfg.policy.enable_async_debug_log: + logging.info(f"[ASYNC_DEBUG] Learner worker stopped, total training iterations: {training_count}") + + def _evaluator_worker(self): + """评估器工作线程""" + if self.cfg.policy.enable_async_debug_log: + logging.info("[ASYNC_DEBUG] Evaluator worker started") + + while not self.stop_event.is_set(): + try: + # --- 优化: 使用事件等待,而不是固定时间休眠 --- + # 等待学习器发出评估信号,设置超时以防万一 + eval_triggered = self.eval_event.wait(timeout=60.0) + + if eval_triggered: + self.eval_event.clear() # 清除事件,等待下一次信号 + + if self.cfg.policy.enable_async_debug_log: + logging.info(f"[ASYNC_DEBUG] Evaluator received signal, starting evaluation at train_iter: {self.train_iter}") + + # TODO(pu): model 更新后 kv_cache 也应该更新, evaluator是评估完整的几局 + try: + if hasattr(self.policy._eval_model, 'world_model'): + with self.policy_lock: + self.policy._eval_model.world_model.clear_caches() + except Exception as e: + logging.error(f"[ASYNC_DEBUG] Error clearing evaluator caches: {e}", exc_info=True) + + stop, episode_info = self.evaluator.eval( + self.learner.save_checkpoint, + self.train_iter, + self.env_step + ) + + if episode_info is not None: + if isinstance(episode_info, dict) and 'eval_episode_return_mean' in episode_info: + reward = episode_info['eval_episode_return_mean'] + if reward > self.best_reward: + self.best_reward = reward + else: + logging.warning(f"[ASYNC_DEBUG] Evaluator: unexpected episode_info format: {type(episode_info)}") + logging.warning(f"[ASYNC_DEBUG] Evaluator: episode_info: {episode_info}") + + if stop: + self.stop_event.set() + break + else: + # 超时后,检查stop_event是否被设置 + if self.stop_event.is_set(): + break + + except Exception as e: + logging.error(f"[ASYNC_DEBUG] Evaluator worker error: {e}") + import traceback + logging.error(f"[ASYNC_DEBUG] Evaluator error traceback: {traceback.format_exc()}") + time.sleep(1.0) + continue + + if self.cfg.policy.enable_async_debug_log: + logging.info(f"[ASYNC_DEBUG] Evaluator worker stopped.") + + def train(self): + """开始异步训练""" + if self.cfg.policy.enable_async_debug_log: + logging.info("[ASYNC_DEBUG] Starting async training...") + + collector_thread = threading.Thread(target=self._collector_worker, name="Collector") + learner_thread = threading.Thread(target=self._learner_worker, name="Learner") + evaluator_thread = threading.Thread(target=self._evaluator_worker, name="Evaluator") + + threads = [collector_thread, learner_thread, evaluator_thread] + + try: + for t in threads: + t.start() + + # 主线程可以监控线程状态或等待中断 + for t in threads: + # 使用 join 来等待线程结束,可以设置一个很长的超时 + # 无限期等待,直到线程自己结束或被中断 + t.join() + + except KeyboardInterrupt: + logging.info("[ASYNC_DEBUG] KeyboardInterrupt received, shutting down workers...") + except Exception as e: + logging.error(f"[ASYNC_DEBUG] Main thread error: {e}") + import traceback + logging.error(f"[ASYNC_DEBUG] Main thread error traceback: {traceback.format_exc()}") + finally: + # 确保设置停止事件,通知所有线程退出 + self.stop_event.set() + # 唤醒可能在等待的评估器线程,使其能检查到stop_event + self.eval_event.set() + + logging.info("[ASYNC_DEBUG] Waiting for worker threads to terminate...") + for t in threads: + if t.is_alive(): + t.join(timeout=5.0) # 给5秒钟时间正常退出 + if t.is_alive(): + logging.warning(f"[ASYNC_DEBUG] Thread {t.name} did not terminate gracefully.") + + # --- 优化: 显式关闭资源 --- + if self.tb_logger: + self.tb_logger.close() + + try: + self.learner.call_hook('after_run') + except Exception as e: + logging.error(f"[ASYNC_DEBUG] After run hook error: {e}") + + if self.cfg.policy.use_wandb: + try: + wandb.finish() + except Exception as e: + logging.error(f"[ASYNC_DEBUG] Wandb finish error: {e}") + + logging.info(f"[ASYNC_DEBUG] Async training finished. Final stats: " + f"train_iter={self.train_iter}, env_step={self.env_step}, best_reward={self.best_reward}") + + return self.policy + +# train_unizero_segment_async 函数保持不变,因为它只是入口点 +def train_unizero_segment_async( + input_cfg: Tuple[dict, dict], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + cfg, create_cfg = input_cfg + assert create_cfg.policy.type in ['unizero', 'sampled_unizero'], "train_unizero entry now only supports the following algo.: 'unizero', 'sampled_unizero'" + trainer = AsyncTrainer(cfg, create_cfg, model, model_path, max_train_iter, max_env_step) + return trainer.train() \ No newline at end of file diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py index e8df2a6e0..eb6d87e95 100644 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -15,6 +15,7 @@ from .tokenizer import Tokenizer from .transformer import Transformer, TransformerConfig from .utils import LossWithIntermediateLosses, init_weights, WorldModelOutput, hash_state +import threading logging.getLogger().setLevel(logging.DEBUG) @@ -39,6 +40,7 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: - tokenizer (:obj:`Tokenizer`): The tokenizer. """ super().__init__() + # self.cache_lock = threading.Lock() self.tokenizer = tokenizer self.config = config self.transformer = Transformer(self.config) @@ -152,6 +154,18 @@ def custom_init(module): self.reanalyze_phase = False + # # ---------- 关键补丁 ---------- + # def __getstate__(self): + # state = self.__dict__.copy() + # # lock 对象无法被 pickle,序列化时去掉 + # state.pop('cache_lock', None) + # return state + + # def __setstate__(self, state): + # self.__dict__.update(state) + # # 反序列化后重新创建 lock + # self.cache_lock = threading.Lock() + def _get_final_norm(self, norm_option: str) -> nn.Module: """ Return the corresponding normalization module based on the specified normalization option. @@ -480,7 +494,7 @@ def forward( - logits for policy. - logits for value. """ - + # with self.cache_lock: # Calculate previous steps based on key-value caching configuration if kvcache_independent: # If kv caching is independent, compute previous steps for each past key-value pair. @@ -655,8 +669,24 @@ def _add_position_embeddings(self, embeddings, prev_steps, num_steps, kvcache_in return embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)) else: valid_context_lengths = torch.tensor(self.keys_values_wm_size_list_current, device=self.device) + + # if valid_context_lengths is None: + # raise RuntimeError("valid_context_lengths should not be None here") + # # 保险:保证长度与 batch 一致 + # if valid_context_lengths.numel() != embeddings.size(0): + # valid_context_lengths = valid_context_lengths[:embeddings.size(0)] + position_embeddings = self.pos_emb( valid_context_lengths + torch.arange(num_steps, device=self.device)).unsqueeze(1) + try: + embeddings + position_embeddings + except Exception as e: + print(f'Error: {e}') + print(f'position_embeddings.shape: {position_embeddings.shape}, embeddings.shape: {embeddings.shape}') + print(f'valid_context_lengths: {valid_context_lengths}') + print(f'prev_steps: {prev_steps}') + print(f'num_steps: {num_steps}') + raise e return embeddings + position_embeddings def _process_obs_act_combined_cont(self, obs_embeddings_or_act_tokens, prev_steps): @@ -1254,7 +1284,9 @@ def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, # If not found, try to retrieve from past_kv_cache_recurrent_infer if matched_value is None: - matched_value = self.shared_pool_recur_infer[self.past_kv_cache_recurrent_infer.get(cache_key)] + cache_index = self.past_kv_cache_recurrent_infer.get(cache_key) + if cache_index is not None: + matched_value = self.shared_pool_recur_infer[cache_index] if matched_value is not None: # If a matching cache is found, add it to the lists @@ -1821,11 +1853,16 @@ def clear_caches(self): """ Clears the caches of the world model. """ + # with self.cache_lock: for kv_cache_dict_env in self.past_kv_cache_init_infer_envs: kv_cache_dict_env.clear() self.past_kv_cache_recurrent_infer.clear() self.keys_values_wm_list.clear() + + self.keys_values_wm_size_list_current = [] + print(f'Cleared {self.__class__.__name__} past_kv_cache.') + def __repr__(self) -> str: return "transformer-based latent world_model of UniZero" diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index f2bfc48f9..dde9a6390 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -341,8 +341,8 @@ def _init_learn(self) -> None: ) self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) - assert self.value_support.size == self._learn_model.value_support_size # if these assertions fails, somebody introduced... - assert self.reward_support.size == self._learn_model.reward_support_size # ...incoherence between policy and model + # assert self.value_support.size == self._learn_model.value_support_size # if these assertions fails, somebody introduced... + # assert self.reward_support.size == self._learn_model.reward_support_size # ...incoherence between policy and model self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) @@ -1030,4 +1030,32 @@ def recompute_pos_emb_diff_and_clear_cache(self) -> None: # If rotary_emb is False, nn.Embedding is used for absolute position encoding. model.world_model.precompute_pos_emb_diff_kv() model.world_model.clear_caches() + torch.cuda.empty_cache() + + def recompute_pos_emb_diff_for_async(self) -> None: + """ + Overview: + Clear the caches and precompute positional embedding matrices in the model. + """ + # for model in [self._collect_model, self._target_model, self._eval_model, self._learn_model]: + for model in [self._learn_model]: + if not self._cfg.model.world_model_cfg.rotary_emb: + # If rotary_emb is False, nn.Embedding is used for absolute position encoding. + model.world_model.precompute_pos_emb_diff_kv() + + torch.cuda.empty_cache() + + def recompute_pos_emb_diff_and_clear_cache_for_async(self) -> None: + """ + Overview: + Clear the caches and precompute positional embedding matrices in the model. + """ + # for model in [self._collect_model, self._target_model, self._eval_model, self._learn_model]: + for model in [self._learn_model]: + if not self._cfg.model.world_model_cfg.rotary_emb: + # If rotary_emb is False, nn.Embedding is used for absolute position encoding. + model.world_model.precompute_pos_emb_diff_kv() + + for model in [self._collect_model, self._target_model]: + model.world_model.clear_caches() torch.cuda.empty_cache() \ No newline at end of file diff --git a/zoo/classic_control/cartpole/config/cartpole_unizero_segment_async_config.py b/zoo/classic_control/cartpole/config/cartpole_unizero_segment_async_config.py new file mode 100644 index 000000000..501731802 --- /dev/null +++ b/zoo/classic_control/cartpole/config/cartpole_unizero_segment_async_config.py @@ -0,0 +1,152 @@ +from easydict import EasyDict +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +# collector_env_num = 8 +# num_segments = 8 +# n_episode = 8 + +collector_env_num = 3 +num_segments = 3 +n_episode = 3 + +game_segment_length = 20 +evaluator_env_num = 3 +num_simulations = 25 +# num_simulations = 10 # TODO + +update_per_collect = None +replay_ratio = 0.25 +max_env_step = int(1e5) +# max_env_step = int(2e3) # TODO + +batch_size = 256 +num_unroll_steps = 5 +reanalyze_ratio = 0. + +# Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. +# buffer_reanalyze_freq = 1/50 +buffer_reanalyze_freq = 1/50000000 +# buffer_reanalyze_freq = 1 # TODO +# Each reanalyze process will reanalyze sequences ( transitions per sequence) +reanalyze_batch_size = 160 +# The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. +reanalyze_partition = 0.75 + +# ============= 异步训练相关配置 ============= +# 是否启用异步训练 +enable_async_training = True +# 数据缓冲队列大小 +data_queue_size = 20 +# 是否输出异步训练的详细调试信息 +# enable_async_debug_log = True +enable_async_debug_log = False + +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== +cartpole_unizero_config = dict( + exp_name=f'data_unizero_async/cartpole_unizero_ns{num_simulations}_upc{update_per_collect}-rr{replay_ratio}_H{num_unroll_steps}_bs{batch_size}_seed0', + env=dict( + env_name='CartPole-v0', + continuous=False, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000, ), ), ), + model=dict( + observation_shape=4, + action_space_size=2, + self_supervised_learning_loss=True, # NOTE: default is False. + discrete_action_encoding_type='one_hot', + norm_type='BN', + model_type='mlp', + world_model_cfg=dict( + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', + max_blocks=10, + max_tokens=2 * 10, + context_length=2 * 4, + context_length_for_recurrent=2 * 4, + # device='cuda', + device='cpu', + action_space_size=2, + num_layers=2, + num_heads=2, + embed_dim=64, + env_num=max(collector_env_num, evaluator_env_num), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + obs_type='vector', + norm_type='BN', + # rotary_emb=True, + rotary_emb=False, + ), + ), + use_wandb=False, + # (str) The path of the pretrained model. If None, the model will be initialized by the default model. + model_path=None, + num_unroll_steps=num_unroll_steps, + # cuda=True, + # device='cuda', + cuda=False, + device='cpu', + use_augmentation=False, + env_type='not_board_games', + num_segments=num_segments, + game_segment_length=game_segment_length, + replay_ratio=replay_ratio, + batch_size=batch_size, + optim_type='AdamW', + piecewise_decay_lr_scheduler=False, + learning_rate=0.0001, + target_update_freq=100, + grad_clip_value=5, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + # eval_freq=int(1e4), + eval_freq=int(1e3), + # eval_freq=int(50), # TODO + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # ============= The key different params for reanalyze ============= + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + buffer_reanalyze_freq=buffer_reanalyze_freq, + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size=reanalyze_batch_size, + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition=reanalyze_partition, + # ============= 异步训练配置 ============= + enable_async_training=enable_async_training, + data_queue_size=data_queue_size, + enable_async_debug_log=enable_async_debug_log, + ), +) + +cartpole_unizero_config = EasyDict(cartpole_unizero_config) +main_config = cartpole_unizero_config + +cartpole_unizero_create_config = dict( + env=dict( + type='cartpole_lightzero', + import_names=['zoo.classic_control.cartpole.envs.cartpole_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero', + import_names=['lzero.policy.unizero'], + ), +) +cartpole_unizero_create_config = EasyDict(cartpole_unizero_create_config) +create_config = cartpole_unizero_create_config + +if __name__ == "__main__": + from lzero.entry import train_unizero_segment_async + train_unizero_segment_async([main_config, create_config], seed=0, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/classic_control/cartpole/config/cartpole_unizero_segment_config.py b/zoo/classic_control/cartpole/config/cartpole_unizero_segment_config.py new file mode 100644 index 000000000..774e114f2 --- /dev/null +++ b/zoo/classic_control/cartpole/config/cartpole_unizero_segment_config.py @@ -0,0 +1,120 @@ +from easydict import EasyDict +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +num_segments = 8 +game_segment_length = 20 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 25 +update_per_collect = None +replay_ratio = 0.25 +max_env_step = int(2e5) +batch_size = 256 +num_unroll_steps = 5 +reanalyze_ratio = 0. +# Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. +buffer_reanalyze_freq = 1/50 +# Each reanalyze process will reanalyze sequences ( transitions per sequence) +reanalyze_batch_size = 160 +# The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. +reanalyze_partition = 0.75 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== +cartpole_unizero_config = dict( + exp_name=f'data_unizero/cartpole_unizero_ns{num_simulations}_upc{update_per_collect}-rr{replay_ratio}_H{num_unroll_steps}_bs{batch_size}_seed0', + env=dict( + env_name='CartPole-v0', + continuous=False, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000, ), ), ), + model=dict( + observation_shape=4, + action_space_size=2, + self_supervised_learning_loss=True, # NOTE: default is False. + discrete_action_encoding_type='one_hot', + norm_type='BN', + model_type='mlp', + world_model_cfg=dict( + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', + max_blocks=10, + max_tokens=2 * 10, + context_length=2 * 4, + context_length_for_recurrent=2 * 4, + device='cuda', + action_space_size=2, + num_layers=2, + num_heads=2, + embed_dim=64, + env_num=max(collector_env_num, evaluator_env_num), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + obs_type='vector', + norm_type='BN', + # rotary_emb=True, + rotary_emb=False, + ), + ), + use_wandb=False, + # (str) The path of the pretrained model. If None, the model will be initialized by the default model. + model_path=None, + num_unroll_steps=num_unroll_steps, + cuda=True, + use_augmentation=False, + env_type='not_board_games', + num_segments=num_segments, + game_segment_length=game_segment_length, + replay_ratio=replay_ratio, + batch_size=batch_size, + optim_type='AdamW', + piecewise_decay_lr_scheduler=False, + learning_rate=0.0001, + target_update_freq=100, + grad_clip_value=5, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(1e3), + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # ============= The key different params for reanalyze ============= + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + buffer_reanalyze_freq=buffer_reanalyze_freq, + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size=reanalyze_batch_size, + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition=reanalyze_partition, + ), +) + +cartpole_unizero_config = EasyDict(cartpole_unizero_config) +main_config = cartpole_unizero_config + +cartpole_unizero_create_config = dict( + env=dict( + type='cartpole_lightzero', + import_names=['zoo.classic_control.cartpole.envs.cartpole_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero', + import_names=['lzero.policy.unizero'], + ), +) +cartpole_unizero_create_config = EasyDict(cartpole_unizero_create_config) +create_config = cartpole_unizero_create_config + +if __name__ == "__main__": + from lzero.entry import train_unizero_segment + train_unizero_segment([main_config, create_config], seed=0, max_env_step=max_env_step)