Skip to content

Commit 82467c0

Browse files
OliverZimKakadus
authored andcommitted
fix: check and set dtype of box correctly
1 parent 41240ff commit 82467c0

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

model/gym-interface/py/ns3ai_gym_env/envs/ns3_environment.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def _create_space(self, spaceDesc):
6161
return space
6262

6363
def _create_data(self, dataContainerPb):
64+
import pdb; pdb.set_trace()
6465
if dataContainerPb.type == pb.Discrete:
6566
discreteContainerPb = pb.DiscreteDataContainer()
6667
dataContainerPb.data.Unpack(discreteContainerPb)
@@ -73,16 +74,15 @@ def _create_data(self, dataContainerPb):
7374
# print(boxContainerPb.shape, boxContainerPb.dtype, boxContainerPb.uintData)
7475

7576
if boxContainerPb.dtype == pb.INT:
76-
data = boxContainerPb.intData
77+
data = np.array(boxContainerPb.intData, dtype=int)
7778
elif boxContainerPb.dtype == pb.UINT:
78-
data = boxContainerPb.uintData
79+
data = np.array(boxContainerPb.uintData, dtype=np.uint)
7980
elif boxContainerPb.dtype == pb.DOUBLE:
80-
data = boxContainerPb.doubleData
81+
data = np.array(boxContainerPb.doubleData, dtype=np.float64)
8182
else:
82-
data = boxContainerPb.floatData
83+
data = np.array(boxContainerPb.floatData, dtype=np.float32)
8384

8485
# TODO: reshape using shape info
85-
data = np.array(data)
8686
return data
8787

8888
elif dataContainerPb.type == pb.Tuple:
@@ -207,11 +207,11 @@ def _pack_data(cls, actions, spaceDesc):
207207
boxContainerPb.dtype = pb.UINT
208208
boxContainerPb.uintData.extend(actions)
209209

210-
elif spaceDesc.dtype in ['float', 'float32', 'float64']:
210+
elif spaceDesc.dtype.name in ["float", "float32"]:
211211
boxContainerPb.dtype = pb.FLOAT
212212
boxContainerPb.floatData.extend(actions)
213213

214-
elif spaceDesc.dtype in ['double']:
214+
elif spaceDesc.dtype.name in ["double", "float64"]:
215215
boxContainerPb.dtype = pb.DOUBLE
216216
boxContainerPb.doubleData.extend(actions)
217217

0 commit comments

Comments
 (0)