4949SmoLossInfo = namedtuple ("SmoLossInfo" , ["actor" ], default_value = ())
5050
5151
52- # -> algorithm
53- class Discriminator_SA (Algorithm ):
54- def __init__ (self , observation_spec , action_spec ):
55- super ().__init__ (observation_spec = observation_spec )
56-
57- disc_net = CriticNetwork ((observation_spec , action_spec ))
58- self ._disc_net = disc_net
59-
60- def forward (self , inputs , state = ()):
61- return self ._disc_net (inputs , state )
62-
63- def compute_grad_pen (self , expert_state , offline_state , lambda_ = 10 ):
64- alpha = torch .rand (expert_state .size (0 ), 1 )
65- expert_data = expert_state
66- offline_data = offline_state
67-
68- alpha = alpha .expand_as (expert_data ).to (expert_data .device )
69-
70- mixup_data = alpha * expert_data + (1 - alpha ) * offline_data
71- mixup_data .requires_grad = True
72-
73- disc = self (mixup_data )
74- ones = torch .ones (disc .size ()).to (disc .device )
75- grad = autograd .grad (
76- outputs = disc ,
77- inputs = mixup_data ,
78- grad_outputs = ones ,
79- create_graph = True ,
80- retain_graph = True ,
81- only_inputs = True )[0 ]
82-
83- grad_pen = lambda_ * (grad .norm (2 , dim = 1 ) - 1 ).pow (2 ).mean ()
84- return grad_pen
85-
86- def update (self , expert_loader , offline_loader ):
87- self .train ()
88-
89- loss = 0
90- n = 0
91- for expert_state , offline_state in zip (expert_loader , offline_loader ):
92-
93- expert_state = expert_state [0 ].to (self .device )
94- offline_state = offline_state [0 ][:expert_state .shape [0 ]].to (
95- self .device )
96-
97- policy_d = self (offline_state )
98- expert_d = self (expert_state )
99-
100- expert_loss = F .binary_cross_entropy_with_logits (
101- expert_d ,
102- torch .ones (expert_d .size ()).to (self .device ))
103- policy_loss = F .binary_cross_entropy_with_logits (
104- policy_d ,
105- torch .zeros (policy_d .size ()).to (self .device ))
106-
107- gail_loss = expert_loss + policy_loss
108- grad_pen = self .compute_grad_pen (expert_state , offline_state )
109-
110- loss += (gail_loss + grad_pen ).item ()
111- n += 1
112-
113- self .optimizer .zero_grad ()
114- (gail_loss + grad_pen ).backward ()
115- self .optimizer .step ()
116- return loss / n
117-
118- def predict_reward (self , state ):
119- with torch .no_grad ():
120- self .eval ()
121- d = self (state )
122- s = torch .sigmoid (d )
123- # log(d^E/d^O)
124- # reward = - (1/s-1).log()
125- reward = s .log () - (1 - s ).log ()
126- return reward
127-
128-
12952@alf .configurable
13053class SmodiceAlgorithm (OffPolicyAlgorithm ):
13154 r"""SMODICE algorithm.
@@ -143,27 +66,24 @@ class SmodiceAlgorithm(OffPolicyAlgorithm):
14366 ICML 2022.
14467 """
14568
146- def __init__ (
147- self ,
148- observation_spec ,
149- action_spec : BoundedTensorSpec ,
150- reward_spec = TensorSpec (()),
151- actor_network_cls = ActorNetwork ,
152- v_network_cls = ValueNetwork ,
153- discriminator_network_cls = None ,
154- actor_optimizer = None ,
155- value_optimizer = None ,
156- discriminator_optimizer = None ,
157- #=====new params
158- gamma : float = 0.99 ,
159- v_l2_reg : float = 0.001 ,
160- env = None ,
161- config : TrainerConfig = None ,
162- checkpoint = None ,
163- debug_summaries = False ,
164- epsilon_greedy = None ,
165- f = "chi" ,
166- name = "SmodiceAlgorithm" ):
69+ def __init__ (self ,
70+ observation_spec ,
71+ action_spec : BoundedTensorSpec ,
72+ reward_spec = TensorSpec (()),
73+ actor_network_cls = ActorNetwork ,
74+ v_network_cls = ValueNetwork ,
75+ discriminator_network_cls = None ,
76+ actor_optimizer = None ,
77+ value_optimizer = None ,
78+ discriminator_optimizer = None ,
79+ gamma : float = 0.99 ,
80+ f = "chi" ,
81+ env = None ,
82+ config : TrainerConfig = None ,
83+ checkpoint = None ,
84+ debug_summaries = False ,
85+ epsilon_greedy = None ,
86+ name = "SmodiceAlgorithm" ):
16787 """
16888 Args:
16989 observation_spec (nested TensorSpec): representing the observations.
@@ -178,7 +98,13 @@ def __init__(
17898 actor_network_cls (Callable): is used to construct the actor network.
17999 The constructed actor network is a determinstic network and
180100 will be used to generate continuous actions.
101+ v_network_cls (Callable): is used to construct the value network.
102+ discriminator_network_cls (Callable): is used to construct the discriminatr.
181103 actor_optimizer (torch.optim.optimizer): The optimizer for actor.
104+ value_optimizer (torch.optim.optimizer): The optimizer for value network.
105+ discriminator_optimizer (torch.optim.optimizer): The optimizer for discriminator.
106+ gamma (float): the discount factor.
107+ f (str): the function form for f-divergence. Currently support 'chi' and 'kl'
182108 env (Environment): The environment to interact with. ``env`` is a
183109 batched environment, which means that it runs multiple simulations
184110 simultateously. ``env` only needs to be provided to the root
@@ -242,12 +168,9 @@ def __init__(
242168 if discriminator_optimizer is not None and discriminator_net is not None :
243169 self .add_optimizer (discriminator_optimizer , [discriminator_net ])
244170
245- self ._actor_optimizer = actor_optimizer
246- self ._value_optimizer = value_optimizer
247- self ._v_l2_reg = v_l2_reg
248171 self ._gamma = gamma
249172 self ._f = f
250- assert f == "chi" , "only support chi form"
173+ assert f in [ "chi" , "kl" ], " only support chi or kl form"
251174
252175 # f-divergence functions
253176 if self ._f == 'chi' :
@@ -327,14 +250,8 @@ def _discriminator_train_step(self, inputs: TimeStep, state, rollout_info,
327250 return LossInfo (loss = expert_loss , extra = SmoLossInfo (actor = expert_loss ))
328251
329252 def value_train_step (self , inputs : TimeStep , state , rollout_info ):
330- # initial_v_values, e_v, result={}
331253 observation = inputs .observation
332-
333- # extract initial observation from batch, or prepare a batch
334254 initial_observation = observation
335-
336- # Shared network values
337- # mini_batch_length
338255 initial_v_values , _ = self ._value_network (initial_observation )
339256
340257 # mini-batch len
0 commit comments