Skip to content

Commit 4e92de5

Browse files
committed
fix(nyz): fix env check multi-discrete bug (#852)
1 parent f5157c7 commit 4e92de5

File tree

2 files changed

+23
-19
lines changed

2 files changed

+23
-19
lines changed

ding/envs/env/env_implementation_check.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,38 +24,37 @@ def check_space_dtype(env: 'BaseEnv') -> None:
2424

2525

2626
# Util function
27-
def check_array_space(ndarray: Union[np.ndarray, Sequence, Dict], space: Union['Space', Dict], name: str) -> None:
28-
if isinstance(ndarray, np.ndarray):
27+
def check_array_space(data: Union[np.ndarray, Sequence, Dict], space: Union['Space', Dict], name: str) -> None:
28+
if isinstance(data, np.ndarray):
2929
# print("{}'s type should be np.ndarray".format(name))
30-
assert ndarray.dtype == space.dtype, "{}'s dtype is {}, but requires {}".format(
31-
name, ndarray.dtype, space.dtype
32-
)
33-
assert ndarray.shape == space.shape, "{}'s shape is {}, but requires {}".format(
34-
name, ndarray.shape, space.shape
35-
)
30+
assert data.dtype == space.dtype, "{}'s dtype is {}, but requires {}".format(name, data.dtype, space.dtype)
31+
assert data.shape == space.shape, "{}'s shape is {}, but requires {}".format(name, data.shape, space.shape)
3632
if isinstance(space, Box):
37-
assert (space.low <= ndarray).all() and (ndarray <= space.high).all(
38-
), "{}'s value is {}, but requires in range ({},{})".format(name, ndarray, space.low, space.high)
33+
assert (space.low <= data).all() and (data <= space.high).all(
34+
), "{}'s value is {}, but requires in range ({},{})".format(name, data, space.low, space.high)
3935
elif isinstance(space, (Discrete, MultiDiscrete, MultiBinary)):
40-
print(space.start, space.n)
41-
assert (ndarray >= space.start) and (ndarray <= space.n)
42-
elif isinstance(ndarray, Sequence):
43-
for i in range(len(ndarray)):
36+
if isinstance(space, Discrete):
37+
assert (data >= space.start) and (data <= space.n)
38+
else:
39+
assert (data >= 0).all()
40+
assert all([d < n for d, n in zip(data, space.nvec)])
41+
elif isinstance(data, Sequence):
42+
for i in range(len(data)):
4443
try:
45-
check_array_space(ndarray[i], space[i], name)
44+
check_array_space(data[i], space[i], name)
4645
except AssertionError as e:
4746
print("The following error happens at {}-th index".format(i))
4847
raise e
49-
elif isinstance(ndarray, dict):
50-
for k in ndarray.keys():
48+
elif isinstance(data, dict):
49+
for k in data.keys():
5150
try:
52-
check_array_space(ndarray[k], space[k], name)
51+
check_array_space(data[k], space[k], name)
5352
except AssertionError as e:
5453
print("The following error happens at key {}".format(k))
5554
raise e
5655
else:
5756
raise TypeError(
58-
"Input array should be np.ndarray or sequence/dict of np.ndarray, but found {}".format(type(ndarray))
57+
"Input array should be np.ndarray or sequence/dict of np.ndarray, but found {}".format(type(data))
5958
)
6059

6160

ding/envs/env/tests/test_env_implementation_check.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ def test_check_array_space():
2222
discrete_array = np.array(11, dtype=np.int64)
2323
with pytest.raises(AssertionError):
2424
check_array_space(discrete_array, discrete_space, 'test_discrete')
25+
26+
multi_discrete_space = gym.spaces.MultiDiscrete([2, 3])
27+
multi_discrete_array = np.array([1, 2], dtype=np.int64)
28+
check_array_space(multi_discrete_array, multi_discrete_space, 'test_multi_discrete')
29+
2530
seq_array = (np.array([1, 2, 3], dtype=np.int64), np.array([4., 5., 6.], dtype=np.float32))
2631
seq_space = [gym.spaces.Box(low=0, high=10, shape=(3, ), dtype=np.int64) for _ in range(2)]
2732
with pytest.raises(AssertionError):

0 commit comments

Comments
 (0)