Skip to content

Commit 0187f7e

Browse files
authored
Rebuild RLCodebase
2 parents 531b7dc + 5948958 commit 0187f7e

36 files changed

+941
-920
lines changed

README.md

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
# RLCodebase
22
RLCodebase is a modularized codebase for deep reinforcement learning algorithms based on PyTorch. This repo aims to provide an user-friendly reinforcement learning codebase for beginners to get started and for researchers to try their ideas quickly and efficiently.
33

4-
For now, it has implemented DQN(PER), A2C, PPO, DDPG, TD3 and SAC algorithms, and tested on OpenAI Gym, Procgen, PyBullet and DMControl Suite environments.
4+
For now, it has implemented DQN(PER), A2C, PPO, DDPG, TD3 and SAC algorithms, and has been tested on Atari, Procgen, Mujoco, PyBullet and DMControl Suite environments.
55

66
## Introduction
77
The design of RLCodebase is shown as below.
88

99

1010
![RLCodebase](imgs/RLCodebase.png)
11-
* Config: Config is a class that contains parameters for reinforcement learning algorithms such as discount factor, learning rate, etc. and general configurations such as random seed, saving path, etc.
12-
* Agent: Agent is a wrapped class that controls the workflow of reinforcement learning algorithms like a manager. It's responsible for the interactions among submodules (policy, environment, memory).
13-
* Policy: Policy tells us what action to taken given a state. It also implements a function that defines how to update the model given a batch of data.
14-
* Environment: Environment is designed to be a vectorized gym environment. Here we use gym wrappers from OpenAI baselines for convenient implementations.
15-
* Memory: Memory stores data needed for improving our model.
11+
* **Config**: Config is a class that contains parameters for reinforcement learning algorithms such as discount factor, learning rate, etc. and general configurations such as random seed, saving path, etc.
12+
* **Trainer**: Trainer is a wrapped class that controls the workflow of reinforcement learning training. It manages the interactions between submodules (Agent, Env, memory).
13+
* **Agent**: Agent chooses actions to take given states. It also defines how to update the model given a batch of data.
14+
* **Model**: Model gathers all neural networks to train.
15+
* **Env**: Env is a vectorized gym environment.
16+
* **Memory**: Memory stores experiences utilized for RL training.
1617

1718
## Installtion
1819
All required packages have been included in setup.py and requirements.txt. Mujoco is needed for mujoco_py and dm control suite. To support mujoco_py and dm control, please refer to https://github.com/openai/mujoco-py and https://github.com/deepmind/dm_control. For mujoco_py 2.1.2.14 and dm_control (commit fe44496), you may download mujoco like below
@@ -43,7 +44,7 @@ pip install -e .
4344
pip install -r requirements.txt
4445
4546
# try it
46-
python example_a2c.py
47+
python examples/example_ppo.py
4748
````
4849

4950
## Supported Algorithms
@@ -64,7 +65,7 @@ python example_a2c.py
6465
### 1. PPO & A2C In Atari Games
6566
<img src="https://github.com/KarlXing/RLCodebase/blob/master/imgs/A2C&PPO.png">
6667

67-
### 2. DDPG & TD3 & SAC In Pybullet Environments
68+
### 2. DDPG & TD3 & SAC In PyBullet Environments
6869
<img src="https://github.com/KarlXing/RLCodebase/blob/master/imgs/DDPG&TD3&SAC.png">
6970

7071
### 3. DQN & DQN+PER In PongNoFrameskip-v4

example_a2c.py renamed to examples/example_a2c.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import rlcodebase
22
from rlcodebase.env import make_vec_envs
3-
from rlcodebase.agent import A2CAgent
3+
from rlcodebase.trainer import A2CTrainer
44
from rlcodebase.utils import get_action_dim, init_parser, Config, Logger
55
from rlcodebase.model import CatACConvNet
66
from torch.utils.tensorboard import SummaryWriter
@@ -44,9 +44,9 @@ def main():
4444
model = CatACConvNet(input_channels = env.observation_space.shape[0], action_dim = get_action_dim(env.action_space)).to(config.device)
4545
logger = Logger(SummaryWriter(config.save_path), config.num_echo_episodes)
4646

