@@ -1862,13 +1862,12 @@ def type_check(
18621862 self ._specs [_key ].type_check (value [_key ], _key )
18631863
18641864 def is_in (self , val : Union [dict , TensorDictBase ]) -> bool :
1865- return all (
1866- [
1867- item .is_in (val .get (key ))
1868- for (key , item ) in self ._specs .items ()
1869- if item is not None
1870- ]
1871- )
1865+ for (key , item ) in self ._specs .items ():
1866+ if item is None :
1867+ continue
1868+ if not item .is_in (val .get (key )):
1869+ return False
1870+ return True
18721871
18731872 def project (self , val : TensorDictBase ) -> TensorDictBase :
18741873 for key , item in self .items ():
@@ -1894,22 +1893,29 @@ def rand(self, shape=None) -> TensorDictBase:
18941893 )
18951894
18961895 def keys (
1897- self , yield_nesting_keys : bool = False , nested_keys : bool = True
1896+ self ,
1897+ include_nested : bool = False ,
1898+ leaves_only : bool = False ,
18981899 ) -> KeysView :
18991900 """Keys of the CompositeSpec.
19001901
1902+ The keys argument reflect those of :class:`tensordict.TensorDict`.
1903+
19011904 Args:
1902- yield_nesting_keys (bool, optional): if :obj:`True`, the values returned
1903- will contain every level of nesting, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))`
1904- will lead to the keys :obj:`["next", ("next", "obs")]`. Default is :obj:`False`, i.e.
1905- only nested keys will be returned.
1906- nested_keys (bool, optional): if :obj:`False`, the returned keys will not be nested. They will
1905+ include_nested (bool, optional): if ``False``, the returned keys will not be nested. They will
19071906 represent only the immediate children of the root, and not the whole nested sequence, i.e. a
19081907 :obj:`CompositeSpec(next=CompositeSpec(obs=None))` will lead to the keys
1909- :obj:`["next"]. Default is :obj:`True`, i.e. nested keys will be returned.
1908+ :obj:`["next"]. Default is ``False``, i.e. nested keys will not
1909+ be returned.
1910+ leaves_only (bool, optional): if :obj:`False`, the values returned
1911+ will contain every level of nesting, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))`
1912+ will lead to the keys :obj:`["next", ("next", "obs")]`.
1913+ Default is ``False``.
19101914 """
19111915 return _CompositeSpecKeysView (
1912- self , _yield_nesting_keys = yield_nesting_keys , nested_keys = nested_keys
1916+ self ,
1917+ include_nested = include_nested ,
1918+ leaves_only = leaves_only ,
19131919 )
19141920
19151921 def items (self ) -> ItemsView :
@@ -2014,13 +2020,14 @@ def expand(self, *shape):
20142020
20152021
20162022def _keys_to_empty_composite_spec (keys ):
2023+ """Given a list of keys, creates a CompositeSpec tree where each leaf is assigned a None value."""
20172024 if not len (keys ):
20182025 return
20192026 c = CompositeSpec ()
20202027 for key in keys :
20212028 if isinstance (key , str ):
20222029 c [key ] = None
2023- elif key [0 ] in c .keys (yield_nesting_keys = True ):
2030+ elif key [0 ] in c .keys ():
20242031 if c [key [0 ]] is None :
20252032 # if the value is None we just replace it
20262033 c [key [0 ]] = _keys_to_empty_composite_spec ([key [1 :]])
@@ -2042,28 +2049,34 @@ class _CompositeSpecKeysView:
20422049 def __init__ (
20432050 self ,
20442051 composite : CompositeSpec ,
2045- nested_keys : bool = True ,
2046- _yield_nesting_keys : bool = False ,
2052+ include_nested ,
2053+ leaves_only ,
20472054 ):
20482055 self .composite = composite
2049- self ._yield_nesting_keys = _yield_nesting_keys
2050- self .nested_keys = nested_keys
2056+ self .leaves_only = leaves_only
2057+ self .include_nested = include_nested
20512058
20522059 def __iter__ (
20532060 self ,
20542061 ):
20552062 for key , item in self .composite .items ():
2056- if self .nested_keys and isinstance (item , CompositeSpec ):
2057- for subkey in item .keys ():
2058- yield (key , * subkey ) if isinstance (subkey , tuple ) else (key , subkey )
2059- if self ._yield_nesting_keys :
2060- yield key
2061- else :
2062- if not isinstance (item , CompositeSpec ) or len (item ):
2063+ if self .include_nested and isinstance (item , CompositeSpec ):
2064+ for subkey in item .keys (
2065+ include_nested = True , leaves_only = self .leaves_only
2066+ ):
2067+ if not isinstance (subkey , tuple ):
2068+ subkey = (subkey ,)
2069+ yield (key , * subkey )
2070+ if not self .leaves_only :
20632071 yield key
2072+ elif not isinstance (item , CompositeSpec ) or not self .leaves_only :
2073+ yield key
20642074
20652075 def __len__ (self ):
20662076 i = 0
20672077 for _ in self :
20682078 i += 1
20692079 return i
2080+
2081+ def __repr__ (self ):
2082+ return f"_CompositeSpecKeysView(keys={ list (self )} )"
0 commit comments