@@ -24,38 +24,37 @@ def check_space_dtype(env: 'BaseEnv') -> None:
2424
2525
2626# Util function
27- def check_array_space (ndarray : Union [np .ndarray , Sequence , Dict ], space : Union ['Space' , Dict ], name : str ) -> None :
28- if isinstance (ndarray , np .ndarray ):
27+ def check_array_space (data : Union [np .ndarray , Sequence , Dict ], space : Union ['Space' , Dict ], name : str ) -> None :
28+ if isinstance (data , np .ndarray ):
2929 # print("{}'s type should be np.ndarray".format(name))
30- assert ndarray .dtype == space .dtype , "{}'s dtype is {}, but requires {}" .format (
31- name , ndarray .dtype , space .dtype
32- )
33- assert ndarray .shape == space .shape , "{}'s shape is {}, but requires {}" .format (
34- name , ndarray .shape , space .shape
35- )
30+ assert data .dtype == space .dtype , "{}'s dtype is {}, but requires {}" .format (name , data .dtype , space .dtype )
31+ assert data .shape == space .shape , "{}'s shape is {}, but requires {}" .format (name , data .shape , space .shape )
3632 if isinstance (space , Box ):
37- assert (space .low <= ndarray ).all () and (ndarray <= space .high ).all (
38- ), "{}'s value is {}, but requires in range ({},{})" .format (name , ndarray , space .low , space .high )
33+ assert (space .low <= data ).all () and (data <= space .high ).all (
34+ ), "{}'s value is {}, but requires in range ({},{})" .format (name , data , space .low , space .high )
3935 elif isinstance (space , (Discrete , MultiDiscrete , MultiBinary )):
40- print (space .start , space .n )
41- assert (ndarray >= space .start ) and (ndarray <= space .n )
42- elif isinstance (ndarray , Sequence ):
43- for i in range (len (ndarray )):
36+ if isinstance (space , Discrete ):
37+ assert (data >= space .start ) and (data <= space .n )
38+ else :
39+ assert (data >= 0 ).all ()
40+ assert all ([d < n for d , n in zip (data , space .nvec )])
41+ elif isinstance (data , Sequence ):
42+ for i in range (len (data )):
4443 try :
45- check_array_space (ndarray [i ], space [i ], name )
44+ check_array_space (data [i ], space [i ], name )
4645 except AssertionError as e :
4746 print ("The following error happens at {}-th index" .format (i ))
4847 raise e
49- elif isinstance (ndarray , dict ):
50- for k in ndarray .keys ():
48+ elif isinstance (data , dict ):
49+ for k in data .keys ():
5150 try :
52- check_array_space (ndarray [k ], space [k ], name )
51+ check_array_space (data [k ], space [k ], name )
5352 except AssertionError as e :
5453 print ("The following error happens at key {}" .format (k ))
5554 raise e
5655 else :
5756 raise TypeError (
58- "Input array should be np.ndarray or sequence/dict of np.ndarray, but found {}" .format (type (ndarray ))
57+ "Input array should be np.ndarray or sequence/dict of np.ndarray, but found {}" .format (type (data ))
5958 )
6059
6160
0 commit comments