Skip to content

Commit 39960d8

Browse files
authored
Merge pull request #168 from edbeeching/code-quality
Adds code quality
2 parents cfaaf15 + ed8407f commit 39960d8

31 files changed

+427
-438
lines changed

.github/workflows/quality.yml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
name: Quality
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
pull_request:
8+
branches:
9+
- main
10+
11+
jobs:
12+
13+
check_code_quality:
14+
name: Check code quality
15+
runs-on: ubuntu-latest
16+
steps:
17+
- name: Checkout code
18+
uses: actions/checkout@v2
19+
- name: Setup Python environment
20+
uses: actions/setup-python@v2
21+
with:
22+
python-version: 3.10.10
23+
- name: Install dependencies
24+
run: |
25+
python -m pip install --upgrade pip
26+
python -m pip install ".[dev]"
27+
- name: Code quality
28+
run: |
29+
make quality

.vscode/settings.json

Lines changed: 0 additions & 12 deletions
This file was deleted.

Makefile

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
.PHONY: quality style test unity-test
22

3-
# Check that source code meets quality standards
4-
quality:
5-
black --check --line-length 119 --target-version py38 tests godot_rl
6-
isort --check-only tests godot_rl
7-
flake8 tests godot_rl
8-
93
# Format source code automatically
104
style:
11-
black --line-length 119 --target-version py38 tests godot_rl
12-
isort tests godot_rl
5+
black --line-length 120 --target-version py310 tests godot_rl examples
6+
isort -w 120 tests godot_rl examples
7+
# Check that source code meets quality standards
8+
quality:
9+
black --check --line-length 120 --target-version py310 tests godot_rl examples
10+
isort -w 120 --check-only tests godot_rl examples
11+
flake8 --max-line-length 120 tests godot_rl examples
1312

1413
# Run tests for the library
1514
test:

README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,33 @@ Godot RL Agents supports 4 different RL training frameworks, the links below det
8383
- [CleanRL](docs/ADV_CLEAN_RL.md) (Windows, Mac, Linux)
8484
- [Ray rllib](docs/ADV_RLLIB.md) (Windows, Mac, Linux)
8585

86+
## Contributing
87+
We welcome new contributions to the library, such as:
88+
- New environments made in Godot
89+
- Improvements to the readme files
90+
- Additions to the python codebase
91+
92+
Start by forking the repo and then cloning it to your machine, creating a venv and performing an editable installation.
93+
94+
```
95+
# If you want to PR, you should fork the lib or ask to be a contibutor
96+
git clone [email protected]:YOUR_USERNAME/godot_rl_agents.git
97+
cd godot_rl_agents
98+
python -m venv venv
99+
pip install -e ".[dev]"
100+
# check tests run
101+
make test
102+
```
103+
104+
Then add your features.
105+
Format your code with:
106+
```
107+
make style
108+
make quality
109+
```
110+
Then make a PR against main on the original repo.
111+
112+
86113
## FAQ
87114

88115
### Why have we developed Godot RL Agents?

examples/clean_rl_example.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
import pathlib
55
import random
66
import time
7-
from distutils.util import strtobool
87
from collections import deque
8+
from distutils.util import strtobool
9+
910
import numpy as np
1011
import torch
1112
import torch.nn as nn
1213
import torch.optim as optim
1314
from torch.distributions.normal import Normal
1415
from torch.utils.tensorboard import SummaryWriter
16+
1517
from godot_rl.wrappers.clean_rl_wrapper import CleanRLGodotEnv
1618

1719

@@ -167,8 +169,9 @@ def get_action_and_value(self, x, action=None):
167169

168170
# env setup
169171

170-
envs = env = CleanRLGodotEnv(env_path=args.env_path, show_window=args.viz, speedup=args.speedup, seed=args.seed,
171-
n_parallel=args.n_parallel)
172+
envs = env = CleanRLGodotEnv(
173+
env_path=args.env_path, show_window=args.viz, speedup=args.speedup, seed=args.seed, n_parallel=args.n_parallel
174+
)
172175
args.num_envs = envs.num_envs
173176
args.batch_size = int(args.num_envs * args.num_steps)
174177
args.minibatch_size = int(args.batch_size // args.num_minibatches)
@@ -334,7 +337,6 @@ def get_action_and_value(self, x, action=None):
334337

335338
agent.eval().to("cpu")
336339

337-
338340
class OnnxPolicy(torch.nn.Module):
339341
def __init__(self, actor_mean):
340342
super().__init__()
@@ -344,7 +346,6 @@ def forward(self, obs, state_ins):
344346
action_mean = self.actor_mean(obs)
345347
return action_mean, state_ins
346348

347-
348349
onnx_policy = OnnxPolicy(agent.actor_mean)
349350
dummy_input = torch.unsqueeze(torch.tensor(envs.single_observation_space.sample()), 0)
350351

@@ -355,9 +356,10 @@ def forward(self, obs, state_ins):
355356
opset_version=15,
356357
input_names=["obs", "state_ins"],
357358
output_names=["output", "state_outs"],
358-
dynamic_axes={'obs': {0: 'batch_size'},
359-
'state_ins': {0: 'batch_size'}, # variable length axes
360-
'output': {0: 'batch_size'},
361-
'state_outs': {0: 'batch_size'}}
362-
359+
dynamic_axes={
360+
"obs": {0: "batch_size"},
361+
"state_ins": {0: "batch_size"}, # variable length axes
362+
"output": {0: "batch_size"},
363+
"state_outs": {0: "batch_size"},
364+
},
363365
)