47-
# create agent and run
48-
agent = A2CAgent(config, env, model, logger)
49-
agent.run()
47+
# create trainer and run
48+
trainer = A2CTrainer(config, env, model, logger)
49+
trainer.run()
5050

5151
if __name__ == '__main__':
5252
main()
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import rlcodebase
22
from rlcodebase.env import make_vec_envs
3-
from rlcodebase.agent import DDPGAgent
3+
from rlcodebase.trainer import DDPGTrainer
44
from rlcodebase.utils import get_action_dim, init_parser, Config, Logger
55
from rlcodebase.model import ConDetACLinearNet
66
from torch.utils.tensorboard import SummaryWriter
@@ -47,9 +47,9 @@ def main():
4747
target_model = ConDetACLinearNet(input_dim = env.observation_space.shape[0], action_dim = get_action_dim(env.action_space)).to(config.device)
4848
logger = Logger(SummaryWriter(config.save_path), config.num_echo_episodes)
4949

50-
# create agent and run
51-
agent = DDPGAgent(config, env, eval_env, model, target_model, logger)
52-
agent.run()
50+
# create trainer and run
51+
trainer = DDPGTrainer(config, env, eval_env, model, target_model, logger)
52+
trainer.run()
5353

5454
if __name__ == '__main__':
5555
main()

example_dqn.py renamed to examples/example_dqn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import rlcodebase
22
from rlcodebase.env import make_vec_envs
3-
from rlcodebase.agent import DQNAgent
3+
from rlcodebase.trainer import DQNTrainer
44
from rlcodebase.utils import get_action_dim, init_parser, Config, Logger
55
from rlcodebase.model import CatQConvNet
66
from torch.utils.tensorboard import SummaryWriter
@@ -54,9 +54,9 @@ def main():
5454
target_model = CatQConvNet(input_channels = env.observation_space.shape[0], action_dim = get_action_dim(env.action_space)).to(config.device)
5555
logger = Logger(SummaryWriter(config.save_path), config.num_echo_episodes, config.log_episodes_avg_window)
5656

57-
# create agent and run
58-
agent = DQNAgent(config, env, eval_env, model, target_model, logger)
59-
agent.run()
57+
# create trainer and run
58+
trainer = DQNTrainer(config, env, eval_env, model, target_model, logger)
59+
trainer.run()
6060

6161
if __name__ == '__main__':
6262
main()

example_ppo.py renamed to examples/example_ppo.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import rlcodebase
22
from rlcodebase.env import make_vec_envs
3-
from rlcodebase.agent import PPOAgent
3+
from rlcodebase.trainer import PPOTrainer
44
from rlcodebase.utils import get_action_dim, init_parser, Config, Logger
55
from rlcodebase.model import CatACConvNet
66
from torch.utils.tensorboard import SummaryWriter
@@ -47,9 +47,9 @@ def main():
4747
model = CatACConvNet(input_channels = env.observation_space.shape[0], action_dim = get_action_dim(env.action_space)).to(config.device)
4848
logger = Logger(SummaryWriter(config.save_path), config.num_echo_episodes)
4949

50-
# create agent and run
51-
agent = PPOAgent(config, env, model, logger)
52-
agent.run()
50+
# create trainer and run
51+
trainer = PPOTrainer(config, env, model, logger)
52+
trainer.run()
5353

5454
if __name__ == '__main__':
5555
main()
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import rlcodebase
22
from rlcodebase.env import make_vec_envs_procgen
3-
from rlcodebase.agent import PPOAgent
3+
from rlcodebase.trainer import PPOTrainer
44
from rlcodebase.utils import get_action_dim, init_parser, Config, Logger
55
from rlcodebase.model import ImpalaCNN, SeparateImpalaCNN
66
from torch.utils.tensorboard import SummaryWriter
@@ -62,9 +62,9 @@ def main():
6262
model = Model(input_channels = env.observation_space.shape[0], action_dim = get_action_dim(env.action_space)).to(config.device)
6363
logger = Logger(SummaryWriter(config.save_path), config.num_echo_episodes)
6464

65-
# create agent and run
66-
agent = PPOAgent(config, env, model, logger)
67-
agent.run()
65+
# create trainer and run
66+
trainer = PPOTrainer(config, env, model, logger)
67+
trainer.run()
6868

6969
if __name__ == '__main__':
7070
main()

