|
4 | 4 | # SPDX-License-Identifier: BSD-3-Clause |
5 | 5 |
|
6 | 6 | import gymnasium as gym |
| 7 | +import json |
7 | 8 | import numpy as np |
8 | 9 | import torch |
9 | 10 | from typing import Any |
@@ -90,3 +91,131 @@ def tensorize(s, x): |
90 | 91 |
|
91 | 92 | sample = (gym.vector.utils.batch_space(space, batch_size) if batch_size > 0 else space).sample() |
92 | 93 | return tensorize(space, sample) |
| 94 | + |
| 95 | + |
| 96 | +def serialize_space(space: SpaceType) -> str: |
| 97 | + """Serialize a space specification as JSON. |
| 98 | +
|
| 99 | + Args: |
| 100 | + space: Space specification. |
| 101 | +
|
| 102 | + Returns: |
| 103 | + Serialized JSON representation. |
| 104 | + """ |
| 105 | + # Gymnasium spaces |
| 106 | + if isinstance(space, gym.spaces.Discrete): |
| 107 | + return json.dumps({"type": "gymnasium", "space": "Discrete", "n": int(space.n)}) |
| 108 | + elif isinstance(space, gym.spaces.Box): |
| 109 | + return json.dumps({ |
| 110 | + "type": "gymnasium", |
| 111 | + "space": "Box", |
| 112 | + "low": space.low.tolist(), |
| 113 | + "high": space.high.tolist(), |
| 114 | + "shape": space.shape, |
| 115 | + }) |
| 116 | + elif isinstance(space, gym.spaces.MultiDiscrete): |
| 117 | + return json.dumps({"type": "gymnasium", "space": "MultiDiscrete", "nvec": space.nvec.tolist()}) |
| 118 | + elif isinstance(space, gym.spaces.Tuple): |
| 119 | + return json.dumps({"type": "gymnasium", "space": "Tuple", "spaces": tuple(map(serialize_space, space.spaces))}) |
| 120 | + elif isinstance(space, gym.spaces.Dict): |
| 121 | + return json.dumps( |
| 122 | + {"type": "gymnasium", "space": "Dict", "spaces": {k: serialize_space(v) for k, v in space.spaces.items()}} |
| 123 | + ) |
| 124 | + # Python data types |
| 125 | + # Box |
| 126 | + elif isinstance(space, int) or (isinstance(space, list) and all(isinstance(x, int) for x in space)): |
| 127 | + return json.dumps({"type": "python", "space": "Box", "value": space}) |
| 128 | + # Discrete |
| 129 | + elif isinstance(space, set) and len(space) == 1: |
| 130 | + return json.dumps({"type": "python", "space": "Discrete", "value": next(iter(space))}) |
| 131 | + # MultiDiscrete |
| 132 | + elif isinstance(space, list) and all(isinstance(x, set) and len(x) == 1 for x in space): |
| 133 | + return json.dumps({"type": "python", "space": "MultiDiscrete", "value": [next(iter(x)) for x in space]}) |
| 134 | + # composite spaces |
| 135 | + # Tuple |
| 136 | + elif isinstance(space, tuple): |
| 137 | + return json.dumps({"type": "python", "space": "Tuple", "value": [serialize_space(x) for x in space]}) |
| 138 | + # Dict |
| 139 | + elif isinstance(space, dict): |
| 140 | + return json.dumps( |
| 141 | + {"type": "python", "space": "Dict", "value": {k: serialize_space(v) for k, v in space.items()}} |
| 142 | + ) |
| 143 | + raise ValueError(f"Unsupported space ({space})") |
| 144 | + |
| 145 | + |
| 146 | +def deserialize_space(string: str) -> gym.spaces.Space: |
| 147 | + """Deserialize a space specification encoded as JSON. |
| 148 | +
|
| 149 | + Args: |
| 150 | + string: Serialized JSON representation. |
| 151 | +
|
| 152 | + Returns: |
| 153 | + Space specification. |
| 154 | + """ |
| 155 | + obj = json.loads(string) |
| 156 | + # Gymnasium spaces |
| 157 | + if obj["type"] == "gymnasium": |
| 158 | + if obj["space"] == "Discrete": |
| 159 | + return gym.spaces.Discrete(n=obj["n"]) |
| 160 | + elif obj["space"] == "Box": |
| 161 | + return gym.spaces.Box(low=np.array(obj["low"]), high=np.array(obj["high"]), shape=obj["shape"]) |
| 162 | + elif obj["space"] == "MultiDiscrete": |
| 163 | + return gym.spaces.MultiDiscrete(nvec=np.array(obj["nvec"])) |
| 164 | + elif obj["space"] == "Tuple": |
| 165 | + return gym.spaces.Tuple(spaces=tuple(map(deserialize_space, obj["spaces"]))) |
| 166 | + elif obj["space"] == "Dict": |
| 167 | + return gym.spaces.Dict(spaces={k: deserialize_space(v) for k, v in obj["spaces"].items()}) |
| 168 | + else: |
| 169 | + raise ValueError(f"Unsupported space ({obj['spaces']})") |
| 170 | + # Python data types |
| 171 | + elif obj["type"] == "python": |
| 172 | + if obj["space"] == "Discrete": |
| 173 | + return {obj["value"]} |
| 174 | + elif obj["space"] == "Box": |
| 175 | + return obj["value"] |
| 176 | + elif obj["space"] == "MultiDiscrete": |
| 177 | + return [{x} for x in obj["value"]] |
| 178 | + elif obj["space"] == "Tuple": |
| 179 | + return tuple(map(deserialize_space, obj["value"])) |
| 180 | + elif obj["space"] == "Dict": |
| 181 | + return {k: deserialize_space(v) for k, v in obj["value"].items()} |
| 182 | + else: |
| 183 | + raise ValueError(f"Unsupported space ({obj['spaces']})") |
| 184 | + else: |
| 185 | + raise ValueError(f"Unsupported type ({obj['type']})") |
| 186 | + |
| 187 | + |
| 188 | +def replace_env_cfg_spaces_with_strings(env_cfg: object) -> object: |
| 189 | + """Replace spaces objects with their serialized JSON representations in an environment config. |
| 190 | +
|
| 191 | + Args: |
| 192 | + env_cfg: Environment config instance. |
| 193 | +
|
| 194 | + Returns: |
| 195 | + Environment config instance with spaces replaced if any. |
| 196 | + """ |
| 197 | + for attr in ["observation_space", "action_space", "state_space"]: |
| 198 | + if hasattr(env_cfg, attr): |
| 199 | + setattr(env_cfg, attr, serialize_space(getattr(env_cfg, attr))) |
| 200 | + for attr in ["observation_spaces", "action_spaces"]: |
| 201 | + if hasattr(env_cfg, attr): |
| 202 | + setattr(env_cfg, attr, {k: serialize_space(v) for k, v in getattr(env_cfg, attr).items()}) |
| 203 | + return env_cfg |
| 204 | + |
| 205 | + |
| 206 | +def replace_strings_with_env_cfg_spaces(env_cfg: object) -> object: |
| 207 | + """Replace spaces objects with their serialized JSON representations in an environment config. |
| 208 | +
|
| 209 | + Args: |
| 210 | + env_cfg: Environment config instance. |
| 211 | +
|
| 212 | + Returns: |
| 213 | + Environment config instance with spaces replaced if any. |
| 214 | + """ |
| 215 | + for attr in ["observation_space", "action_space", "state_space"]: |
| 216 | + if hasattr(env_cfg, attr): |
| 217 | + setattr(env_cfg, attr, deserialize_space(getattr(env_cfg, attr))) |
| 218 | + for attr in ["observation_spaces", "action_spaces"]: |
| 219 | + if hasattr(env_cfg, attr): |
| 220 | + setattr(env_cfg, attr, {k: deserialize_space(v) for k, v in getattr(env_cfg, attr).items()}) |
| 221 | + return env_cfg |
0 commit comments