@@ -258,9 +258,11 @@ def __init__(
258258 ):
259259 super ().__init__ ()
260260 self .action_spec = action_spec
261- self .version_number = torch .nn .Parameter (torch .Tensor ([2.0 ]))
261+ self .version_number = torch .nn .Parameter (
262+ torch .Tensor ([2.0 ]), requires_grad = False
263+ )
262264 self .is_continuous_int_deprecated = torch .nn .Parameter (
263- torch .Tensor ([int (self .action_spec .is_continuous ())])
265+ torch .Tensor ([int (self .action_spec .is_continuous ())]), requires_grad = False
264266 )
265267 self .continuous_act_size_vector = torch .nn .Parameter (
266268 torch .Tensor ([int (self .action_spec .continuous_size )]), requires_grad = False
@@ -283,6 +285,9 @@ def __init__(
283285 self .encoding_size = network_settings .memory .memory_size // 2
284286 else :
285287 self .encoding_size = network_settings .hidden_units
288+ self .memory_size_vector = torch .nn .Parameter (
289+ torch .Tensor ([int (self .network_body .memory_size )]), requires_grad = False
290+ )
286291
287292 self .action_model = ActionModel (
288293 self .encoding_size ,
@@ -335,10 +340,7 @@ def forward(
335340 disc_action_out ,
336341 action_out_deprecated ,
337342 ) = self .action_model .get_action_out (encoding , masks )
338- export_out = [
339- self .version_number ,
340- torch .Tensor ([self .network_body .memory_size ]),
341- ]
343+ export_out = [self .version_number , self .memory_size_vector ]
342344 if self .action_spec .continuous_size > 0 :
343345 export_out += [cont_action_out , self .continuous_act_size_vector ]
344346 if self .action_spec .discrete_size > 0 :
0 commit comments