|
6 | 6 | import brax |
7 | 7 | import numpy as np |
8 | 8 | from brax.envs.ant import _SYSTEM_CONFIG, Ant |
9 | | -from brax.envs.wrappers import GymWrapper, VectorWrapper, VectorGymWrapper |
| 9 | +from brax.envs.wrappers import GymWrapper, VectorGymWrapper, VectorWrapper |
10 | 10 | from google.protobuf import json_format, text_format |
11 | 11 | from google.protobuf.json_format import MessageToDict |
12 | 12 | from numpyencoder import NumpyEncoder |
13 | 13 |
|
| 14 | +from carl.context.selection import AbstractSelector |
14 | 15 | from carl.envs.carl_env import CARLEnv |
15 | 16 | from carl.utils.trial_logger import TrialLogger |
16 | | -from carl.context.selection import AbstractSelector |
17 | 17 |
|
18 | 18 | DEFAULT_CONTEXT = { |
19 | 19 | "joint_stiffness": 5000, |
|
38 | 38 |
|
39 | 39 | class CARLAnt(CARLEnv): |
40 | 40 | def __init__( |
41 | | - self, |
42 | | - env: Ant = Ant(), |
43 | | - n_envs: int = 1, |
44 | | - contexts: Dict[str, Dict] = {}, |
45 | | - hide_context=False, |
46 | | - add_gaussian_noise_to_context: bool = False, |
47 | | - gaussian_noise_std_percentage: float = 0.01, |
48 | | - logger: Optional[TrialLogger] = None, |
49 | | - scale_context_features: str = "no", |
50 | | - default_context: Optional[Dict] = DEFAULT_CONTEXT, |
51 | | - state_context_features: Optional[List[str]] = None, |
52 | | - context_mask: Optional[List[str]] = None, |
53 | | - dict_observation_space: bool = False, |
54 | | - context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None, |
55 | | - context_selector_kwargs: Optional[Dict] = None, |
56 | | - |
| 41 | + self, |
| 42 | + env: Ant = Ant(), |
| 43 | + n_envs: int = 1, |
| 44 | + contexts: Dict[str, Dict] = {}, |
| 45 | + hide_context=False, |
| 46 | + add_gaussian_noise_to_context: bool = False, |
| 47 | + gaussian_noise_std_percentage: float = 0.01, |
| 48 | + logger: Optional[TrialLogger] = None, |
| 49 | + scale_context_features: str = "no", |
| 50 | + default_context: Optional[Dict] = DEFAULT_CONTEXT, |
| 51 | + state_context_features: Optional[List[str]] = None, |
| 52 | + context_mask: Optional[List[str]] = None, |
| 53 | + dict_observation_space: bool = False, |
| 54 | + context_selector: Optional[ |
| 55 | + Union[AbstractSelector, type(AbstractSelector)] |
| 56 | + ] = None, |
| 57 | + context_selector_kwargs: Optional[Dict] = None, |
57 | 58 | ): |
58 | 59 | if n_envs == 1: |
59 | 60 | env = GymWrapper(env) |
|
0 commit comments