Skip to content

Commit 5fbfc73

Browse files
OliverZimKakadus
authored andcommitted
fix: check and set dtype of box correctly
1 parent 41240ff commit 5fbfc73

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

model/gym-interface/py/ns3ai_gym_env/envs/ns3_environment.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,16 +73,15 @@ def _create_data(self, dataContainerPb):
7373
# print(boxContainerPb.shape, boxContainerPb.dtype, boxContainerPb.uintData)
7474

7575
if boxContainerPb.dtype == pb.INT:
76-
data = boxContainerPb.intData
76+
data = np.array(boxContainerPb.doubleData, dtype=int)
7777
elif boxContainerPb.dtype == pb.UINT:
78-
data = boxContainerPb.uintData
78+
data = np.array(boxContainerPb.doubleData, dtype=np.uint)
7979
elif boxContainerPb.dtype == pb.DOUBLE:
80-
data = boxContainerPb.doubleData
80+
data = np.array(boxContainerPb.doubleData, dtype=np.float64)
8181
else:
82-
data = boxContainerPb.floatData
82+
data = np.array(boxContainerPb.floatData, dtype=np.float32)
8383

8484
# TODO: reshape using shape info
85-
data = np.array(data)
8685
return data
8786

8887
elif dataContainerPb.type == pb.Tuple:
@@ -207,11 +206,11 @@ def _pack_data(cls, actions, spaceDesc):
207206
boxContainerPb.dtype = pb.UINT
208207
boxContainerPb.uintData.extend(actions)
209208

210-
elif spaceDesc.dtype in ['float', 'float32', 'float64']:
209+
elif spaceDesc.dtype.name in ["float", "float32"]:
211210
boxContainerPb.dtype = pb.FLOAT
212211
boxContainerPb.floatData.extend(actions)
213212

214-
elif spaceDesc.dtype in ['double']:
213+
elif spaceDesc.dtype.name in ["double", "float64"]:
215214
boxContainerPb.dtype = pb.DOUBLE
216215
boxContainerPb.doubleData.extend(actions)
217216

0 commit comments

Comments
 (0)