example_sac.py renamed to examples/example_sac.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import rlcodebase
22
from rlcodebase.env import make_vec_envs
3-
from rlcodebase.agent import SACAgent
3+
from rlcodebase.trainer import SACTrainer
44
from rlcodebase.utils import get_action_dim, init_parser, Config, Logger
55
from rlcodebase.model import ConStoSGADCLinearNet
66
from torch.utils.tensorboard import SummaryWriter
@@ -48,9 +48,9 @@ def main():
4848
target_model = ConStoSGADCLinearNet(input_dim = env.observation_space.shape[0], action_dim = get_action_dim(env.action_space)).to(config.device)
4949
logger = Logger(SummaryWriter(config.save_path), config.num_echo_episodes)
5050

51-
# create agent and run
52-
agent = SACAgent(config, env, eval_env, model, target_model, logger)
53-
agent.run()
51+
# create trainer and run
52+
trainer = SACTrainer(config, env, eval_env, model, target_model, logger)
53+
trainer.run()
5454

5555
if __name__ == '__main__':
5656
main()
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import rlcodebase
22
from rlcodebase.env import make_vec_envs_dmcontrol
3-
from rlcodebase.agent import SACAgent
3+
from rlcodebase.trainer import SACTrainer
44
from rlcodebase.utils import get_action_dim, init_parser, Config, Logger
55
from rlcodebase.model import ConStoSGADCLinearNet
66
from torch.utils.tensorboard import SummaryWriter
@@ -49,9 +49,9 @@ def main():
4949
target_model = ConStoSGADCLinearNet(input_dim = env.observation_space.shape[0], action_dim = get_action_dim(env.action_space)).to(config.device)
5050
logger = Logger(SummaryWriter(config.save_path), config.num_echo_episodes)
5151

52-
# create agent and run
53-
agent = SACAgent(config, env, eval_env, model, target_model, logger)
54-
agent.run()
52+
# create trainer and run
53+
trainer = SACTrainer(config, env, eval_env, model, target_model, logger)
54+
trainer.run()
5555

5656
if __name__ == '__main__':
5757
main()

example_td3.py renamed to examples/example_td3.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import rlcodebase
22
from rlcodebase.env import make_vec_envs
3-
from rlcodebase.agent import TD3Agent
3+
from rlcodebase.trainer import TD3Trainer
44
from rlcodebase.utils import get_action_dim, init_parser, Config, Logger
55
from rlcodebase.model import ConDetADCLinearNet
66
from torch.utils.tensorboard import SummaryWriter
@@ -50,9 +50,9 @@ def main():
5050
target_model = ConDetADCLinearNet(input_dim = env.observation_space.shape[0], action_dim = get_action_dim(env.action_space)).to(config.device)
5151
logger = Logger(SummaryWriter(config.save_path), config.num_echo_episodes)
5252

53-
# create agent and run
54-
agent = TD3Agent(config, env, eval_env, model, target_model, logger)
55-
agent.run()
53+
# create trainer and run
54+
trainer = TD3Trainer(config, env, eval_env, model, target_model, logger)
55+
trainer.run()
5656

5757
if __name__ == '__main__':
5858
main()
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from ast import arg
22
import rlcodebase
33
from rlcodebase.env import make_vec_envs_dmcontrol
4-
from rlcodebase.agent import TD3Agent
4+
from rlcodebase.trainer import TD3Trainer
55
from rlcodebase.utils import get_action_dim, init_parser, Config, Logger
66
from rlcodebase.model import ConDetADCLinearNet
77
from torch.utils.tensorboard import SummaryWriter
@@ -53,9 +53,9 @@ def main():
5353
target_model = ConDetADCLinearNet(input_dim = env.observation_space.shape[0], action_dim = get_action_dim(env.action_space)).to(config.device)
5454
logger = Logger(SummaryWriter(config.save_path), config.num_echo_episodes)
5555

56-
# create agent and run
57-
agent = TD3Agent(config, env, eval_env, model, target_model, logger)
58-
agent.run()
56+
# create trainer and run
57+
trainer = TD3Trainer(config, env, eval_env, model, target_model, logger)
58+
trainer.run()
5959

6060
if __name__ == '__main__':
6161
main()

0 commit comments

Comments
 (0)