|
6 | 6 | from __future__ import annotations
|
7 | 7 |
|
8 | 8 | import abc
|
9 |
| -import functools |
10 | 9 | import warnings
|
11 | 10 | from copy import deepcopy
|
| 11 | +from functools import partial, wraps |
12 | 12 | from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
|
13 | 13 |
|
14 | 14 | import numpy as np
|
|
33 | 33 | _StepMDP,
|
34 | 34 | _terminated_or_truncated,
|
35 | 35 | _update_during_reset,
|
| 36 | + check_env_specs as check_env_specs_func, |
36 | 37 | get_available_libraries,
|
37 | 38 | )
|
38 | 39 |
|
@@ -2035,7 +2036,7 @@ def _register_gym(
|
2035 | 2036 |
|
2036 | 2037 | if entry_point is None:
|
2037 | 2038 | entry_point = cls
|
2038 |
| - entry_point = functools.partial( |
| 2039 | + entry_point = partial( |
2039 | 2040 | _TorchRLGymWrapper,
|
2040 | 2041 | entry_point=entry_point,
|
2041 | 2042 | info_keys=info_keys,
|
@@ -2084,7 +2085,7 @@ def _register_gym( # noqa: F811
|
2084 | 2085 |
|
2085 | 2086 | if entry_point is None:
|
2086 | 2087 | entry_point = cls
|
2087 |
| - entry_point = functools.partial( |
| 2088 | + entry_point = partial( |
2088 | 2089 | _TorchRLGymWrapper,
|
2089 | 2090 | entry_point=entry_point,
|
2090 | 2091 | info_keys=info_keys,
|
@@ -2138,7 +2139,7 @@ def _register_gym( # noqa: F811
|
2138 | 2139 |
|
2139 | 2140 | if entry_point is None:
|
2140 | 2141 | entry_point = cls
|
2141 |
| - entry_point = functools.partial( |
| 2142 | + entry_point = partial( |
2142 | 2143 | _TorchRLGymWrapper,
|
2143 | 2144 | entry_point=entry_point,
|
2144 | 2145 | info_keys=info_keys,
|
@@ -2195,7 +2196,7 @@ def _register_gym( # noqa: F811
|
2195 | 2196 |
|
2196 | 2197 | if entry_point is None:
|
2197 | 2198 | entry_point = cls
|
2198 |
| - entry_point = functools.partial( |
| 2199 | + entry_point = partial( |
2199 | 2200 | _TorchRLGymWrapper,
|
2200 | 2201 | entry_point=entry_point,
|
2201 | 2202 | info_keys=info_keys,
|
@@ -2254,7 +2255,7 @@ def _register_gym( # noqa: F811
|
2254 | 2255 | )
|
2255 | 2256 | if entry_point is None:
|
2256 | 2257 | entry_point = cls
|
2257 |
| - entry_point = functools.partial( |
| 2258 | + entry_point = partial( |
2258 | 2259 | _TorchRLGymWrapper,
|
2259 | 2260 | entry_point=entry_point,
|
2260 | 2261 | info_keys=info_keys,
|
@@ -2293,7 +2294,7 @@ def _register_gym( # noqa: F811
|
2293 | 2294 | if entry_point is None:
|
2294 | 2295 | entry_point = cls
|
2295 | 2296 |
|
2296 |
| - entry_point = functools.partial( |
| 2297 | + entry_point = partial( |
2297 | 2298 | _TorchRLGymnasiumWrapper,
|
2298 | 2299 | entry_point=entry_point,
|
2299 | 2300 | info_keys=info_keys,
|
@@ -3422,11 +3423,11 @@ def _get_sync_func(policy_device, env_device):
|
3422 | 3423 | if policy_device is not None and policy_device.type == "cuda":
|
3423 | 3424 | if env_device is None or env_device.type == "cuda":
|
3424 | 3425 | return torch.cuda.synchronize
|
3425 |
| - return functools.partial(torch.cuda.synchronize, device=policy_device) |
| 3426 | + return partial(torch.cuda.synchronize, device=policy_device) |
3426 | 3427 | if env_device is not None and env_device.type == "cuda":
|
3427 | 3428 | if policy_device is None:
|
3428 | 3429 | return torch.cuda.synchronize
|
3429 |
| - return functools.partial(torch.cuda.synchronize, device=env_device) |
| 3430 | + return partial(torch.cuda.synchronize, device=env_device) |
3430 | 3431 | return torch.cuda.synchronize
|
3431 | 3432 | if torch.backends.mps.is_available():
|
3432 | 3433 | return torch.mps.synchronize
|
|
0 commit comments