examples/sample_factory_example.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
2-
from godot_rl.wrappers.sample_factory_wrapper import sample_factory_training, sample_factory_enjoy
2+
3+
from godot_rl.wrappers.sample_factory_wrapper import sample_factory_enjoy, sample_factory_training
34

45

56
def get_args():
@@ -10,8 +11,12 @@ def get_args():
1011
parser.add_argument("--seed", default=0, type=int, help="environment seed")
1112
parser.add_argument("--export", default=False, action="store_true", help="whether to export the model")
1213
parser.add_argument("--viz", default=False, action="store_true", help="Whether to visualize one process")
13-
parser.add_argument("--experiment_dir", default="logs/sf", type=str,
14-
help="The name of the experiment directory, in which the tensorboard logs are getting stored")
14+
parser.add_argument(
15+
"--experiment_dir",
16+
default="logs/sf",
17+
type=str,
18+
help="The name of the experiment directory, in which the tensorboard logs are getting stored",
19+
)
1520
parser.add_argument(
1621
"--experiment_name",
1722
default="experiment",
@@ -22,14 +27,13 @@ def get_args():
2227
return parser.parse_known_args()
2328

2429

25-
2630
def main():
2731
args, extras = get_args()
2832
if args.eval:
2933
sample_factory_enjoy(args, extras)
3034
else:
3135
sample_factory_training(args, extras)
32-
33-
36+
37+
3438
if __name__ == "__main__":
3539
main()

examples/stable_baselines3_example.py

Lines changed: 51 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
import pathlib
44
from typing import Callable
55

6+
from stable_baselines3 import PPO
67
from stable_baselines3.common.callbacks import CheckpointCallback
8+
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
9+
710
from godot_rl.core.utils import can_import
8-
from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv
911
from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx
10-
from stable_baselines3 import PPO
11-
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
12+
from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv
1213

1314
# To download the env source and binary:
1415
# 1. gdrl.env_from_hub -r edbeeching/godot_rl_BallChase
@@ -28,42 +29,39 @@
2829
default="logs/sb3",
2930
type=str,
3031
help="The name of the experiment directory, in which the tensorboard logs and checkpoints (if enabled) are "
31-
"getting stored."
32+
"getting stored.",
3233
)
3334
parser.add_argument(
3435
"--experiment_name",
3536
default="experiment",
3637
type=str,
3738
help="The name of the experiment, which will be displayed in tensorboard and "
38-
"for checkpoint directory and name (if enabled).",
39-
)
40-
parser.add_argument(
41-
"--seed",
42-
type=int,
43-
default=0,
44-
help="seed of the experiment"
39+
"for checkpoint directory and name (if enabled).",
4540
)
41+
parser.add_argument("--seed", type=int, default=0, help="seed of the experiment")
4642
parser.add_argument(
4743
"--resume_model_path",
4844
default=None,
4945
type=str,
5046
help="The path to a model file previously saved using --save_model_path or a checkpoint saved using "
51-
"--save_checkpoints_frequency. Use this to resume training or infer from a saved model.",
47+
"--save_checkpoints_frequency. Use this to resume training or infer from a saved model.",
5248
)
5349
parser.add_argument(
5450
"--save_model_path",
5551
default=None,
5652
type=str,
5753
help="The path to use for saving the trained sb3 model after training is complete. Saved model can be used later "
58-
"to resume training. Extension will be set to .zip",
54+
"to resume training. Extension will be set to .zip",
5955
)
6056
parser.add_argument(
6157
"--save_checkpoint_frequency",
6258
default=None,
6359
type=int,
64-
help=("If set, will save checkpoints every 'frequency' environment steps. "
65-
"Requires a unique --experiment_name or --experiment_dir for each run. "
66-
"Does not need --save_model_path to be set. "),
60+
help=(
61+
"If set, will save checkpoints every 'frequency' environment steps. "
62+
"Requires a unique --experiment_name or --experiment_dir for each run. "
63+
"Does not need --save_model_path to be set. "
64+
),
6765
)
6866
parser.add_argument(
6967
"--onnx_export_path",
@@ -76,34 +74,38 @@
7674
default=1_000_000,
7775
type=int,
7876
help="The number of environment steps to train for, default is 1_000_000. If resuming from a saved model, "
79-
"it will continue training for this amount of steps from the saved state without counting previously trained "
80-
"steps",
77+
"it will continue training for this amount of steps from the saved state without counting previously trained "
78+
"steps",
8179
)
8280
parser.add_argument(
8381
"--inference",
8482
default=False,
8583
action="store_true",
8684
help="Instead of training, it will run inference on a loaded model for --timesteps steps. "
87-
"Requires --resume_model_path to be set."
85+
"Requires --resume_model_path to be set.",
8886
)
8987
parser.add_argument(
9088
"--linear_lr_schedule",
9189
default=False,
9290
action="store_true",
9391
help="Use a linear LR schedule for training. If set, learning rate will decrease until it reaches 0 at "
94-
"--timesteps"
95-
"value. Note: On resuming training, the schedule will reset. If disabled, constant LR will be used."
92+
"--timesteps"
93+
"value. Note: On resuming training, the schedule will reset. If disabled, constant LR will be used.",
9694
)
9795
parser.add_argument(
9896
"--viz",
9997
action="store_true",
10098
help="If set, the simulation will be displayed in a window during training. Otherwise "
101-
"training will run without rendering the simulation. This setting does not apply to in-editor training.",
102-
default=False
99+
"training will run without rendering the simulation. This setting does not apply to in-editor training.",
100+
default=False,
103101
)
104102
parser.add_argument("--speedup", default=1, type=int, help="Whether to speed up the physics in the env")
105-
parser.add_argument("--n_parallel", default=1, type=int, help="How many instances of the environment executable to "
106-
"launch - requires --env_path to be set if > 1.")
103+
parser.add_argument(
104+
"--n_parallel",
105+
default=1,
106+
type=int,
107+
help="How many instances of the environment executable to " "launch - requires --env_path to be set if > 1.",
108+
)
107109
args, extras = parser.parse_known_args()
108110

