Skip to content

Commit d2818b2

Browse files
committed
Update
[ghstack-poisoned]
1 parent 6c5d8e5 commit d2818b2

File tree

1 file changed

+139
-0
lines changed

1 file changed

+139
-0
lines changed
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# PPO Trainer Configuration for Pendulum-v1
2+
# This configuration uses the new configurable trainer system
3+
4+
defaults:
5+
6+
- transform@transform0: noop_reset
7+
- transform@transform1: step_counter
8+
- transform@transform2: reward_sum
9+
10+
- env@training_env: batched_env
11+
- env@training_env.create_env_fn: transformed_env
12+
- env@training_env.create_env_fn.base_env: gym
13+
- transform@training_env.create_env_fn.transform: compose
14+
15+
- model@models.policy_model: tanh_normal
16+
- model@models.value_model: value
17+
18+
- network@networks.policy_network: mlp
19+
- network@networks.value_network: mlp
20+
21+
- collector@collector: multi_async
22+
23+
- replay_buffer@replay_buffer: base
24+
- storage@replay_buffer.storage: lazy_tensor
25+
- writer@replay_buffer.writer: round_robin
26+
- sampler@replay_buffer.sampler: without_replacement
27+
- trainer@trainer: ppo
28+
- optimizer@optimizer: adam
29+
- loss@loss: ppo
30+
- logger@logger: wandb
31+
- _self_
32+
33+
# Network configurations
34+
networks:
35+
policy_network:
36+
out_features: 2 # Pendulum action space is 1-dimensional
37+
in_features: 3 # Pendulum observation space is 3-dimensional
38+
num_cells: [128, 128]
39+
40+
value_network:
41+
out_features: 1 # Value output
42+
in_features: 3 # Pendulum observation space
43+
num_cells: [128, 128]
44+
45+
# Model configurations
46+
models:
47+
policy_model:
48+
return_log_prob: true
49+
in_keys: ["observation"]
50+
param_keys: ["loc", "scale"]
51+
out_keys: ["action"]
52+
network: ${networks.policy_network}
53+
54+
value_model:
55+
in_keys: ["observation"]
56+
out_keys: ["state_value"]
57+
network: ${networks.value_network}
58+
59+
# Environment configuration
60+
transform0:
61+
noops: 30
62+
random: true
63+
64+
transform1:
65+
max_steps: 200
66+
step_count_key: "step_count"
67+
68+
transform2:
69+
in_keys: ["reward"]
70+
out_keys: ["reward_sum"]
71+
72+
training_env:
73+
num_workers: 1
74+
create_env_fn:
75+
base_env:
76+
env_name: Pendulum-v1
77+
transform:
78+
transforms:
79+
- ${transform0}
80+
- ${transform1}
81+
- ${transform2}
82+
_partial_: true
83+
84+
# Loss configuration
85+
loss:
86+
actor_network: ${models.policy_model}
87+
critic_network: ${models.value_model}
88+
entropy_coeff: 0.01
89+
90+
# Optimizer configuration
91+
optimizer:
92+
lr: 0.001
93+
94+
# Collector configuration
95+
collector:
96+
create_env_fn: ${training_env}
97+
policy: ${models.policy_model}
98+
total_frames: 1_000_000
99+
frames_per_batch: 1024
100+
num_workers: 2
101+
_partial_: true
102+
103+
# Replay buffer configuration
104+
replay_buffer:
105+
storage:
106+
max_size: 1024
107+
device: cpu
108+
ndim: 1
109+
sampler:
110+
drop_last: true
111+
shuffle: true
112+
writer:
113+
compilable: false
114+
batch_size: 128
115+
116+
logger:
117+
exp_name: ppo_pendulum_v1
118+
offline: false
119+
project: torchrl-sota-implementations
120+
121+
# Trainer configuration
122+
trainer:
123+
collector: ${collector}
124+
optimizer: ${optimizer}
125+
replay_buffer: ${replay_buffer}
126+
loss_module: ${loss}
127+
logger: ${logger}
128+
total_frames: 1_000_000
129+
frame_skip: 1
130+
clip_grad_norm: true
131+
clip_norm: 100.0
132+
progress_bar: true
133+
seed: 42
134+
save_trainer_interval: 100
135+
log_interval: 100
136+
save_trainer_file: null
137+
optim_steps_per_batch: null
138+
num_epochs: 2
139+
async_collection: false

0 commit comments

Comments
 (0)