Skip to content

Commit 7d9c151

Browse files
pseudo-rnd-thoughtsMark Towers
andauthored
Optimise Box __init__ (#1529)
Co-authored-by: Mark Towers <mark@anyscale.com>
1 parent 05b3e37 commit 7d9c151

File tree

1 file changed

+24
-29
lines changed

1 file changed

+24
-29
lines changed

gymnasium/spaces/box.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ def array_short_repr(arr: NDArray[Any]) -> str:
3131

3232
def is_float_integer(var: Any) -> bool:
3333
"""Checks if a scalar variable is an integer or float (does not include bool)."""
34-
return np.issubdtype(type(var), np.integer) or np.issubdtype(type(var), np.floating)
34+
return isinstance(var, (int, float, np.integer, np.floating)) and not isinstance(
35+
var, bool
36+
)
3537

3638

3739
class Box(Space[NDArray[Any]]):
@@ -88,11 +90,7 @@ def __init__(
8890
self.dtype = np.dtype(dtype)
8991

9092
# * check that dtype is an accepted dtype
91-
if not (
92-
np.issubdtype(self.dtype, np.integer)
93-
or np.issubdtype(self.dtype, np.floating)
94-
or self.dtype == np.bool_
95-
):
93+
if self.dtype.kind not in ("i", "u", "f", "b"):
9694
raise ValueError(
9795
f"Invalid Box dtype ({self.dtype}), must be an integer, floating, or bool dtype"
9896
)
@@ -103,7 +101,7 @@ def __init__(
103101
raise TypeError(
104102
f"Expected Box shape to be an iterable, actual type={type(shape)}"
105103
)
106-
elif not all(np.issubdtype(type(dim), np.integer) for dim in shape):
104+
elif not all(isinstance(dim, (int, np.integer)) for dim in shape):
107105
raise TypeError(
108106
f"Expected all Box shape elements to be integer, actual type={tuple(type(dim) for dim in shape)}"
109107
)
@@ -135,9 +133,9 @@ def __init__(
135133
# * nan - must be done beforehand as int dtype can cast `nan` to another value
136134
# * unsign int inf and -inf - special case that is disallowed
137135

138-
if self.dtype == np.bool_:
136+
if self.dtype.kind == "b":
139137
dtype_min, dtype_max = 0, 1
140-
elif np.issubdtype(self.dtype, np.floating):
138+
elif self.dtype.kind == "f":
141139
dtype_min = float(np.finfo(self.dtype).min)
142140
dtype_max = float(np.finfo(self.dtype).max)
143141
else:
@@ -164,8 +162,8 @@ def __init__(
164162
f"Box all low values must be less than or equal to high (some values break this), low={self.low}, high={self.high}"
165163
)
166164

167-
self.low_repr = array_short_repr(self.low)
168-
self.high_repr = array_short_repr(self.high)
165+
self.low_repr = None
166+
self.high_repr = None
169167

170168
super().__init__(self.shape, self.dtype, seed)
171169

@@ -180,7 +178,7 @@ def _cast_low(self, low, dtype_min) -> tuple[np.ndarray, np.ndarray]:
180178
The updated low value and for what values the input is bounded (below)
181179
"""
182180
if is_float_integer(low):
183-
bounded_below = -np.inf < np.full(self.shape, low, dtype=float)
181+
bounded_below = np.full(self.shape, -np.inf < low)
184182

185183
if np.isnan(low):
186184
raise ValueError(f"No low value can be equal to `np.nan`, low={low}")
@@ -203,11 +201,7 @@ def _cast_low(self, low, dtype_min) -> tuple[np.ndarray, np.ndarray]:
203201
raise ValueError(
204202
f"Box low must be a np.ndarray, integer, or float, actual type={type(low)}"
205203
)
206-
elif not (
207-
np.issubdtype(low.dtype, np.floating)
208-
or np.issubdtype(low.dtype, np.integer)
209-
or low.dtype == np.bool_
210-
):
204+
elif low.dtype.kind not in ("f", "i", "u", "b"):
211205
raise ValueError(
212206
f"Box low must be a floating, integer, or bool dtype, actual dtype={low.dtype}"
213207
)
@@ -216,9 +210,10 @@ def _cast_low(self, low, dtype_min) -> tuple[np.ndarray, np.ndarray]:
216210

217211
bounded_below = -np.inf < low
218212

219-
if np.any(np.isneginf(low)):
213+
neginf = np.isneginf(low)
214+
if np.any(neginf):
220215
if self.dtype.kind == "i": # signed int
221-
low[np.isneginf(low)] = dtype_min
216+
low[neginf] = dtype_min
222217
elif self.dtype.kind in {"u", "b"}: # unsigned int and bool
223218
raise ValueError(
224219
f"Box unsigned int dtype don't support `-np.inf`, low={low}"
@@ -229,8 +224,8 @@ def _cast_low(self, low, dtype_min) -> tuple[np.ndarray, np.ndarray]:
229224
)
230225

231226
if (
232-
np.issubdtype(low.dtype, np.floating)
233-
and np.issubdtype(self.dtype, np.floating)
227+
low.dtype.kind == "f"
228+
and self.dtype.kind == "f"
234229
and np.finfo(self.dtype).precision < np.finfo(low.dtype).precision
235230
):
236231
gym.logger.warn(
@@ -249,7 +244,7 @@ def _cast_high(self, high, dtype_max) -> tuple[np.ndarray, np.ndarray]:
249244
The updated high value and for what values the input is bounded (above)
250245
"""
251246
if is_float_integer(high):
252-
bounded_above = np.full(self.shape, high, dtype=float) < np.inf
247+
bounded_above = np.full(self.shape, high < np.inf)
253248

254249
if np.isnan(high):
255250
raise ValueError(f"No high value can be equal to `np.nan`, high={high}")
@@ -272,11 +267,7 @@ def _cast_high(self, high, dtype_max) -> tuple[np.ndarray, np.ndarray]:
272267
raise ValueError(
273268
f"Box high must be a np.ndarray, integer, or float, actual type={type(high)}"
274269
)
275-
elif not (
276-
np.issubdtype(high.dtype, np.floating)
277-
or np.issubdtype(high.dtype, np.integer)
278-
or high.dtype == np.bool_
279-
):
270+
elif high.dtype.kind not in ("f", "i", "u", "b"):
280271
raise ValueError(
281272
f"Box high must be a floating or integer dtype, actual dtype={high.dtype}"
282273
)
@@ -299,8 +290,8 @@ def _cast_high(self, high, dtype_max) -> tuple[np.ndarray, np.ndarray]:
299290
)
300291

301292
if (
302-
np.issubdtype(high.dtype, np.floating)
303-
and np.issubdtype(self.dtype, np.floating)
293+
high.dtype.kind == "f"
294+
and self.dtype.kind == "f"
304295
and np.finfo(self.dtype).precision < np.finfo(high.dtype).precision
305296
):
306297
gym.logger.warn(
@@ -451,6 +442,10 @@ def __repr__(self) -> str:
451442
Returns:
452443
A representation of the space
453444
"""
445+
if self.low_repr is None:
446+
self.low_repr = array_short_repr(self.low)
447+
if self.high_repr is None:
448+
self.high_repr = array_short_repr(self.high)
454449
return f"Box({self.low_repr}, {self.high_repr}, {self.shape}, {self.dtype})"
455450

456451
def __eq__(self, other: Any) -> bool:

0 commit comments

Comments
 (0)