55import math
66import warnings
77from dataclasses import dataclass
8+ from functools import wraps
89from numbers import Number
910from typing import Dict , Optional , Tuple , Union
1011
4344 FUNCTORCH_ERROR = err
4445
4546
47+ def _delezify (func ):
48+ @wraps (func )
49+ def new_func (self , * args , ** kwargs ):
50+ self .target_entropy
51+ return func (self , * args , ** kwargs )
52+
53+ return new_func
54+
55+
4656class SACLoss (LossModule ):
4757 """TorchRL implementation of the SAC loss.
4858
@@ -371,7 +381,6 @@ def __init__(
371381
372382 self ._target_entropy = target_entropy
373383 self ._action_spec = action_spec
374- self .target_entropy_buffer = None
375384 if self ._version == 1 :
376385 self .actor_critic = ActorCriticWrapper (
377386 self .actor_network , self .value_network
@@ -384,48 +393,54 @@ def __init__(
384393 if self ._version == 1 :
385394 self ._vmap_qnetwork00 = vmap (qvalue_network )
386395
396+ @property
397+ def target_entropy_buffer (self ):
398+ return self .target_entropy
399+
387400 @property
388401 def target_entropy (self ):
389- target_entropy = self .target_entropy_buffer
390- if target_entropy is None :
391- delattr (self , "target_entropy_buffer" )
392- target_entropy = self ._target_entropy
393- action_spec = self ._action_spec
394- actor_network = self .actor_network
395- device = next (self .parameters ()).device
396- if target_entropy == "auto" :
397- action_spec = (
398- action_spec
399- if action_spec is not None
400- else getattr (actor_network , "spec" , None )
401- )
402- if action_spec is None :
403- raise RuntimeError (
404- "Cannot infer the dimensionality of the action. Consider providing "
405- "the target entropy explicitely or provide the spec of the "
406- "action tensor in the actor network."
407- )
408- if not isinstance (action_spec , CompositeSpec ):
409- action_spec = CompositeSpec ({self .tensor_keys .action : action_spec })
410- if (
411- isinstance (self .tensor_keys .action , tuple )
412- and len (self .tensor_keys .action ) > 1
413- ):
414- action_container_shape = action_spec [
415- self .tensor_keys .action [:- 1 ]
416- ].shape
417- else :
418- action_container_shape = action_spec .shape
419- target_entropy = - float (
420- action_spec [self .tensor_keys .action ]
421- .shape [len (action_container_shape ) :]
422- .numel ()
402+ target_entropy = self ._buffers .get ("_target_entropy" , None )
403+ if target_entropy is not None :
404+ return target_entropy
405+ target_entropy = self ._target_entropy
406+ action_spec = self ._action_spec
407+ actor_network = self .actor_network
408+ device = next (self .parameters ()).device
409+ if target_entropy == "auto" :
410+ action_spec = (
411+ action_spec
412+ if action_spec is not None
413+ else getattr (actor_network , "spec" , None )
414+ )
415+ if action_spec is None :
416+ raise RuntimeError (
417+ "Cannot infer the dimensionality of the action. Consider providing "
418+ "the target entropy explicitely or provide the spec of the "
419+ "action tensor in the actor network."
423420 )
424- self .register_buffer (
425- "target_entropy_buffer" , torch .tensor (target_entropy , device = device )
421+ if not isinstance (action_spec , CompositeSpec ):
422+ action_spec = CompositeSpec ({self .tensor_keys .action : action_spec })
423+ if (
424+ isinstance (self .tensor_keys .action , tuple )
425+ and len (self .tensor_keys .action ) > 1
426+ ):
427+
428+ action_container_shape = action_spec [self .tensor_keys .action [:- 1 ]].shape
429+ else :
430+ action_container_shape = action_spec .shape
431+ target_entropy = - float (
432+ action_spec [self .tensor_keys .action ]
433+ .shape [len (action_container_shape ) :]
434+ .numel ()
426435 )
427- return self .target_entropy_buffer
428- return target_entropy
436+ delattr (self , "_target_entropy" )
437+ self .register_buffer (
438+ "_target_entropy" , torch .tensor (target_entropy , device = device )
439+ )
440+ return self ._target_entropy
441+
442+ state_dict = _delezify (LossModule .state_dict )
443+ load_state_dict = _delezify (LossModule .load_state_dict )
429444
430445 def _forward_value_estimator_keys (self , ** kwargs ) -> None :
431446 if self ._value_estimator is not None :
0 commit comments