@@ -306,16 +306,16 @@ def _gym_to_torchrl_spec_transform(
306306 shape = torch .Size ([1 ])
307307 if dtype is None :
308308 dtype = numpy_to_torch_dtype_dict [spec .dtype ]
309- low = torch .tensor (spec .low , device = device , dtype = dtype )
310- high = torch .tensor (spec .high , device = device , dtype = dtype )
309+ low = torch .as_tensor (spec .low , device = device , dtype = dtype )
310+ high = torch .as_tensor (spec .high , device = device , dtype = dtype )
311311 is_unbounded = low .isinf ().all () and high .isinf ().all ()
312312
313313 minval , maxval = _minmax_dtype (dtype )
314314 minval = torch .as_tensor (minval ).to (low .device , dtype )
315315 maxval = torch .as_tensor (maxval ).to (low .device , dtype )
316316 is_unbounded = is_unbounded or (
317- torch .isclose (low , torch .tensor (minval , dtype = dtype )).all ()
318- and torch .isclose (high , torch .tensor (maxval , dtype = dtype )).all ()
317+ torch .isclose (low , torch .as_tensor (minval , dtype = dtype )).all ()
318+ and torch .isclose (high , torch .as_tensor (maxval , dtype = dtype )).all ()
319319 )
320320 return (
321321 UnboundedContinuousTensorSpec (shape , device = device , dtype = dtype )
@@ -1480,7 +1480,7 @@ def _read_obs(self, obs, key, tensor, index):
14801480 # Simplest case: there is one observation,
14811481 # presented as a np.ndarray. The key should be pixels or observation.
14821482 # We just write that value at its location in the tensor
1483- tensor [index ] = torch .tensor (obs , device = tensor .device )
1483+ tensor [index ] = torch .as_tensor (obs , device = tensor .device )
14841484 elif isinstance (obs , dict ):
14851485 if key not in obs :
14861486 raise KeyError (
@@ -1491,13 +1491,13 @@ def _read_obs(self, obs, key, tensor, index):
14911491 # if the obs is a dict, we expect that the key points also to
14921492 # a value in the obs. We retrieve this value and write it in the
14931493 # tensor
1494- tensor [index ] = torch .tensor (subobs , device = tensor .device )
1494+ tensor [index ] = torch .as_tensor (subobs , device = tensor .device )
14951495
14961496 elif isinstance (obs , (list , tuple )):
14971497 # tuples are stacked along the first dimension when passing gym spaces
14981498 # to torchrl specs. As such, we can simply stack the tuple and set it
14991499 # at the relevant index (assuming stacking can be achieved)
1500- tensor [index ] = torch .tensor (obs , device = tensor .device )
1500+ tensor [index ] = torch .as_tensor (obs , device = tensor .device )
15011501 else :
15021502 raise NotImplementedError (
15031503 f"Observations of type { type (obs )} are not supported yet."
0 commit comments