@@ -63,6 +63,8 @@ class CARLEnv(Wrapper):
6363 If the context is visible to the agent (hide_context=False), the context features are appended to the state.
6464 state_context_features specifies which of the context features are appended to the state. The default is
6565 appending all context features.
66+ context_mask: Optional[List[str]]
67+ Name of context features to be ignored when appending context features to the state.
6668 context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]]
6769 Context selector (object of) class, e.g., can be RoundRobinSelector (default) or RandomSelector.
6870 Should subclass AbstractSelector.
@@ -94,6 +96,7 @@ def __init__(
9496 scale_context_features : str = "no" ,
9597 default_context : Optional [Dict ] = None ,
9698 state_context_features : Optional [List [str ]] = None ,
99+ context_mask : Optional [List [str ]] = None ,
97100 dict_observation_space : bool = False ,
98101 context_selector : Optional [Union [AbstractSelector , type (AbstractSelector )]] = None ,
99102 context_selector_kwargs : Optional [Dict ] = None ,
@@ -104,6 +107,7 @@ def __init__(
104107 self ._contexts : Optional [Dict [Any , Dict [Any , Any ]]] = None # init for property
105108 self .default_context = default_context
106109 self .contexts = contexts
110+ self .context_mask = context_mask
107111 self .hide_context = hide_context
108112 self .dict_observation_space = dict_observation_space
109113 self .cutoff = max_episode_length
@@ -153,7 +157,14 @@ def __init__(
153157 json .dump (data , file , indent = "\t " )
154158 else :
155159 state_context_features = []
156- self .state_context_features = state_context_features
160+ else :
161+ state_context_features = list (self .contexts [list (self .contexts .keys ())[0 ]].keys ())
162+ self .state_context_features : List [str ] = state_context_features
163+ # state_context_features contains the names of the context features that should be appended to the state
164+ # However, if context_mask is set, we want to update staet_context_feature_names so that the context features
165+ # in context_mask are not appended to the state anymore.
166+ if self .context_mask :
167+ self .state_context_features = [s for s in self .state_context_features if s not in self .context_mask ]
157168
158169 self .step_counter = 0 # type: int # increased in/after step
159170 self .total_timestep_counter = 0 # type: int
0 commit comments