Skip to content

Commit cd928c7

Browse files
committed
Codes uploaded
1 parent f46df39 commit cd928c7

File tree

14 files changed

+2957
-0
lines changed

14 files changed

+2957
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
*__pycache__*
2+
tb_logs
3+
saved_policy

config.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Modes: 1) depth, 2) single_rgb, 3) multi_rgb
2+
train_mode: "multi_rgb"
3+
test_mode: "multi_rgb"
4+
5+
# Test types: 1) sequential, 2) random
6+
test_type: "random"

inference.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from cgi import test
2+
import os
3+
import gym
4+
import yaml
5+
6+
from stable_baselines3 import PPO
7+
from stable_baselines3.common.monitor import Monitor
8+
from stable_baselines3.common.vec_env import DummyVecEnv, VecTransposeImage
9+
from scripts.network import NatureCNN
10+
11+
12+
# Load train environment configs
13+
with open('scripts/env_config.yml', 'r') as f:
14+
env_config = yaml.safe_load(f)
15+
16+
# Load inference configs
17+
with open('config.yml', 'r') as f:
18+
config = yaml.safe_load(f)
19+
20+
# Model name
21+
model_name = "best_model_" + config["test_mode"]
22+
23+
# Determine input image shape
24+
image_shape = (50,50,1) if config["test_mode"]=="depth" else (50,50,3)
25+
26+
# Create a DummyVecEnv
27+
env = DummyVecEnv([lambda: Monitor(
28+
gym.make(
29+
"scripts:test-env-v0",
30+
ip_address="127.0.0.1",
31+
image_shape=image_shape,
32+
# Train and test envs shares same config for the test
33+
env_config=env_config["TrainEnv"],
34+
input_mode=config["test_mode"],
35+
test_mode=config["test_type"]
36+
)
37+
)])
38+
39+
# Wrap env as VecTransposeImage (Channel last to channel first)
40+
env = VecTransposeImage(env)
41+
42+
policy_kwargs = dict(features_extractor_class=NatureCNN)
43+
44+
# Load an existing model
45+
model = PPO.load(
46+
env=env,
47+
path=os.path.join("saved_policy", model_name),
48+
policy_kwargs=policy_kwargs
49+
)
50+
51+
# Run the trained policy
52+
obs = env.reset()
53+
for i in range(2300):
54+
action, _ = model.predict(obs, deterministic=True)
55+
obs, _, dones, info = env.step(action)

requirements.txt

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
absl-py==0.14.1
2+
airsim==1.6.0
3+
cachetools==4.2.4
4+
certifi==2021.5.30
5+
charset-normalizer==2.0.6
6+
cloudpickle==2.0.0
7+
cycler==0.10.0
8+
google-auth==1.35.0
9+
google-auth-oauthlib==0.4.6
10+
grpcio==1.41.0
11+
gym==0.21.0
12+
idna==3.2
13+
kiwisolver==1.3.2
14+
Markdown==3.3.4
15+
matplotlib==3.4.3
16+
msgpack-python==0.5.6
17+
msgpack-rpc-python==0.4.1
18+
numpy==1.21.2
19+
oauthlib==3.1.1
20+
opencv-contrib-python==4.5.3.56
21+
pandas==1.3.3
22+
Pillow==8.3.2
23+
protobuf==3.18.1
24+
pyasn1==0.4.8
25+
pyasn1-modules==0.2.8
26+
pyparsing==2.4.7
27+
python-dateutil==2.8.2
28+
pytz==2021.3
29+
PyYAML==5.4.1
30+
requests==2.26.0
31+
requests-oauthlib==1.3.0
32+
rsa==4.7.2
33+
six==1.16.0
34+
stable-baselines3==1.2.0
35+
tensorboard==2.6.0
36+
tensorboard-data-server==0.6.1
37+
tensorboard-plugin-wit==1.8.0
38+
torch==1.9.1
39+
tornado==4.5.3
40+
typing-extensions==3.10.0.2
41+
urllib3==1.26.7
42+
Werkzeug==2.0.2
43+
wincertstore==0.2

scripts/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from .airsim_env import AirSimDroneEnv, TestEnv
2+
from gym.envs.registration import register
3+
4+
5+
# Register AirSim environment as a gym environment
6+
register(
7+
id="airsim-env-v0", entry_point="scripts:AirSimDroneEnv",
8+
)
9+
10+
# Register AirSim environment as a gym environment
11+
register(
12+
id="test-env-v0", entry_point="scripts:TestEnv",
13+
)

scripts/airsim/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .client import *
2+
from .utils import *
3+
from .types import *
4+
5+
__version__ = "1.5.0"

0 commit comments

Comments
 (0)