Skip to content

Commit eada47e

Browse files
authored
Merge pull request #54 from automl/context_mask_#53
Context mask #53
2 parents 49bb50f + bd84c25 commit eada47e

19 files changed

+80
-4
lines changed

carl/envs/box2d/carl_bipedal_walker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def __init__(
8989
scale_context_features: str = "no",
9090
default_context: Optional[Dict] = DEFAULT_CONTEXT,
9191
state_context_features: Optional[List[str]] = None,
92+
context_mask: Optional[List[str]] = None,
9293
dict_observation_space: bool = False,
9394
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None,
9495
context_selector_kwargs: Optional[Dict] = None,
@@ -119,7 +120,8 @@ def __init__(
119120
state_context_features=state_context_features,
120121
dict_observation_space=dict_observation_space,
121122
context_selector=context_selector,
122-
context_selector_kwargs=context_selector_kwargs
123+
context_selector_kwargs=context_selector_kwargs,
124+
context_mask=context_mask,
123125

124126
)
125127
self.whitelist_gaussian_noise = list(

carl/envs/box2d/carl_lunarlander.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def __init__(
113113
scale_context_features: str = "no",
114114
default_context: Optional[Dict] = DEFAULT_CONTEXT,
115115
state_context_features: Optional[List[str]] = None,
116+
context_mask: Optional[List[str]] = None,
116117
max_episode_length: int = 1000,
117118
high_gameover_penalty: bool = False,
118119
dict_observation_space: bool = False,
@@ -147,7 +148,8 @@ def __init__(
147148
max_episode_length=max_episode_length,
148149
dict_observation_space=dict_observation_space,
149150
context_selector=context_selector,
150-
context_selector_kwargs=context_selector_kwargs
151+
context_selector_kwargs=context_selector_kwargs,
152+
context_mask=context_mask,
151153
)
152154
self.whitelist_gaussian_noise = list(
153155
DEFAULT_CONTEXT.keys()

carl/envs/box2d/carl_vehicle_racing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def __init__(
194194
scale_context_features: str = "no",
195195
default_context: Optional[Dict] = DEFAULT_CONTEXT,
196196
state_context_features: Optional[List[str]] = None,
197+
context_mask: Optional[List[str]] = None,
197198
dict_observation_space: bool = False,
198199
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None,
199200
context_selector_kwargs: Optional[Dict] = None,
@@ -230,6 +231,7 @@ def __init__(
230231
dict_observation_space=dict_observation_space,
231232
context_selector=context_selector,
232233
context_selector_kwargs=context_selector_kwargs,
234+
context_mask=context_mask,
233235
)
234236
self.whitelist_gaussian_noise = [
235237
k for k in DEFAULT_CONTEXT.keys() if k not in CATEGORICAL_CONTEXT_FEATURES

carl/envs/brax/carl_ant.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
scale_context_features: str = "no",
5050
default_context: Optional[Dict] = DEFAULT_CONTEXT,
5151
state_context_features: Optional[List[str]] = None,
52+
context_mask: Optional[List[str]] = None,
5253
dict_observation_space: bool = False,
5354
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None,
5455
context_selector_kwargs: Optional[Dict] = None,
@@ -78,6 +79,7 @@ def __init__(
7879
dict_observation_space=dict_observation_space,
7980
context_selector=context_selector,
8081
context_selector_kwargs=context_selector_kwargs,
82+
context_mask=context_mask,
8183
)
8284
self.whitelist_gaussian_noise = list(
8385
DEFAULT_CONTEXT.keys()

carl/envs/brax/carl_fetch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353
scale_context_features: str = "no",
5454
default_context: Optional[Dict] = DEFAULT_CONTEXT,
5555
state_context_features: Optional[List[str]] = None,
56+
context_mask: Optional[List[str]] = None,
5657
dict_observation_space: bool = False,
5758
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None,
5859
context_selector_kwargs: Optional[Dict] = None,
@@ -81,6 +82,7 @@ def __init__(
8182
dict_observation_space=dict_observation_space,
8283
context_selector=context_selector,
8384
context_selector_kwargs=context_selector_kwargs,
85+
context_mask=context_mask,
8486
)
8587
self.whitelist_gaussian_noise = list(
8688
DEFAULT_CONTEXT.keys()

carl/envs/brax/carl_grasp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353
scale_context_features: str = "no",
5454
default_context: Optional[Dict] = DEFAULT_CONTEXT,
5555
state_context_features: Optional[List[str]] = None,
56+
context_mask: Optional[List[str]] = None,
5657
dict_observation_space: bool = False,
5758
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None,
5859
context_selector_kwargs: Optional[Dict] = None,
@@ -81,6 +82,7 @@ def __init__(
8182
dict_observation_space=dict_observation_space,
8283
context_selector=context_selector,
8384
context_selector_kwargs=context_selector_kwargs,
85+
context_mask=context_mask,
8486
)
8587
self.whitelist_gaussian_noise = list(
8688
DEFAULT_CONTEXT.keys()

carl/envs/brax/carl_halfcheetah.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(
4747
scale_context_features: str = "no",
4848
default_context: Optional[Dict] = DEFAULT_CONTEXT,
4949
state_context_features: Optional[List[str]] = None,
50+
context_mask: Optional[List[str]] = None,
5051
dict_observation_space: bool = False,
5152
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None,
5253
context_selector_kwargs: Optional[Dict] = None,
@@ -75,6 +76,7 @@ def __init__(
7576
dict_observation_space=dict_observation_space,
7677
context_selector=context_selector,
7778
context_selector_kwargs=context_selector_kwargs,
79+
context_mask=context_mask,
7880
)
7981
self.whitelist_gaussian_noise = list(
8082
DEFAULT_CONTEXT.keys()

carl/envs/brax/carl_humanoid.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(
4848
scale_context_features: str = "no",
4949
default_context: Optional[Dict] = DEFAULT_CONTEXT,
5050
state_context_features: Optional[List[str]] = None,
51+
context_mask: Optional[List[str]] = None,
5152
dict_observation_space: bool = False,
5253
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None,
5354
context_selector_kwargs: Optional[Dict] = None,
@@ -76,6 +77,7 @@ def __init__(
7677
dict_observation_space=dict_observation_space,
7778
context_selector=context_selector,
7879
context_selector_kwargs=context_selector_kwargs,
80+
context_mask=context_mask,
7981
)
8082
self.whitelist_gaussian_noise = list(
8183
DEFAULT_CONTEXT.keys()

carl/envs/brax/carl_ur5e.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353
scale_context_features: str = "no",
5454
default_context: Optional[Dict] = DEFAULT_CONTEXT,
5555
state_context_features: Optional[List[str]] = None,
56+
context_mask: Optional[List[str]] = None,
5657
dict_observation_space: bool = False,
5758
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None,
5859
context_selector_kwargs: Optional[Dict] = None,
@@ -81,6 +82,7 @@ def __init__(
8182
dict_observation_space=dict_observation_space,
8283
context_selector=context_selector,
8384
context_selector_kwargs=context_selector_kwargs,
85+
context_mask=context_mask,
8486
)
8587
self.whitelist_gaussian_noise = list(
8688
DEFAULT_CONTEXT.keys()

carl/envs/carl_env.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ class CARLEnv(Wrapper):
6363
If the context is visible to the agent (hide_context=False), the context features are appended to the state.
6464
state_context_features specifies which of the context features are appended to the state. The default is
6565
appending all context features.
66+
context_mask: Optional[List[str]]
67+
Name of context features to be ignored when appending context features to the state.
6668
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]]
6769
Context selector (object of) class, e.g., can be RoundRobinSelector (default) or RandomSelector.
6870
Should subclass AbstractSelector.
@@ -94,6 +96,7 @@ def __init__(
9496
scale_context_features: str = "no",
9597
default_context: Optional[Dict] = None,
9698
state_context_features: Optional[List[str]] = None,
99+
context_mask: Optional[List[str]] = None,
97100
dict_observation_space: bool = False,
98101
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None,
99102
context_selector_kwargs: Optional[Dict] = None,
@@ -104,6 +107,7 @@ def __init__(
104107
self._contexts: Optional[Dict[Any, Dict[Any, Any]]] = None # init for property
105108
self.default_context = default_context
106109
self.contexts = contexts
110+
self.context_mask = context_mask
107111
self.hide_context = hide_context
108112
self.dict_observation_space = dict_observation_space
109113
self.cutoff = max_episode_length
@@ -153,7 +157,14 @@ def __init__(
153157
json.dump(data, file, indent="\t")
154158
else:
155159
state_context_features = []
156-
self.state_context_features = state_context_features
160+
else:
161+
state_context_features = list(self.contexts[list(self.contexts.keys())[0]].keys())
162+
self.state_context_features: List[str] = state_context_features
163+
# state_context_features contains the names of the context features that should be appended to the state
164+
# However, if context_mask is set, we want to update staet_context_feature_names so that the context features
165+
# in context_mask are not appended to the state anymore.
166+
if self.context_mask:
167+
self.state_context_features = [s for s in self.state_context_features if s not in self.context_mask]
157168

158169
self.step_counter = 0 # type: int # increased in/after step
159170
self.total_timestep_counter = 0 # type: int

0 commit comments

Comments
 (0)