Skip to content

Commit 7669886

Browse files
committed
new way to load models with different action spaces
1 parent cc31995 commit 7669886

File tree

4 files changed

+72
-64
lines changed

4 files changed

+72
-64
lines changed
Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,25 @@
11
{
2-
"agent_name": "AGENT_24_2022_02_04__11_36",
3-
"robot": "rto_real",
4-
"actions_in_observationspace": true,
5-
"reward_fnc": "rule_05",
6-
"discrete_action_space": false,
7-
"normalize": false,
8-
"task_mode": "staged",
9-
"train_max_steps_per_episode": 120,
10-
"eval_max_steps_per_episode": 170,
11-
"goal_radius": 0.7,
12-
"curr_stage": 10,
13-
"batch_size": 19200,
14-
"gamma": 0.99,
15-
"n_steps": 2400,
16-
"ent_coef": 0.005,
17-
"learning_rate": 0.0003,
18-
"vf_coef": 0.22,
19-
"max_grad_norm": 0.5,
20-
"gae_lambda": 0.95,
21-
"m_batch_size": 15,
22-
"n_epochs": 3,
23-
"clip_range": 0.22
24-
}
2+
"agent_name": "AGENT_24_2022_02_04__11_36",
3+
"robot": "rto_real",
4+
"actions_in_observationspace": true,
5+
"reward_fnc": "rule_05",
6+
"discrete_action_space": false,
7+
"normalize": false,
8+
"task_mode": "staged",
9+
"train_max_steps_per_episode": 120,
10+
"eval_max_steps_per_episode": 170,
11+
"goal_radius": 0.7,
12+
"curr_stage": 10,
13+
"batch_size": 19200,
14+
"gamma": 0.99,
15+
"n_steps": 2400,
16+
"ent_coef": 0.005,
17+
"learning_rate": 0.0003,
18+
"vf_coef": 0.22,
19+
"max_grad_norm": 0.5,
20+
"gae_lambda": 0.95,
21+
"m_batch_size": 15,
22+
"n_epochs": 3,
23+
"clip_range": 0.22,
24+
"observation_space": ["laser_scan", "goal_in_robot_frame", "last_action"]
25+
}
Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,25 @@
11
{
2-
"agent_name": "AGENT_24_2022_02_04__11_36",
3-
"robot": "rto_real",
4-
"actions_in_observationspace": true,
5-
"reward_fnc": "rule_06",
6-
"discrete_action_space": false,
7-
"normalize": false,
8-
"task_mode": "staged",
9-
"train_max_steps_per_episode": 400,
10-
"eval_max_steps_per_episode": 1100,
11-
"goal_radius": 0.7,
12-
"curr_stage": 1,
13-
"batch_size": 38400,
14-
"gamma": 0.99,
15-
"n_steps": 3200,
16-
"ent_coef": 0.005,
17-
"learning_rate": 0.0003,
18-
"vf_coef": 0.22,
19-
"max_grad_norm": 0.5,
20-
"gae_lambda": 0.95,
21-
"m_batch_size": 15,
22-
"n_epochs": 3,
23-
"clip_range": 0.22
24-
}
2+
"agent_name": "AGENT_24_2022_02_04__11_36",
3+
"robot": "rto_real",
4+
"actions_in_observationspace": true,
5+
"reward_fnc": "rule_06",
6+
"discrete_action_space": false,
7+
"normalize": false,
8+
"task_mode": "staged",
9+
"train_max_steps_per_episode": 400,
10+
"eval_max_steps_per_episode": 1100,
11+
"goal_radius": 0.7,
12+
"curr_stage": 1,
13+
"batch_size": 38400,
14+
"gamma": 0.99,
15+
"n_steps": 3200,
16+
"ent_coef": 0.005,
17+
"learning_rate": 0.0003,
18+
"vf_coef": 0.22,
19+
"max_grad_norm": 0.5,
20+
"gae_lambda": 0.95,
21+
"m_batch_size": 15,
22+
"n_epochs": 3,
23+
"clip_range": 0.22,
24+
"observation_space": ["laser_scan", "goal_in_robot_frame", "last_action"]
25+
}

