@@ -89,8 +89,13 @@ class _AcceptedKeys:
89
89
pixels : NestedKey = "pixels"
90
90
reco_pixels : NestedKey = "reco_pixels"
91
91
92
+ tensor_keys : _AcceptedKeys
92
93
default_keys = _AcceptedKeys ()
93
94
95
+ decoder : TensorDictModule
96
+ reward_model : TensorDictModule
97
+ world_mdel : TensorDictModule
98
+
94
99
def __init__ (
95
100
self ,
96
101
world_model : TensorDictModule ,
@@ -238,9 +243,13 @@ class _AcceptedKeys:
238
243
done : NestedKey = "done"
239
244
terminated : NestedKey = "terminated"
240
245
246
+ tensor_keys : _AcceptedKeys
241
247
default_keys = _AcceptedKeys ()
242
248
default_value_estimator = ValueEstimators .TDLambda
243
249
250
+ value_model : TensorDictModule
251
+ actor_model : TensorDictModule
252
+
244
253
def __init__ (
245
254
self ,
246
255
actor_model : TensorDictModule ,
@@ -392,8 +401,11 @@ class _AcceptedKeys:
392
401
393
402
value : NestedKey = "state_value"
394
403
404
+ tensor_keys : _AcceptedKeys
395
405
default_keys = _AcceptedKeys ()
396
406
407
+ value_model : TensorDictModule
408
+
397
409
def __init__ (
398
410
self ,
399
411
value_model : TensorDictModule ,
0 commit comments