4
4
# LICENSE file in the root directory of this source tree.
5
5
6
6
import argparse
7
+ import functools
7
8
import os
8
9
9
10
import pytest
12
13
import torchrl .modules
13
14
from tensordict import LazyStackedTensorDict , pad , TensorDict , unravel_key_list
14
15
from tensordict .nn import InteractionType , TensorDictModule , TensorDictSequential
16
+ from tensordict .utils import assert_close
15
17
from torch import nn
16
18
from torchrl .data .tensor_specs import Bounded , Composite , Unbounded
17
19
from torchrl .envs import (
@@ -938,10 +940,12 @@ def test_multi_consecutive(self, shape, python_based):
938
940
@pytest .mark .parametrize ("python_based" , [True , False ])
939
941
@pytest .mark .parametrize ("parallel" , [True , False ])
940
942
@pytest .mark .parametrize ("heterogeneous" , [True , False ])
941
- def test_lstm_parallel_env (self , python_based , parallel , heterogeneous ):
943
+ @pytest .mark .parametrize ("within" , [False , True ])
944
+ def test_lstm_parallel_env (self , python_based , parallel , heterogeneous , within ):
942
945
from torchrl .envs import InitTracker , ParallelEnv , TransformedEnv
943
946
944
947
torch .manual_seed (0 )
948
+ num_envs = 3
945
949
device = "cuda" if torch .cuda .device_count () else "cpu"
946
950
# tests that hidden states are carried over with parallel envs
947
951
lstm_module = LSTMModule (
@@ -958,25 +962,36 @@ def test_lstm_parallel_env(self, python_based, parallel, heterogeneous):
958
962
else :
959
963
cls = SerialEnv
960
964
961
- def create_transformed_env ():
962
- primer = lstm_module .make_tensordict_primer ()
963
- env = DiscreteActionVecMockEnv (
964
- categorical_action_encoding = True , device = device
965
+ if within :
966
+
967
+ def create_transformed_env ():
968
+ primer = lstm_module .make_tensordict_primer ()
969
+ env = DiscreteActionVecMockEnv (
970
+ categorical_action_encoding = True , device = device
971
+ )
972
+ env = TransformedEnv (env )
973
+ env .append_transform (InitTracker ())
974
+ env .append_transform (primer )
975
+ return env
976
+
977
+ else :
978
+ create_transformed_env = functools .partial (
979
+ DiscreteActionVecMockEnv ,
980
+ categorical_action_encoding = True ,
981
+ device = device ,
965
982
)
966
- env = TransformedEnv (env )
967
- env .append_transform (InitTracker ())
968
- env .append_transform (primer )
969
- return env
970
983
971
984
if heterogeneous :
972
985
create_transformed_env = [
973
- EnvCreator (create_transformed_env ),
974
- EnvCreator (create_transformed_env ),
986
+ EnvCreator (create_transformed_env ) for _ in range (num_envs )
975
987
]
976
988
env = cls (
977
989
create_env_fn = create_transformed_env ,
978
- num_workers = 2 ,
990
+ num_workers = num_envs ,
979
991
)
992
+ if not within :
993
+ env = env .append_transform (InitTracker ())
994
+ env .append_transform (lstm_module .make_tensordict_primer ())
980
995
981
996
mlp = TensorDictModule (
982
997
MLP (
@@ -1002,6 +1017,19 @@ def create_transformed_env():
1002
1017
data = env .rollout (10 , actor , break_when_any_done = break_when_any_done )
1003
1018
assert (data .get (("next" , "recurrent_state_c" )) != 0.0 ).all ()
1004
1019
assert (data .get ("recurrent_state_c" ) != 0.0 ).any ()
1020
+ return data
1021
+
1022
+ @pytest .mark .parametrize ("python_based" , [True , False ])
1023
+ @pytest .mark .parametrize ("parallel" , [True , False ])
1024
+ @pytest .mark .parametrize ("heterogeneous" , [True , False ])
1025
+ def test_lstm_parallel_within (self , python_based , parallel , heterogeneous ):
1026
+ out_within = self .test_lstm_parallel_env (
1027
+ python_based , parallel , heterogeneous , within = True
1028
+ )
1029
+ out_not_within = self .test_lstm_parallel_env (
1030
+ python_based , parallel , heterogeneous , within = False
1031
+ )
1032
+ assert_close (out_within , out_not_within )
1005
1033
1006
1034
@pytest .mark .skipif (
1007
1035
not _has_functorch , reason = "vmap can only be used with functorch"
@@ -1330,10 +1358,12 @@ def test_multi_consecutive(self, shape, python_based):
1330
1358
@pytest .mark .parametrize ("python_based" , [True , False ])
1331
1359
@pytest .mark .parametrize ("parallel" , [True , False ])
1332
1360
@pytest .mark .parametrize ("heterogeneous" , [True , False ])
1333
- def test_gru_parallel_env (self , python_based , parallel , heterogeneous ):
1361
+ @pytest .mark .parametrize ("within" , [False , True ])
1362
+ def test_gru_parallel_env (self , python_based , parallel , heterogeneous , within ):
1334
1363
from torchrl .envs import InitTracker , ParallelEnv , TransformedEnv
1335
1364
1336
1365
torch .manual_seed (0 )
1366
+ num_workers = 3
1337
1367
1338
1368
device = "cuda" if torch .cuda .device_count () else "cpu"
1339
1369
# tests that hidden states are carried over with parallel envs
@@ -1347,30 +1377,42 @@ def test_gru_parallel_env(self, python_based, parallel, heterogeneous):
1347
1377
python_based = python_based ,
1348
1378
)
1349
1379
1350
- def create_transformed_env ():
1351
- primer = gru_module .make_tensordict_primer ()
1352
- env = DiscreteActionVecMockEnv (
1353
- categorical_action_encoding = True , device = device
1380
+ if within :
1381
+
1382
+ def create_transformed_env ():
1383
+ primer = gru_module .make_tensordict_primer ()
1384
+ env = DiscreteActionVecMockEnv (
1385
+ categorical_action_encoding = True , device = device
1386
+ )
1387
+ env = TransformedEnv (env )
1388
+ env .append_transform (InitTracker ())
1389
+ env .append_transform (primer )
1390
+ return env
1391
+
1392
+ else :
1393
+ create_transformed_env = functools .partial (
1394
+ DiscreteActionVecMockEnv ,
1395
+ categorical_action_encoding = True ,
1396
+ device = device ,
1354
1397
)
1355
- env = TransformedEnv (env )
1356
- env .append_transform (InitTracker ())
1357
- env .append_transform (primer )
1358
- return env
1359
1398
1360
1399
if parallel :
1361
1400
cls = ParallelEnv
1362
1401
else :
1363
1402
cls = SerialEnv
1364
1403
if heterogeneous :
1365
1404
create_transformed_env = [
1366
- EnvCreator (create_transformed_env ),
1367
- EnvCreator (create_transformed_env ),
1405
+ EnvCreator (create_transformed_env ) for _ in range (num_workers )
1368
1406
]
1369
1407
1370
- env = cls (
1408
+ env : ParallelEnv | SerialEnv = cls (
1371
1409
create_env_fn = create_transformed_env ,
1372
- num_workers = 2 ,
1410
+ num_workers = num_workers ,
1373
1411
)
1412
+ if not within :
1413
+ primer = gru_module .make_tensordict_primer ()
1414
+ env = env .append_transform (InitTracker ())
1415
+ env .append_transform (primer )
1374
1416
1375
1417
mlp = TensorDictModule (
1376
1418
MLP (
@@ -1396,6 +1438,19 @@ def create_transformed_env():
1396
1438
data = env .rollout (10 , actor , break_when_any_done = break_when_any_done )
1397
1439
assert (data .get ("recurrent_state" ) != 0.0 ).any ()
1398
1440
assert (data .get (("next" , "recurrent_state" )) != 0.0 ).all ()
1441
+ return data
1442
+
1443
+ @pytest .mark .parametrize ("python_based" , [True , False ])
1444
+ @pytest .mark .parametrize ("parallel" , [True , False ])
1445
+ @pytest .mark .parametrize ("heterogeneous" , [True , False ])
1446
+ def test_gru_parallel_within (self , python_based , parallel , heterogeneous ):
1447
+ out_within = self .test_gru_parallel_env (
1448
+ python_based , parallel , heterogeneous , within = True
1449
+ )
1450
+ out_not_within = self .test_gru_parallel_env (
1451
+ python_based , parallel , heterogeneous , within = False
1452
+ )
1453
+ assert_close (out_within , out_not_within )
1399
1454
1400
1455
@pytest .mark .skipif (
1401
1456
not _has_functorch , reason = "vmap can only be used with functorch"
0 commit comments