@@ -31,7 +31,9 @@ def array_short_repr(arr: NDArray[Any]) -> str:
3131
3232def is_float_integer (var : Any ) -> bool :
3333 """Checks if a scalar variable is an integer or float (does not include bool)."""
34- return np .issubdtype (type (var ), np .integer ) or np .issubdtype (type (var ), np .floating )
34+ return isinstance (var , (int , float , np .integer , np .floating )) and not isinstance (
35+ var , bool
36+ )
3537
3638
3739class Box (Space [NDArray [Any ]]):
@@ -88,11 +90,7 @@ def __init__(
8890 self .dtype = np .dtype (dtype )
8991
9092 # * check that dtype is an accepted dtype
91- if not (
92- np .issubdtype (self .dtype , np .integer )
93- or np .issubdtype (self .dtype , np .floating )
94- or self .dtype == np .bool_
95- ):
93+ if self .dtype .kind not in ("i" , "u" , "f" , "b" ):
9694 raise ValueError (
9795 f"Invalid Box dtype ({ self .dtype } ), must be an integer, floating, or bool dtype"
9896 )
@@ -103,7 +101,7 @@ def __init__(
103101 raise TypeError (
104102 f"Expected Box shape to be an iterable, actual type={ type (shape )} "
105103 )
106- elif not all (np . issubdtype ( type ( dim ), np .integer ) for dim in shape ):
104+ elif not all (isinstance ( dim , ( int , np .integer ) ) for dim in shape ):
107105 raise TypeError (
108106 f"Expected all Box shape elements to be integer, actual type={ tuple (type (dim ) for dim in shape )} "
109107 )
@@ -135,9 +133,9 @@ def __init__(
135133 # * nan - must be done beforehand as int dtype can cast `nan` to another value
136134 # * unsign int inf and -inf - special case that is disallowed
137135
138- if self .dtype == np . bool_ :
136+ if self .dtype . kind == "b" :
139137 dtype_min , dtype_max = 0 , 1
140- elif np . issubdtype ( self .dtype , np . floating ) :
138+ elif self .dtype . kind == "f" :
141139 dtype_min = float (np .finfo (self .dtype ).min )
142140 dtype_max = float (np .finfo (self .dtype ).max )
143141 else :
@@ -164,8 +162,8 @@ def __init__(
164162 f"Box all low values must be less than or equal to high (some values break this), low={ self .low } , high={ self .high } "
165163 )
166164
167- self .low_repr = array_short_repr ( self . low )
168- self .high_repr = array_short_repr ( self . high )
165+ self .low_repr = None
166+ self .high_repr = None
169167
170168 super ().__init__ (self .shape , self .dtype , seed )
171169
@@ -180,7 +178,7 @@ def _cast_low(self, low, dtype_min) -> tuple[np.ndarray, np.ndarray]:
180178 The updated low value and for what values the input is bounded (below)
181179 """
182180 if is_float_integer (low ):
183- bounded_below = - np .inf < np . full (self .shape , low , dtype = float )
181+ bounded_below = np .full (self .shape , - np . inf < low )
184182
185183 if np .isnan (low ):
186184 raise ValueError (f"No low value can be equal to `np.nan`, low={ low } " )
@@ -203,11 +201,7 @@ def _cast_low(self, low, dtype_min) -> tuple[np.ndarray, np.ndarray]:
203201 raise ValueError (
204202 f"Box low must be a np.ndarray, integer, or float, actual type={ type (low )} "
205203 )
206- elif not (
207- np .issubdtype (low .dtype , np .floating )
208- or np .issubdtype (low .dtype , np .integer )
209- or low .dtype == np .bool_
210- ):
204+ elif low .dtype .kind not in ("f" , "i" , "u" , "b" ):
211205 raise ValueError (
212206 f"Box low must be a floating, integer, or bool dtype, actual dtype={ low .dtype } "
213207 )
@@ -216,9 +210,10 @@ def _cast_low(self, low, dtype_min) -> tuple[np.ndarray, np.ndarray]:
216210
217211 bounded_below = - np .inf < low
218212
219- if np .any (np .isneginf (low )):
213+ neginf = np .isneginf (low )
214+ if np .any (neginf ):
220215 if self .dtype .kind == "i" : # signed int
221- low [np . isneginf ( low ) ] = dtype_min
216+ low [neginf ] = dtype_min
222217 elif self .dtype .kind in {"u" , "b" }: # unsigned int and bool
223218 raise ValueError (
224219 f"Box unsigned int dtype don't support `-np.inf`, low={ low } "
@@ -229,8 +224,8 @@ def _cast_low(self, low, dtype_min) -> tuple[np.ndarray, np.ndarray]:
229224 )
230225
231226 if (
232- np . issubdtype ( low .dtype , np . floating )
233- and np . issubdtype ( self .dtype , np . floating )
227+ low .dtype . kind == "f"
228+ and self .dtype . kind == "f"
234229 and np .finfo (self .dtype ).precision < np .finfo (low .dtype ).precision
235230 ):
236231 gym .logger .warn (
@@ -249,7 +244,7 @@ def _cast_high(self, high, dtype_max) -> tuple[np.ndarray, np.ndarray]:
249244 The updated high value and for what values the input is bounded (above)
250245 """
251246 if is_float_integer (high ):
252- bounded_above = np .full (self .shape , high , dtype = float ) < np .inf
247+ bounded_above = np .full (self .shape , high < np .inf )
253248
254249 if np .isnan (high ):
255250 raise ValueError (f"No high value can be equal to `np.nan`, high={ high } " )
@@ -272,11 +267,7 @@ def _cast_high(self, high, dtype_max) -> tuple[np.ndarray, np.ndarray]:
272267 raise ValueError (
273268 f"Box high must be a np.ndarray, integer, or float, actual type={ type (high )} "
274269 )
275- elif not (
276- np .issubdtype (high .dtype , np .floating )
277- or np .issubdtype (high .dtype , np .integer )
278- or high .dtype == np .bool_
279- ):
270+ elif high .dtype .kind not in ("f" , "i" , "u" , "b" ):
280271 raise ValueError (
281272 f"Box high must be a floating or integer dtype, actual dtype={ high .dtype } "
282273 )
@@ -299,8 +290,8 @@ def _cast_high(self, high, dtype_max) -> tuple[np.ndarray, np.ndarray]:
299290 )
300291
301292 if (
302- np . issubdtype ( high .dtype , np . floating )
303- and np . issubdtype ( self .dtype , np . floating )
293+ high .dtype . kind == "f"
294+ and self .dtype . kind == "f"
304295 and np .finfo (self .dtype ).precision < np .finfo (high .dtype ).precision
305296 ):
306297 gym .logger .warn (
@@ -451,6 +442,10 @@ def __repr__(self) -> str:
451442 Returns:
452443 A representation of the space
453444 """
445+ if self .low_repr is None :
446+ self .low_repr = array_short_repr (self .low )
447+ if self .high_repr is None :
448+ self .high_repr = array_short_repr (self .high )
454449 return f"Box({ self .low_repr } , { self .high_repr } , { self .shape } , { self .dtype } )"
455450
456451 def __eq__ (self , other : Any ) -> bool :
0 commit comments