-
Notifications
You must be signed in to change notification settings - Fork 180
feature(xjy): add the rnd-related features #438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
6a92678
1cf8688
e9314d1
ac58169
b7015d8
0eb9792
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,258 @@ | ||
| import logging | ||
| import os | ||
| from functools import partial | ||
| from typing import Tuple, Optional | ||
|
|
||
| import torch | ||
| import wandb | ||
| from ding.config import compile_config | ||
| from ding.envs import create_env_manager | ||
| from ding.envs import get_vec_env_setting | ||
| from ding.policy import create_policy | ||
| from ding.rl_utils import get_epsilon_greedy_fn | ||
| from ding.utils import EasyTimer | ||
| from ding.utils import set_pkg_seed, get_rank, get_world_size | ||
| from ding.worker import BaseLearner | ||
| from tensorboardX import SummaryWriter | ||
| from torch.utils.tensorboard import SummaryWriter | ||
|
|
||
| from lzero.entry.utils import log_buffer_memory_usage | ||
| from lzero.policy import visit_count_temperature | ||
| from lzero.policy.random_policy import LightZeroRandomPolicy | ||
| from lzero.worker import MuZeroEvaluator as Evaluator | ||
| from lzero.worker import MuZeroSegmentCollector as Collector | ||
| from lzero.reward_model.rnd_reward_model import RNDRewardModel | ||
| from .utils import random_collect, calculate_update_per_collect | ||
|
|
||
| timer = EasyTimer() | ||
|
|
||
| def train_unizero_segment_with_reward_model( | ||
| input_cfg: Tuple[dict, dict], | ||
| seed: int = 0, | ||
| model: Optional[torch.nn.Module] = None, | ||
| model_path: Optional[str] = None, | ||
| max_train_iter: Optional[int] = int(1e10), | ||
| max_env_step: Optional[int] = int(1e10), | ||
| ) -> 'Policy': | ||
| """ | ||
| Overview: | ||
| The train entry for UniZero (with muzero_segment_collector and buffer reanalyze trick), proposed in our paper UniZero: Generalized and Efficient Planning with Scalable Latent World Models. | ||
| UniZero aims to enhance the planning capabilities of reinforcement learning agents by addressing the limitations found in MuZero-style algorithms, | ||
| particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667. | ||
| Arguments: | ||
| - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type. | ||
| ``Tuple[dict, dict]`` type means [user_config, create_cfg]. | ||
| - seed (:obj:`int`): Random seed. | ||
| - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. | ||
| - model_path (:obj:`Optional[str]`): The pretrained model path, which should | ||
| point to the ckpt file of the pretrained model, and an absolute path is recommended. | ||
| In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. | ||
| - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. | ||
| - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. | ||
| Returns: | ||
| - policy (:obj:`Policy`): Converged policy. | ||
| """ | ||
|
|
||
| cfg, create_cfg = input_cfg | ||
|
|
||
| # Ensure the specified policy type is supported | ||
| assert create_cfg.policy.type in ['unizero', 'sampled_unizero'], "train_unizero entry now only supports the following algo.: 'unizero', 'sampled_unizero'" | ||
| assert cfg.policy.use_rnd_model, "cfg.policy.use_rnd_model must be True to use RND reward model" | ||
|
|
||
| # Import the correct GameBuffer class based on the policy type | ||
| game_buffer_classes = {'unizero': 'UniZeroGameBuffer', 'sampled_unizero': 'SampledUniZeroGameBuffer'} | ||
|
|
||
| GameBuffer = getattr(__import__('lzero.mcts', fromlist=[game_buffer_classes[create_cfg.policy.type]]), | ||
| game_buffer_classes[create_cfg.policy.type]) | ||
|
|
||
| # Set device based on CUDA availability | ||
| cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' | ||
| logging.info(f'cfg.policy.device: {cfg.policy.device}') | ||
|
|
||
| # Compile the configuration | ||
| cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) | ||
|
|
||
| # Create main components: env, policy | ||
| env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) | ||
| collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) | ||
| evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) | ||
|
|
||
| collector_env.seed(cfg.seed) | ||
| evaluator_env.seed(cfg.seed, dynamic_seed=False) | ||
| set_pkg_seed(cfg.seed, use_cuda=torch.cuda.is_available()) | ||
|
|
||
| policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) | ||
|
|
||
| # Load pretrained model if specified | ||
| if model_path is not None: | ||
| logging.info(f'Loading model from {model_path} begin...') | ||
| policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) | ||
| logging.info(f'Loading model from {model_path} end!') | ||
|
|
||
| # Create worker components: learner, collector, evaluator, replay buffer, commander | ||
| tb_logger = None | ||
| if get_rank() == 0: | ||
| tb_logger = SummaryWriter(os.path.join(f'./{cfg.exp_name}/log/', 'serial')) | ||
| learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) | ||
|
|
||
| # MCTS+RL algorithms related core code | ||
| policy_config = cfg.policy | ||
| replay_buffer = GameBuffer(policy_config) | ||
| collector = Collector(env=collector_env, policy=policy.collect_mode, tb_logger=tb_logger, exp_name=cfg.exp_name, | ||
| policy_config=policy_config) | ||
| evaluator = Evaluator(eval_freq=cfg.policy.eval_freq, n_evaluator_episode=cfg.env.n_evaluator_episode, | ||
| stop_value=cfg.env.stop_value, env=evaluator_env, policy=policy.eval_mode, | ||
| tb_logger=tb_logger, exp_name=cfg.exp_name, policy_config=policy_config) | ||
|
|
||
|
|
||
|
|
||
| # ============================================================== | ||
| # 新增: 初始化 RND 奖励模型 | ||
| # RNDRewardModel 需要策略模型中的表征网络(作为预测器)和目标表征网络(作为固定目标) | ||
| # 对于 UniZero,tokenizer 扮演了表征网络的功能。 | ||
| # ============================================================== | ||
| reward_model = RNDRewardModel( | ||
| config=cfg.reward_model, | ||
| device=policy.collect_mode.get_attribute('device'), | ||
| tb_logger=tb_logger, | ||
| exp_name=cfg.exp_name, | ||
| representation_network=policy._learn_model.representation_network, | ||
puyuan1996 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| target_representation_network=policy._target_model_for_intrinsic_reward.representation_network, | ||
| use_momentum_representation_network=cfg.policy.use_momentum_representation_network, | ||
| bp_update_sync=cfg.policy.bp_update_sync, | ||
| multi_gpu=cfg.policy.multi_gpu, | ||
| ) | ||
|
|
||
|
|
||
| # Learner's before_run hook | ||
| learner.call_hook('before_run') | ||
|
|
||
| if cfg.policy.use_wandb and get_rank() == 0: | ||
| policy.set_train_iter_env_step(learner.train_iter, collector.envstep) | ||
|
|
||
| # Collect random data before training | ||
| if cfg.policy.random_collect_data: | ||
| random_data = random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer) | ||
| try: | ||
| reward_model.warmup_with_random_segments(random_data) | ||
| except Exception as e: | ||
| logging.exception(f"Failed to warm up RND normalization using random data: {e}") | ||
| raise | ||
| batch_size = policy._cfg.batch_size | ||
|
|
||
| buffer_reanalyze_count = 0 | ||
| train_epoch = 0 | ||
| reanalyze_batch_size = cfg.policy.reanalyze_batch_size | ||
|
|
||
| if cfg.policy.multi_gpu: | ||
| # Get current world size and rank | ||
| world_size = get_world_size() | ||
| rank = get_rank() | ||
| else: | ||
| world_size = 1 | ||
| rank = 0 | ||
|
|
||
| while True: | ||
| # Log buffer memory usage | ||
| log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) | ||
|
|
||
| # Set temperature for visit count distributions | ||
| collect_kwargs = { | ||
| 'temperature': visit_count_temperature( | ||
| policy_config.manual_temperature_decay, | ||
| policy_config.fixed_temperature_value, | ||
| policy_config.threshold_training_steps_for_final_temperature, | ||
| trained_steps=learner.train_iter | ||
| ), | ||
| 'epsilon': 0.0 # Default epsilon value | ||
| } | ||
|
|
||
| # Configure epsilon for epsilon-greedy exploration | ||
| if policy_config.eps.eps_greedy_exploration_in_collect: | ||
| epsilon_greedy_fn = get_epsilon_greedy_fn( | ||
| start=policy_config.eps.start, | ||
| end=policy_config.eps.end, | ||
| decay=policy_config.eps.decay, | ||
| type_=policy_config.eps.type | ||
| ) | ||
| collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) | ||
|
|
||
| # Evaluate policy performance | ||
| # if evaluator.should_eval(learner.train_iter): | ||
| # stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) | ||
| # if stop: | ||
| # break | ||
|
|
||
| # Collect new data | ||
| new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) | ||
|
|
||
| # Determine updates per collection | ||
| update_per_collect = calculate_update_per_collect(cfg, new_data, world_size) | ||
|
|
||
| # Update replay buffer | ||
| replay_buffer.push_game_segments(new_data) | ||
| replay_buffer.remove_oldest_data_to_fit() | ||
|
|
||
| # Periodically reanalyze buffer | ||
| if cfg.policy.buffer_reanalyze_freq >= 1: | ||
| # Reanalyze buffer <buffer_reanalyze_freq> times in one train_epoch | ||
| reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq | ||
| else: | ||
| # Reanalyze buffer each <1/buffer_reanalyze_freq> train_epoch | ||
| if train_epoch > 0 and train_epoch % int(1/cfg.policy.buffer_reanalyze_freq) == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition): | ||
| with timer: | ||
| # Each reanalyze process will reanalyze <reanalyze_batch_size> sequences (<cfg.policy.num_unroll_steps> transitions per sequence) | ||
| replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) | ||
| buffer_reanalyze_count += 1 | ||
| logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') | ||
| logging.info(f'Buffer reanalyze time: {timer.value}') | ||
|
|
||
| # Train the policy if sufficient data is available | ||
| if collector.envstep > cfg.policy.train_start_after_envsteps: | ||
| if cfg.policy.sample_type == 'episode': | ||
| data_sufficient = replay_buffer.get_num_of_game_segments() > batch_size | ||
| else: | ||
| data_sufficient = replay_buffer.get_num_of_transitions() > batch_size | ||
| if not data_sufficient: | ||
| logging.warning( | ||
| f'The data in replay_buffer is not sufficient to sample a mini-batch: ' | ||
| f'batch_size: {batch_size}, replay_buffer: {replay_buffer}. Continue to collect now ....' | ||
| ) | ||
| continue | ||
|
|
||
| for i in range(update_per_collect): | ||
| if cfg.policy.buffer_reanalyze_freq >= 1: | ||
| # Reanalyze buffer <buffer_reanalyze_freq> times in one train_epoch | ||
| if i % reanalyze_interval == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition): | ||
| with timer: | ||
| # Each reanalyze process will reanalyze <reanalyze_batch_size> sequences (<cfg.policy.num_unroll_steps> transitions per sequence) | ||
| replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) | ||
| buffer_reanalyze_count += 1 | ||
| logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') | ||
| logging.info(f'Buffer reanalyze time: {timer.value}') | ||
|
|
||
| train_data = replay_buffer.sample(batch_size, policy) | ||
| if cfg.policy.use_wandb: | ||
| policy.set_train_iter_env_step(learner.train_iter, collector.envstep) | ||
|
|
||
| train_data_augmented = reward_model.estimate(train_data) | ||
| train_data_augmented.append(learner.train_iter) | ||
|
|
||
| log_vars = learner.train(train_data_augmented, collector.envstep) | ||
| reward_model.train_with_policy_batch(train_data) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 应该reward_model先训一些iters 然后unizero用训好的rnd网络估计融合奖励 再去训unizero的网络,目前这个版本相当于融合奖励每个迭代都在变化,对于unizero这边的学习来说太不平稳了?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 对,目前加上了之前讨论的那个参数自适应,初始阶段为0,一段时间后慢慢升上来,这样的话初始阶段相当于只是训练了RND网络,但是没用到内在奖励
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 目前新跑的都是用了这个方法吗 |
||
| logging.info(f'[{i}/{update_per_collect}]: learner and reward_model ended training step.') | ||
|
|
||
| if cfg.policy.use_priority: | ||
| replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) | ||
|
|
||
| train_epoch += 1 | ||
| policy.recompute_pos_emb_diff_and_clear_cache() | ||
|
|
||
| # Check stopping criteria | ||
| if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: | ||
| break | ||
|
|
||
| learner.call_hook('after_run') | ||
| if cfg.policy.use_wandb: | ||
| wandb.finish() | ||
| return policy | ||
Uh oh!
There was an error while loading. Please reload this page.