Skip to content

Commit cc31995

Browse files
authored
Merge pull request #8 from Arena-Rosnav/marl
remove laser and robot state variable for network init
2 parents 23df6fa + 40500c1 commit cc31995

File tree

5 files changed

+173
-126
lines changed

5 files changed

+173
-126
lines changed

rosnav/model/agent_factory.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,23 @@ def inner_wrapper(wrapped_class) -> Callable:
4040
# end register()
4141

4242
@classmethod
43-
def instantiate(cls, name: str, **kwargs) -> Union[Type[BaseAgent], Type[BasePolicy]]:
43+
def instantiate(
44+
cls, name: str, **kwargs
45+
) -> Union[Type[BaseAgent], Type[BasePolicy]]:
4446
"""Factory command to create the agent.
4547
This method gets the appropriate agent class from the registry
4648
and creates an instance of it, while passing in the parameters
4749
given in ``kwargs``.
4850
4951
Args:
50-
name (str): The name of the agent to create.
52+
name (str): The name of the agent to create.agent_class
5153
5254
Returns:
5355
An instance of the agent that is created.
5456
"""
5557
assert name in cls.registry, f"Agent '{name}' is not registered!"
5658
agent_class = cls.registry[name]
57-
59+
5860
if issubclass(agent_class, BaseAgent):
5961
return agent_class(**kwargs)
6062
else:

rosnav/model/base_agent.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from typing import Type, List
2-
31
from abc import ABC, abstractmethod
42
from enum import Enum
5-
from torch.nn.modules.module import Module
3+
from typing import List, Type
4+
65
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
6+
from torch.nn.modules.module import Module
77

88

99
class PolicyType(Enum):
@@ -47,14 +47,17 @@ def activation_fn(self) -> Type[Module]:
4747
pass
4848

4949
def get_kwargs(self):
50+
fe_kwargs = self.features_extractor_kwargs
51+
fe_kwargs["robot_model"] = self.robot_model
52+
5053
kwargs = {
5154
"features_extractor_class": self.features_extractor_class,
52-
"features_extractor_kwargs": self.features_extractor_kwargs,
55+
"features_extractor_kwargs": fe_kwargs,
5356
"net_arch": self.net_arch,
5457
"activation_fn": self.activation_fn,
5558
}
56-
if not kwargs['features_extractor_class']:
57-
del kwargs['features_extractor_class']
58-
if not kwargs['features_extractor_kwargs']:
59-
del kwargs['features_extractor_kwargs']
59+
if not kwargs["features_extractor_class"]:
60+
del kwargs["features_extractor_class"]
61+
if not kwargs["features_extractor_kwargs"]:
62+
del kwargs["features_extractor_kwargs"]
6063
return kwargs

rosnav/model/custom_sb3_policy.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ class AGENT_6(BaseAgent):
1414
net_arch = [128, 64, 64, dict(pi=[64, 64], vf=[64, 64])]
1515
activation_fn = nn.ReLU
1616

17+
def __init__(self, robot_model: str = None):
18+
self.robot_model = robot_model
19+
1720

1821
@AgentFactory.register("AGENT_7")
1922
class AGENT_7(BaseAgent):
@@ -23,6 +26,9 @@ class AGENT_7(BaseAgent):
2326
net_arch = [128, 128, 128, dict(pi=[64, 64], vf=[64, 64])]
2427
activation_fn = nn.ReLU
2528

29+
def __init__(self, robot_model: str = None):
30+
self.robot_model = robot_model
31+
2632

2733
@AgentFactory.register("AGENT_8")
2834
class AGENT_8(BaseAgent):
@@ -32,6 +38,9 @@ class AGENT_8(BaseAgent):
3238
net_arch = [64, 64, 64, 64, dict(pi=[64, 64], vf=[64, 64])]
3339
activation_fn = nn.ReLU
3440

41+
def __init__(self, robot_model: str = None):
42+
self.robot_model = robot_model
43+
3544

3645
@AgentFactory.register("AGENT_9")
3746
class AGENT_9(BaseAgent):
@@ -41,6 +50,9 @@ class AGENT_9(BaseAgent):
4150
net_arch = [64, 64, 64, 64, dict(pi=[64, 64, 64], vf=[64, 64, 64])]
4251
activation_fn = nn.ReLU
4352

53+
def __init__(self, robot_model: str = None):
54+
self.robot_model = robot_model
55+
4456

4557
@AgentFactory.register("AGENT_10")
4658
class AGENT_10(BaseAgent):
@@ -50,6 +62,9 @@ class AGENT_10(BaseAgent):
5062
net_arch = [128, 128, 128, 128, dict(pi=[64, 64, 64], vf=[64, 64, 64])]
5163
activation_fn = nn.ReLU
5264

