1414import re
1515import warnings
1616from enum import Enum
17- from typing import Any , Dict , List
17+ from typing import Any
1818
1919import torch
2020
@@ -329,9 +329,9 @@ def step_mdp(
329329 exclude_reward : bool = True ,
330330 exclude_done : bool = False ,
331331 exclude_action : bool = True ,
332- reward_keys : NestedKey | List [NestedKey ] = "reward" ,
333- done_keys : NestedKey | List [NestedKey ] = "done" ,
334- action_keys : NestedKey | List [NestedKey ] = "action" ,
332+ reward_keys : NestedKey | list [NestedKey ] = "reward" ,
333+ done_keys : NestedKey | list [NestedKey ] = "done" ,
334+ action_keys : NestedKey | list [NestedKey ] = "action" ,
335335) -> TensorDictBase :
336336 """Creates a new tensordict that reflects a step in time of the input tensordict.
337337
@@ -680,8 +680,8 @@ def _per_level_env_check(data0, data1, check_dtype):
680680
681681
682682def check_env_specs (
683- env ,
684- return_contiguous = True ,
683+ env : torchrl . envs . EnvBase , # noqa
684+ return_contiguous : bool | None = None ,
685685 check_dtype = True ,
686686 seed : int | None = None ,
687687 tensordict : TensorDictBase | None = None ,
@@ -699,7 +699,7 @@ def check_env_specs(
699699 env (EnvBase): the env for which the specs have to be checked against data.
700700 return_contiguous (bool, optional): if ``True``, the random rollout will be called with
701701 return_contiguous=True. This will fail in some cases (e.g. heterogeneous shapes
702- of inputs/outputs). Defaults to True .
702+ of inputs/outputs). Defaults to ``None`` (determined by the presence of dynamic specs) .
703703 check_dtype (bool, optional): if False, dtype checks will be skipped.
704704 Defaults to True.
705705 seed (int, optional): for reproducibility, a seed can be set.
@@ -715,6 +715,8 @@ def check_env_specs(
715715 of an experiment and as such should be kept out of training scripts.
716716
717717 """
718+ if return_contiguous is None :
719+ return_contiguous = not env ._has_dynamic_specs
718720 if seed is not None :
719721 device = (
720722 env .device if env .device is not None and env .device .type == "cuda" else None
@@ -726,7 +728,7 @@ def check_env_specs(
726728 )
727729
728730 fake_tensordict = env .fake_tensordict ()
729- if not env ._batch_locked and tensordict is not None :
731+ if not env .batch_locked and tensordict is not None :
730732 shape = torch .broadcast_shapes (fake_tensordict .shape , tensordict .shape )
731733 fake_tensordict = fake_tensordict .expand (shape )
732734 tensordict = tensordict .expand (shape )
@@ -765,10 +767,13 @@ def check_env_specs(
765767 - List of keys present in fake but not in real: { fake_tensordict_keys - real_tensordict_keys } .
766768"""
767769 )
768- zeroing_err_msg = (
769- "zeroing the two tensordicts did not make them identical. "
770- f"Check for discrepancies:\n Fake=\n { fake_tensordict } \n Real=\n { real_tensordict } "
771- )
770+
771+ def zeroing_err_msg ():
772+ return (
773+ "zeroing the two tensordicts did not make them identical. "
774+ f"Check for discrepancies:\n Fake=\n { fake_tensordict } \n Real=\n { real_tensordict } "
775+ )
776+
772777 from torchrl .envs .common import _has_dynamic_specs
773778
774779 if _has_dynamic_specs (env .specs ):
@@ -778,7 +783,7 @@ def check_env_specs(
778783 ):
779784 fake = fake .apply (lambda x , y : x .expand_as (y ), real )
780785 if (torch .zeros_like (real ) != torch .zeros_like (fake )).any ():
781- raise AssertionError (zeroing_err_msg )
786+ raise AssertionError (zeroing_err_msg () )
782787
783788 # Checks shapes and eventually dtypes of keys at all nesting levels
784789 _per_level_env_check (fake , real , check_dtype = check_dtype )
@@ -788,7 +793,7 @@ def check_env_specs(
788793 torch .zeros_like (fake_tensordict_select )
789794 != torch .zeros_like (real_tensordict_select )
790795 ).any ():
791- raise AssertionError (zeroing_err_msg )
796+ raise AssertionError (zeroing_err_msg () )
792797
793798 # Checks shapes and eventually dtypes of keys at all nesting levels
794799 _per_level_env_check (
@@ -1009,14 +1014,14 @@ class MarlGroupMapType(Enum):
10091014 ALL_IN_ONE_GROUP = 1
10101015 ONE_GROUP_PER_AGENT = 2
10111016
1012- def get_group_map (self , agent_names : List [str ]):
1017+ def get_group_map (self , agent_names : list [str ]):
10131018 if self == MarlGroupMapType .ALL_IN_ONE_GROUP :
10141019 return {"agents" : agent_names }
10151020 elif self == MarlGroupMapType .ONE_GROUP_PER_AGENT :
10161021 return {agent_name : [agent_name ] for agent_name in agent_names }
10171022
10181023
1019- def check_marl_grouping (group_map : Dict [str , List [str ]], agent_names : List [str ]):
1024+ def check_marl_grouping (group_map : dict [str , list [str ]], agent_names : list [str ]):
10201025 """Check MARL group map.
10211026
10221027 Performs checks on the group map of a marl environment to assess its validity.
@@ -1360,7 +1365,7 @@ def skim_through(td, reset=reset):
13601365def _update_during_reset (
13611366 tensordict_reset : TensorDictBase ,
13621367 tensordict : TensorDictBase ,
1363- reset_keys : List [NestedKey ],
1368+ reset_keys : list [NestedKey ],
13641369):
13651370 """Updates the input tensordict with the reset data, based on the reset keys."""
13661371 if not reset_keys :
0 commit comments