109111

@@ -136,19 +138,22 @@ def close_env():
136138

137139
# Prevent overwriting existing checkpoints when starting a new experiment if checkpoint saving is enabled
138140
if args.save_checkpoint_frequency is not None and os.path.isdir(path_checkpoint):
139-
raise RuntimeError(abs_path_checkpoint + " folder already exists. "
140-
"Use a different --experiment_dir, or --experiment_name,"
141-
"or if previous checkpoints are not needed anymore, "
142-
"remove the folder containing the checkpoints. ")
141+
raise RuntimeError(
142+
abs_path_checkpoint + " folder already exists. "
143+
"Use a different --experiment_dir, or --experiment_name,"
144+
"or if previous checkpoints are not needed anymore, "
145+
"remove the folder containing the checkpoints. "
146+
)
143147

144148
if args.inference and args.resume_model_path is None:
145149
raise parser.error("Using --inference requires --resume_model_path to be set.")
146150

147151
if args.env_path is None and args.viz:
148152
print("Info: Using --viz without --env_path set has no effect, in-editor training will always render.")
149153

150-
env = StableBaselinesGodotEnv(env_path=args.env_path, show_window=args.viz, seed=args.seed, n_parallel=args.n_parallel,
151-
speedup=args.speedup)
154+
env = StableBaselinesGodotEnv(
155+
env_path=args.env_path, show_window=args.viz, seed=args.seed, n_parallel=args.n_parallel, speedup=args.speedup
156+
)
152157
env = VecMonitor(env)
153158

154159

@@ -177,13 +182,15 @@ def func(progress_remaining: float) -> float:
177182

178183
if args.resume_model_path is None:
179184
learning_rate = 0.0003 if not args.linear_lr_schedule else linear_schedule(0.0003)
180-
model: PPO = PPO("MultiInputPolicy",
181-
env,
182-
ent_coef=0.0001,
183-
verbose=2,
184-
n_steps=32,
185-
tensorboard_log=args.experiment_dir,
186-
learning_rate=learning_rate)
185+
model: PPO = PPO(
186+
"MultiInputPolicy",
187+
env,
188+
ent_coef=0.0001,
189+
verbose=2,
190+
n_steps=32,
191+
tensorboard_log=args.experiment_dir,
192+
learning_rate=learning_rate,
193+
)
187194
else:
188195
path_zip = pathlib.Path(args.resume_model_path)
189196
print("Loading model: " + os.path.abspath(path_zip))
@@ -201,13 +208,16 @@ def func(progress_remaining: float) -> float:
201208
checkpoint_callback = CheckpointCallback(
202209
save_freq=(args.save_checkpoint_frequency // env.num_envs),
203210
save_path=path_checkpoint,
204-
name_prefix=args.experiment_name
211+
name_prefix=args.experiment_name,
205212
)
206-
learn_arguments['callback'] = checkpoint_callback
213+
learn_arguments["callback"] = checkpoint_callback
207214
try:
208215
model.learn(**learn_arguments)
209216
except KeyboardInterrupt:
210-
print("Training interrupted by user. Will save if --save_model_path was used and/or export if --onnx_export_path was used.")
217+
print(
218+
"""Training interrupted by user. Will save if --save_model_path was
219+
used and/or export if --onnx_export_path was used."""
220+
)
211221

212222
close_env()
213223
handle_onnx_export()

0 commit comments

Comments
 (0)