rosnav/utils/utils.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111

1212
def get_robot_yaml_path(robot_model: str = None) -> str:
1313
if not robot_model:
14-
robot_model = rospy.get_param("model")
14+
robot_model = rospy.get_param("robot_model")
1515

1616
simulation_setup_path = rospkg.RosPack().get_path("arena-simulation-setup")
1717
return os.path.join(
18-
simulation_setup_path, "robot", robot_model, f"{robot_model}.model.yaml"
18+
simulation_setup_path, "robot", robot_model, f"model_params.yaml"
1919
)
2020

2121

@@ -24,26 +24,20 @@ def get_laser_from_robot_yaml(robot_model: str = None) -> Tuple[int, int, int, i
2424

2525
with open(robot_yaml_path, "r") as fd:
2626
robot_data = yaml.safe_load(fd)
27+
laser_data = robot_data["laser"]
2728

28-
for plugin in robot_data["plugins"]:
29-
if plugin["type"] == "Laser":
30-
laser_angle_min = plugin["angle"]["min"]
31-
laser_angle_max = plugin["angle"]["max"]
32-
laser_angle_increment = plugin["angle"]["increment"]
29+
rospy.set_param("laser/num_beams", laser_data["num_beams"])
3330

34-
_L = int(
35-
round((laser_angle_max - laser_angle_min) / laser_angle_increment)
36-
)
37-
38-
# Because RosnavEncoder ist weird...
39-
rospy.set_param("laser/num_beams", _L)
40-
41-
return _L, laser_angle_min, laser_angle_max, laser_angle_increment
31+
return (
32+
laser_data["num_beams"],
33+
laser_data["angle"]["min"],
34+
laser_data["angle"]["max"],
35+
laser_data["angle"]["increment"]
36+
)
4237

4338

4439
def get_observation_space_from_file(robot_model: str = None) -> Tuple[int, int]:
45-
actions_in_obs = rospy.get_param("/actions_in_obs", True)
46-
robot_state_size, action_state_size = 2, 3 if actions_in_obs else 0
40+
robot_state_size, action_state_size = 2, rospy.get_param(rospy.get_namespace() + "action_state_size", 3)
4741
num_beams, _, _, _ = get_laser_from_robot_yaml(robot_model)
4842

4943
return num_beams, action_state_size + robot_state_size

scripts/rosnav_node.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
import rospkg
33
import os
44
import sys
5+
import traceback
56
import json
67
from stable_baselines3 import PPO
8+
from stable_baselines3.common.utils import get_device
79

810
from rosnav.srv import GetAction, GetActionResponse
911
from rosnav.rosnav_space_manager.rosnav_space_manager import RosnavSpaceManager
@@ -24,6 +26,7 @@ def __init__(self):
2426

2527
# Load hyperparams
2628
self._hyperparams = self._load_hyperparams(self.agent_path)
29+
# rospy.set_param("/actions_in_obs", self._hyperparams.get("actions_in_observationspace", False))
2730

2831
self._obs_structure = self._get_observation_space_structure(self._hyperparams)
2932

@@ -54,7 +57,16 @@ def _handle_next_action_srv(self, request):
5457
return response
5558

5659
def _get_model(self, agent_path):
57-
return PPO.load(os.path.join(agent_path, "best_model.zip")).policy
60+
action_state_sizes = [0, 3]
61+
62+
for size in action_state_sizes:
63+
rospy.set_param(rospy.get_namespace() + "action_state_size", size)
64+
try:
65+
return PPO.load(os.path.join(agent_path, "best_model.zip")).policy
66+
except:
67+
pass
68+
69+
rospy.signal_shutdown("")
5870

5971
def _get_model_path(self, model_name):
6072
return os.path.join(

0 commit comments

Comments
 (0)