Skip to content

Commit a24a54d

Browse files
authored
polish(xjy): update configuration and log instructions in tutorials (#330)
* modify config_zh.md and config.md * Updated config and log documentation * polish(xjy): polish config_zh.md and config.md
1 parent ec60f8d commit a24a54d

File tree

4 files changed

+231
-75
lines changed

4 files changed

+231
-75
lines changed

docs/source/tutorials/config/config.md

Lines changed: 77 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,33 +13,92 @@ The `main_config` dictionary contains the main parameter settings for running th
1313
### 1.1 Main Parameters in the `env` Part
1414

1515
- `env_id`: Specifies the environment to be used.
16-
- `obs_shape`: The dimension of the environment observation.
16+
- `observation_shape`: The dimension of the environment's observations.
1717
- `collector_env_num`: The number of parallel environments used to collect data in the experience replay collector.
1818
- `evaluator_env_num`: The number of parallel environments used to evaluate policy performance in the evaluator.
19-
- `n_evaluator_episode`: The number of episodes run by each environment in the evaluator.
19+
- `n_evaluator_episode`: The total number of episodes run across all environments in the evaluator.
20+
- `collect_max_episode_steps`: The maximum number of steps allowed per episode during data collection.
21+
- `eval_max_episode_steps`: The maximum number of steps allowed per episode during evaluation.
22+
- `frame_stack_num`: The number of consecutive frames stacked together as input.
23+
- `gray_scale`: Whether to use grayscale images.
24+
- `scale`: Whether to scale the input data.
25+
- `clip_rewards`: Whether to clip reward values.
26+
- `episode_life`: If True, the game ends when the agent loses a life, otherwise, the game only ends when all lives are lost.
27+
- `env_type`: The type of environment.
28+
- `frame_skip`: The number of frames to repeat the same action.
29+
- `stop_value`: The target score that stops the training.
30+
- `replay_path`: Path to store the replay.
31+
- `save_replay`: Whether to save the replay video.
32+
- `channel_last`: Whether to put the channel dimension in the last dimension of the input data.
33+
- `warp_frame`: Whether to crop each frame of the picture.
2034
- `manager`: Specifies the type of environment manager, mainly used to control the parallelization mode of the environment.
2135

2236
### 1.2 Main Parameters in the `policy` Part
2337

24-
- `model`: Specifies the neural network model used by the policy, including the input dimension of the model, the number of frame stacking, the action space dimension of the model output, whether the model needs to use downsampling, whether to use self-supervised learning auxiliary loss, the action encoding type, the Normalization mode used in the network, etc.
25-
- `cuda`: Specifies whether to migrate the model to the GPU for training.
26-
- `reanalyze_noise`: Whether to introduce noise during MCTS reanalysis, which can increase exploration.
27-
- `env_type`: Marks the environment type faced by the MuZero algorithm. According to different environment types, the MuZero algorithm will have some differences in detail processing.
28-
- `game_segment_length`: The length of the sequence (game segment) used for self-play.
29-
- `random_collect_episode_num`: The number of randomly collected episodes, providing initial data for exploration.
30-
- `eps`: Exploration control parameters, including whether to use the epsilon-greedy method for control, the update method of control parameters, the starting value, the termination value, the decay rate, etc.
38+
- `model`: Specifies the neural network model used by the policy.
39+
- `model_type`: The type of model to use.
40+
- `observation_shape`: The dimensions of the observation space.
41+
- `action_space_size`: The size of the action space.
42+
- `continuous_action_space`: Whether the action space is continuous.
43+
- `num_res_blocks`: The number of residual blocks in the model.
44+
- `downsample`: Whether to downsample the input.
45+
- `norm_type`: The type of normalization used.
46+
- `num_channels`: The number of channels in the convolutional layers (number of features extracted).
47+
- `support_scale`: The range of the value support set (`-support_scale` to `support_scale`).
48+
- `bias`: Whether to use bias terms in the layers.
49+
- `discrete_action_encoding_type`: How discrete actions are encoded.
50+
- `self_supervised_learning_loss`: Whether to use a self-supervised learning loss (as in EfficientZero).
51+
- `image_channel`: The number of channels in the input image.
52+
- `frame_stack_num`: Number of frames stacked.
53+
- `gray_scale`: Whether to use gray images.
54+
- `use_sim_norm`: Whether to use SimNorm after the Latent State.
55+
- `use_sim_norm_kl_loss`: Whether the obs_loss corresponding to the Latent State after SimNorm uses KL divergence loss, which is often used together with SimNorm.
56+
- `res_connection_in_dynamics`: Whether to use the residual connection in the dynamics model.
57+
- `learn`: Configuration for the learning process.
58+
- `learner`: Configuration for the learner (dictionary type), including train iterations and checkpoint saving.
59+
- `resume_training`: Whether to resume training.
60+
- `collect`: Configuration for the collect process.
61+
- `collector`: Collector configuration (dictionary type), including type and print frequency.
62+
- `eval`: Configuration for the evaluation process
63+
- `evaluator`: Evaluator configuration (dictionary type), including evaluation frequency, number of episodes to evaluate, and path to save images.
64+
- `other`: Other configurations.
65+
- `replay_buffer`: Replay buffer configuration (dictionary type), including buffer size, maximum usage and staleness of experiences, and parameters for throughput control and monitoring.
66+
- `cuda`: Whether to use CUDA (GPU) for training.
67+
- `multi_gpu`: Whether to enable multi-GPU training.
68+
- `use_wandb`: Whether to use Weights & Biases (wandb) for logging.
69+
- `mcts_ctree`: Whether to use the C++ version of Monte Carlo Tree Search.
70+
- `collector_env_num`: The number of collection environments.
71+
- `evaluator_env_num`: The number of evaluation environments.
72+
- `env_type`: The type of environment (board game or non-board game).
73+
- `action_type`: The type of action space (fixed or other).
74+
- `game_segment_length`: The length corresponding to the basic unit game segment during collection.
75+
- `cal_dormant_ratio`: Whether to calculate the ratio of dormant neurons.
3176
- `use_augmentation`: Whether to use data augmentation.
32-
- `update_per_collect`: The number of updates after each data collection.
33-
- `batch_size`: The batch size sampled during the update.
34-
- `optim_type`: Optimizer type.
77+
- `augmentation`: The data augmentation methods to use.
78+
- `update_per_collect`: The number of model updates after each data collection phase.
79+
- `batch_size`: The batch size used for training updates.
80+
- `optim_type`: The type of optimizer.
81+
- `reanalyze_ratio`: The reanalyze ratio, which controls the probability to conduct reanalyze.
82+
- `reanalyze_noise`: Whether to introduce noise during MCTS reanalysis (for exploration).
83+
- `reanalyze_batch_size`: Reanalyze batch size.
84+
- `reanalyze_partition`: The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer.
85+
-`random_collect_episode_num`: Number of episodes of random collection, to provide initial exploration data.
86+
- `eps`: Parameters for exploration control, including whether to use epsilon-greedy, update schedules, start/end values, and decay rate.
3587
- `piecewise_decay_lr_scheduler`: Whether to use piecewise constant learning rate decay.
36-
- `learning_rate`: Initial learning rate.
88+
- `learning_rate`: The initial learning rate.
3789
- `num_simulations`: The number of simulations used in the MCTS algorithm.
38-
- `reanalyze_ratio`: Reanalysis coefficient, controlling the probability of reanalysis.
39-
- `ssl_loss_weight`: The weight of the self-supervised learning loss function.
40-
- `n_episode`: The number of episodes run by each environment in the parallel collector.
41-
- `eval_freq`: Policy evaluation frequency (measured by training steps).
42-
- `replay_buffer_size`: The capacity of the experience replay buffer.
90+
- `reward_loss_weight`: Weight for the reward loss.
91+
- `policy_loss_weight`: Weight for the policy loss.
92+
- `value_loss_weight`: Weight for the value loss.
93+
- `ssl_loss_weight`: The weight of the self-supervised learning loss.
94+
- `n_episode`: The number of episodes in parallel collector.
95+
- `eval_freq`: The frequency of policy evaluation (in terms of training steps).
96+
- `replay_buffer_size`: The capacity of the replay buffer.
97+
- `target_update_freq`: How often to update the target network.
98+
- `grad_clip_value`: Value to clip gradient.
99+
- `discount_factor`: Discount factor.
100+
- `td_steps`: TD steps.
101+
- `num_unroll_steps`: The number of rollout steps during MuZero training.
43102

44103
Two frequently changed parameter setting areas are also specially mentioned here, annotated by comments:
45104

docs/source/tutorials/config/config_zh.md

Lines changed: 76 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,33 +13,93 @@
1313
### 1.1 `env`部分的主要参数
1414

1515
- `env_id`: 指定要使用的环境。
16-
- `obs_shape`: 环境观测的维度。
16+
- `observation_shape`: 环境观测的维度。
1717
- `collector_env_num`: 经验回放采集器(collector)中并行用于收集数据的环境数目。
1818
- `evaluator_env_num`: 评估器(evaluator)中并行用于评估策略性能的环境数目。
19-
- `n_evaluator_episode`: 评估器中每个环境运行的episode数目。
19+
- `n_evaluator_episode`: 评估器中所有环境运行的总的 episode 数目。
20+
- `collect_max_episode_steps`: 收集数据时单个 episode 允许的最大步数。
21+
- `eval_max_episode_steps`: 评估时单个 episode 允许的最大步数。
22+
- `frame_stack_num`: 叠帧数。
23+
- `gray_scale`: 是否使用灰度图像。
24+
- `scale`: 是否缩放输入数据。
25+
- `clip_rewards`: 是否裁剪奖励值。
26+
- `episode_life`: 如果为 True,代理丢失一条生命时游戏结束;否则,游戏将在所有生命丢失时结束。
27+
- `env_type`: 环境类型。
28+
- `frame_skip`: 动作重复的帧数。
29+
- `stop_value`: 训练停止的目标分数。
30+
- `replay_path`: 经验回放的存储路径。
31+
- `save_replay`: 是否存储回放视频。
32+
- `channel_last`: 是否将 channel 维度放在输入数据的最后一维。
33+
- `warp_frame`: 是否裁剪每一帧的图片。
2034
- `manager`: 指定环境管理器的类型,主要用于控制环境的并行化方式。
2135

2236
### 1.2 `policy`部分的主要参数
23-
24-
- `model`: 指定策略所使用的神经网络模型,包含模型的输入维度、叠帧数、模型输出的动作空间维度、模型是否需要使用降采样、是否使用自监督学习辅助损失、动作编码类型、网络中使用的Normalization模式等。
25-
- `cuda`: 指定是否将模型迁移到GPU上进行训练。
26-
- `reanalyze_noise`: 是否在MCTS重分析时引入噪声,可以增加探索。
27-
- `env_type`: 标记MuZero算法所面对的环境类型,根据不同的环境类型,MuZero算法会在细节处理上有所不同。
28-
- `game_segment_length`: 用于自我博弈的序列(game segment)长度。
29-
- `random_collect_episode_num`: 随机采集的episode数量,为探索提供初始数据。
30-
- `eps`: 探索控制参数,包括是否使用epsilon-greedy方法进行控制,控制参数的更新方式、起始值、终止值、衰减速度等。
37+
- `model`: 指定策略所使用的神经网络模型。
38+
- `model_type`: 选择使用的模型类型。
39+
- `observation_shape`: 观测空间的维度。
40+
- `action_space_size`: 动作空间大小。
41+
- `continuous_action_space`: 动作空间是否是连续的。
42+
- `num_res_blocks`: 残差块的数量。
43+
- `downsample`: 是否进行降采样。
44+
- `norm_type`: 归一化使用的方法。
45+
- `num_channels`: 卷积层提取的特征个数。
46+
- `support_scale`: 价值支持集的范围 (-support_scale, support_scale)。
47+
- `bias`: 是否使用偏置。
48+
- `discrete_action_encoding_type`: 离散化动作空间使用的编码类型。
49+
- `self_supervised_learning_loss`: 是否使用自监督学习损失(参照EfficientZero的实现)。
50+
- `image_channel`: 输入图像通道数。
51+
- `frame_stack_num`: 堆叠帧数。
52+
- `gray_scale`: 是否使用灰度图像。
53+
- `use_sim_norm`: Latent State 后面是否使用 SimNorm。
54+
- `use_sim_norm_kl_loss`: Latent State 经过 SimNorm 后,对应的 obs_loss 是否使用 KL 散度损失,往往与 SimNorm 配合使用。
55+
- `res_connection_in_dynamics`: 动力学模型中是否使用残差连接。
56+
- `learn`: 学习过程配置
57+
- `learner`: 学习器配置(字典类型),包括训练迭代次数,检查点保存策略等信息。
58+
- `resume_training`: 是否恢复训练。
59+
- `collect`: 收集过程配置
60+
- `collector`: 收集器配置(字典类型),包括类型和输出频率等信息。
61+
- `eval`: 收集过程配置
62+
- `evaluator`: 评估器配置(字典类型),包括评估频率、评估的episode数量和图片保存路径等。
63+
- `other`: 其它配置
64+
- `replay_buffer`: 经验回放器配置(字典类型),包括存储大小,经验的最大使用次数和最大陈旧度以及吞吐量控制和监控配置相关的参数。
65+
- `cuda`: 指定是否将模型迁移到 GPU 上进行训练。
66+
- `multi_gpu`: 是否开启多 GPU 训练。
67+
- `use_wandb`: 是否使用 wandb。
68+
- `mcts_ctree`: 是否使用蒙特卡洛树搜索的cpp版本。
69+
- `collector_env_num`: 收集环境的数量。
70+
- `evaluator_env_num`: 评估环境的数量。
71+
- `env_type`: 环境类型(棋盘游戏或非棋盘游戏)。
72+
- `action_type`: 动作类型 (固定动作空间或其他)。
73+
- `game_segment_length`: 收集时的基本单元 game segment 对应的长度。
74+
- `cal_dormant_ratio`: 是否计算休眠神经元比率。
3175
- `use_augmentation`: 是否使用数据增强。
32-
- `update_per_collect`: 每次数据收集后更新的次数。
76+
- `augmentation`: 数据增强方法。
77+
- `update_per_collect`: 每次数据收集完以后模型更新的次数。
3378
- `batch_size`: 更新时采样的批量大小。
3479
- `optim_type`: 优化器类型。
80+
- `reanalyze_ratio`: 重分析系数,控制进行重分析的概率。
81+
- `reanalyze_noise`: 是否在 MCTS 重分析时引入噪声,可以增加探索。
82+
- `reanalyze_batch_size`: 重分析批量大小。
83+
- `reanalyze_partition`: 重分析的比例。例如,1 表示从整个缓冲区重新分析批次样本,0.5 表示从缓冲区的前一半采样。
84+
- `random_collect_episode_num`: 随机采集的 episode 数量,为探索提供初始数据。
85+
- `eps`: 探索控制参数,包括是否使用 epsilon-greedy 方法进行控制,控制参数的更新方式、起始值、终止值、衰减速度等。
3586
- `piecewise_decay_lr_scheduler`: 是否使用分段常数学习率衰减。
3687
- `learning_rate`: 初始学习率。
37-
- `num_simulations`: MCTS算法中使用的模拟次数。
38-
- `reanalyze_ratio`: 重分析系数,控制进行重分析的概率。
88+
- `num_simulations`: MCTS 算法中使用的模拟次数。
89+
- `reward_loss_weight`: 奖励损失函数的权重。
90+
- `policy_loss_weight`: 策略损失函数的权重。
91+
- `value_loss_weight`: 价值损失函数的权重。
3992
- `ssl_loss_weight`: 自监督学习损失函数的权重。
40-
- `n_episode`: 并行采集器中每个环境运行的episode数量
93+
- `n_episode`: 并行采集器中所有环境运行的总 episode 数量
4194
- `eval_freq`: 策略评估频率(按照训练步数计)。
4295
- `replay_buffer_size`: 经验回放器的容量。
96+
- `target_update_freq`: 目标网络更新频率。
97+
- `grad_clip_value`: 梯度裁剪值。
98+
- `discount_factor`: 折扣因子。
99+
- `td_steps`: TD 步数。
100+
- `num_unroll_steps`: MuZero 训练时展开的步数。
101+
102+
43103

44104
这里还特别提到了两个易变参数设定区域,通过注释
45105

@@ -66,7 +126,7 @@
66126
env=dict(
67127
type='atari_lightzero',
68128
import_names=['zoo.atari.envs.atari_lightzero_env'],
69-
),
129+
)
70130
```
71131

72132
其中`type`指定了要使用的环境名,`env_name`则指定了该环境类所在的引用路径。这里使用的是预定义的`atari_lightzero_env`。如果要使用自定义的环境类,则需要将`type`改为自定义环境类名,并相应修改`import_names`参数。
@@ -77,7 +137,7 @@ env=dict(
77137
policy=dict(
78138
type='muzero',
79139
import_names=['lzero.policy.muzero'],
80-
),
140+
)
81141
```
82142

83143
其中`type`指定了要使用的策略名,`import_names`则指定了该策略类所在的引用路径。这里使用的是LightZero中预定义的MuZero算法。如果要使用自定义的策略类,则需要将`type`改为自定义策略类,并修改`import_names`参数为自定义策略所在的引用路径。

0 commit comments

Comments
 (0)