|
6 | 6 | # SPDX-License-Identifier: MPL-2.0 |
7 | 7 | # This file is part of Grid2Op, Grid2Op a testbed platform to model sequential decision making in power systems. |
8 | 8 |
|
9 | | -import grid2op |
10 | 9 | import unittest |
11 | 10 | import warnings |
12 | 11 |
|
|
17 | 16 | CAN_TEST_ALL = True |
18 | 17 | if GYMNASIUM_AVAILABLE: |
19 | 18 | from gymnasium.utils.env_checker import check_env |
20 | | - from gymnasium.utils.env_checker import check_reset_return_type, check_reset_options, check_reset_seed |
| 19 | + from gymnasium.utils.env_checker import check_reset_return_type, check_reset_options |
| 20 | + try: |
| 21 | + from gymnasium.utils.env_checker import check_reset_seed |
| 22 | + except ImportError: |
| 23 | + # not present in most recent version of gymnasium, I copy pasted |
| 24 | + # it from an oldest version |
| 25 | + import gymnasium |
| 26 | + from logging import getLogger |
| 27 | + import inspect |
| 28 | + from copy import deepcopy |
| 29 | + import numpy as np |
| 30 | + logger = getLogger() |
| 31 | + |
| 32 | + |
| 33 | + def data_equivalence(data_1, data_2) -> bool: |
| 34 | + """Assert equality between data 1 and 2, i.e observations, actions, info. |
| 35 | +
|
| 36 | + Args: |
| 37 | + data_1: data structure 1 |
| 38 | + data_2: data structure 2 |
| 39 | +
|
| 40 | + Returns: |
| 41 | + If observation 1 and 2 are equivalent |
| 42 | + """ |
| 43 | + if type(data_1) == type(data_2): |
| 44 | + if isinstance(data_1, dict): |
| 45 | + return data_1.keys() == data_2.keys() and all( |
| 46 | + data_equivalence(data_1[k], data_2[k]) for k in data_1.keys() |
| 47 | + ) |
| 48 | + elif isinstance(data_1, (tuple, list)): |
| 49 | + return len(data_1) == len(data_2) and all( |
| 50 | + data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2) |
| 51 | + ) |
| 52 | + elif isinstance(data_1, np.ndarray): |
| 53 | + return data_1.shape == data_2.shape and np.allclose( |
| 54 | + data_1, data_2, atol=0.00001 |
| 55 | + ) |
| 56 | + else: |
| 57 | + return data_1 == data_2 |
| 58 | + else: |
| 59 | + return False |
| 60 | + |
| 61 | + |
| 62 | + def check_reset_seed(env: gymnasium.Env): |
| 63 | + """Check that the environment can be reset with a seed. |
| 64 | +
|
| 65 | + Args: |
| 66 | + env: The environment to check |
| 67 | +
|
| 68 | + Raises: |
| 69 | + AssertionError: The environment cannot be reset with a random seed, |
| 70 | + even though `seed` or `kwargs` appear in the signature. |
| 71 | + """ |
| 72 | + signature = inspect.signature(env.reset) |
| 73 | + if "seed" in signature.parameters or ( |
| 74 | + "kwargs" in signature.parameters |
| 75 | + and signature.parameters["kwargs"].kind is inspect.Parameter.VAR_KEYWORD |
| 76 | + ): |
| 77 | + try: |
| 78 | + obs_1, info = env.reset(seed=123) |
| 79 | + assert ( |
| 80 | + obs_1 in env.observation_space |
| 81 | + ), "The observation returned by `env.reset(seed=123)` is not within the observation space." |
| 82 | + assert ( |
| 83 | + env.unwrapped._np_random # pyright: ignore [reportPrivateUsage] |
| 84 | + is not None |
| 85 | + ), "Expects the random number generator to have been generated given a seed was passed to reset. Mostly likely the environment reset function does not call `super().reset(seed=seed)`." |
| 86 | + seed_123_rng = deepcopy( |
| 87 | + env.unwrapped._np_random # pyright: ignore [reportPrivateUsage] |
| 88 | + ) |
| 89 | + |
| 90 | + obs_2, info = env.reset(seed=123) |
| 91 | + assert ( |
| 92 | + obs_2 in env.observation_space |
| 93 | + ), "The observation returned by `env.reset(seed=123)` is not within the observation space." |
| 94 | + if env.spec is not None and env.spec.nondeterministic is False: |
| 95 | + assert data_equivalence( |
| 96 | + obs_1, obs_2 |
| 97 | + ), "Using `env.reset(seed=123)` is non-deterministic as the observations are not equivalent." |
| 98 | + assert ( |
| 99 | + env.unwrapped._np_random.bit_generator.state # pyright: ignore [reportPrivateUsage] |
| 100 | + == seed_123_rng.bit_generator.state |
| 101 | + ), "Mostly likely the environment reset function does not call `super().reset(seed=seed)` as the random generates are not same when the same seeds are passed to `env.reset`." |
| 102 | + |
| 103 | + obs_3, info = env.reset(seed=456) |
| 104 | + assert ( |
| 105 | + obs_3 in env.observation_space |
| 106 | + ), "The observation returned by `env.reset(seed=456)` is not within the observation space." |
| 107 | + assert ( |
| 108 | + env.unwrapped._np_random.bit_generator.state # pyright: ignore [reportPrivateUsage] |
| 109 | + != seed_123_rng.bit_generator.state |
| 110 | + ), "Mostly likely the environment reset function does not call `super().reset(seed=seed)` as the random number generators are not different when different seeds are passed to `env.reset`." |
| 111 | + |
| 112 | + except TypeError as e: |
| 113 | + raise AssertionError( |
| 114 | + "The environment cannot be reset with a random seed, even though `seed` or `kwargs` appear in the signature. " |
| 115 | + f"This should never happen, please report this issue. The error was: {e}" |
| 116 | + ) from e |
| 117 | + |
| 118 | + seed_param = signature.parameters.get("seed") |
| 119 | + # Check the default value is None |
| 120 | + if seed_param is not None and seed_param.default is not None: |
| 121 | + logger.warning( |
| 122 | + "The default seed argument in reset should be `None`, otherwise the environment will by default always be deterministic. " |
| 123 | + f"Actual default: {seed_param.default}" |
| 124 | + ) |
| 125 | + else: |
| 126 | + raise gymnasium.error.Error( |
| 127 | + "The `reset` method does not provide a `seed` or `**kwargs` keyword argument." |
| 128 | + ) |
| 129 | + |
| 130 | + |
21 | 131 | elif GYM_AVAILABLE: |
22 | 132 | from gym.utils.env_checker import check_env |
23 | 133 | from gym.utils.env_checker import check_reset_return_type, check_reset_options, check_reset_seed |
|
0 commit comments