65+
def __init__(self, robot_model: str = None):
66+
self.robot_model = robot_model
67+
5368

5469
@AgentFactory.register("AGENT_11")
5570
class AGENT_11(BaseAgent):
@@ -59,6 +74,9 @@ class AGENT_11(BaseAgent):
5974
net_arch = [512, 512, 512, 512, dict(pi=[64, 64], vf=[64, 64])]
6075
activation_fn = nn.ReLU
6176

77+
def __init__(self, robot_model: str = None):
78+
self.robot_model = robot_model
79+
6280

6381
@AgentFactory.register("AGENT_17")
6482
class AGENT_17(BaseAgent):
@@ -68,6 +86,9 @@ class AGENT_17(BaseAgent):
6886
net_arch = [dict(pi=[64, 64, 64], vf=[64, 64, 64])]
6987
activation_fn = nn.ReLU
7088

89+
def __init__(self, robot_model: str = None):
90+
self.robot_model = robot_model
91+
7192

7293
@AgentFactory.register("AGENT_18")
7394
class AGENT_18(BaseAgent):
@@ -77,6 +98,9 @@ class AGENT_18(BaseAgent):
7798
net_arch = [128, dict(pi=[64, 64, 64], vf=[64, 64, 64])]
7899
activation_fn = nn.ReLU
79100

101+
def __init__(self, robot_model: str = None):
102+
self.robot_model = robot_model
103+
80104

81105
@AgentFactory.register("AGENT_19")
82106
class AGENT_19(BaseAgent):
@@ -86,6 +110,9 @@ class AGENT_19(BaseAgent):
86110
net_arch = [dict(pi=[64, 64], vf=[64, 64])]
87111
activation_fn = nn.ReLU
88112

113+
def __init__(self, robot_model: str = None):
114+
self.robot_model = robot_model
115+
89116

90117
@AgentFactory.register("AGENT_20")
91118
class AGENT_20(BaseAgent):
@@ -95,6 +122,9 @@ class AGENT_20(BaseAgent):
95122
net_arch = [dict(pi=[128], vf=[128])]
96123
activation_fn = nn.ReLU
97124

125+
def __init__(self, robot_model: str = None):
126+
self.robot_model = robot_model
127+
98128

99129
@AgentFactory.register("AGENT_21")
100130
class AGENT_21(BaseAgent):
@@ -104,6 +134,9 @@ class AGENT_21(BaseAgent):
104134
net_arch = [dict(pi=[64, 64], vf=[64, 64])]
105135
activation_fn = nn.ReLU
106136

137+
def __init__(self, robot_model: str = None):
138+
self.robot_model = robot_model
139+
107140

108141
@AgentFactory.register("AGENT_22")
109142
class AGENT_22(BaseAgent):
@@ -113,6 +146,9 @@ class AGENT_22(BaseAgent):
113146
net_arch = [dict(pi=[64, 64, 64], vf=[64, 64, 64])]
114147
activation_fn = nn.ReLU
115148

149+
def __init__(self, robot_model: str = None):
150+
self.robot_model = robot_model
151+
116152

117153
@AgentFactory.register("AGENT_23")
118154
class AGENT_23(BaseAgent):
@@ -122,6 +158,9 @@ class AGENT_23(BaseAgent):
122158
net_arch = [128, dict(pi=[64, 64, 64], vf=[64, 64, 64])]
123159
activation_fn = nn.ReLU
124160

161+
def __init__(self, robot_model: str = None):
162+
self.robot_model = robot_model
163+
125164

126165
@AgentFactory.register("AGENT_24")
127166
class AGENT_24(BaseAgent):
@@ -131,11 +170,17 @@ class AGENT_24(BaseAgent):
131170
net_arch = [128, dict(pi=[64, 64], vf=[64, 64])]
132171
activation_fn = nn.ReLU
133172

173+
def __init__(self, robot_model: str = None):
174+
self.robot_model = robot_model
175+
134176

135177
@AgentFactory.register("AGENT_25")
136178
class AGENT_25(BaseAgent):
137179
type = PolicyType.MLP
138180
features_extractor_class = None
139181
features_extractor_kwargs = None
140182
net_arch = [512, 256, dict(pi=[128], vf=[128])]
141-
activation_fn = nn.ReLU
183+
activation_fn = nn.ReLU
184+
185+
def __init__(self, robot_model: str = None):
186+
self.robot_model = robot_model

0 commit comments

Comments
 (0)