4646SmoCriticInfo = namedtuple ("SmoCriticInfo" ,
4747 ["values" , "initial_v_values" , "is_first" ])
4848
49- SmoLossInfo = namedtuple ("SmoLossInfo" , ["actor" ], default_value = ())
49+ SmoLossInfo = namedtuple (
50+ "SmoLossInfo" , ["actor" , "grad_penalty" ], default_value = ())
5051
5152
5253@alf .configurable
@@ -77,7 +78,8 @@ def __init__(self,
7778 value_optimizer = None ,
7879 discriminator_optimizer = None ,
7980 gamma : float = 0.99 ,
80- f = "chi" ,
81+ f : str = "chi" ,
82+ gradient_penalty_weight : float = 1 ,
8183 env = None ,
8284 config : TrainerConfig = None ,
8385 checkpoint = None ,
@@ -104,7 +106,8 @@ def __init__(self,
104106 value_optimizer (torch.optim.optimizer): The optimizer for value network.
105107 discriminator_optimizer (torch.optim.optimizer): The optimizer for discriminator.
106108 gamma (float): the discount factor.
107- f (str): the function form for f-divergence. Currently support 'chi' and 'kl'
109+ f: the function form for f-divergence. Currently support 'chi' and 'kl'
110+ gradient_penalty_weight: the weight for discriminator gradient penalty
108111 env (Environment): The environment to interact with. ``env`` is a
109112 batched environment, which means that it runs multiple simulations
110113 simultateously. ``env` only needs to be provided to the root
@@ -155,6 +158,7 @@ def __init__(self,
155158 self ._actor_network = actor_network
156159 self ._value_network = value_network
157160 self ._discriminator_net = discriminator_net
161+ self ._gradient_penalty_weight = gradient_penalty_weight
158162
159163 assert actor_optimizer is not None
160164 if actor_optimizer is not None and actor_network is not None :
@@ -236,18 +240,44 @@ def _discriminator_train_step(self, inputs: TimeStep, state, rollout_info,
236240 """
237241 observation = inputs .observation
238242 action = rollout_info .action
239- expert_logits , _ = self . _discriminator_net (( observation , action ),
240- state )
243+
244+ discriminator_inputs = ( observation , action )
241245
242246 if is_expert :
247+ # turn on input gradient for gradient penalty in the case of expert data
248+ for e in discriminator_inputs :
249+ e .requires_grad = True
250+
251+ expert_logits , _ = self ._discriminator_net (discriminator_inputs , state )
252+
253+ if is_expert :
254+ grads = torch .autograd .grad (
255+ outputs = expert_logits ,
256+ inputs = discriminator_inputs ,
257+ grad_outputs = torch .ones_like (expert_logits ),
258+ create_graph = True ,
259+ retain_graph = True ,
260+ only_inputs = True )
261+
262+ grad_pen = 0
263+ for g in grads :
264+ grad_pen += self ._gradient_penalty_weight * (
265+ g .norm (2 , dim = 1 ) - 1 ).pow (2 )
266+
243267 label = torch .ones (expert_logits .size ())
268+ # turn on input gradient for gradient penalty in the case of expert data
269+ for e in discriminator_inputs :
270+ e .requires_grad = True
244271 else :
245272 label = torch .zeros (expert_logits .size ())
273+ grad_pen = ()
246274
247275 expert_loss = F .binary_cross_entropy_with_logits (
248276 expert_logits , label , reduction = 'none' )
249277
250- return LossInfo (loss = expert_loss , extra = SmoLossInfo (actor = expert_loss ))
278+ return LossInfo (
279+ loss = expert_loss if grad_pen == () else expert_loss + grad_pen ,
280+ extra = SmoLossInfo (actor = expert_loss , grad_penalty = grad_pen ))
251281
252282 def value_train_step (self , inputs : TimeStep , state , rollout_info ):
253283 observation = inputs .observation
@@ -285,7 +315,7 @@ def train_step(self,
285315 alf .summary .scalar ("imitation_loss_online" ,
286316 actor_loss .loss .mean ())
287317 alf .summary .scalar ("discriminator_loss_online" ,
288- expert_disc_loss .loss .mean ())
318+ expert_disc_loss .extra . actor .mean ())
289319
290320 # use predicted reward
291321 reward = self .predict_reward (inputs , rollout_info )
@@ -305,7 +335,6 @@ def train_step_offline(self,
305335 state ,
306336 rollout_info ,
307337 pre_train = False ):
308-
309338 action_dist , new_state = self ._predict_action (
310339 inputs .observation , state = state .actor )
311340
@@ -324,7 +353,8 @@ def train_step_offline(self,
324353 actor_loss .loss .mean ())
325354 alf .summary .scalar ("discriminator_loss_offline" ,
326355 expert_disc_loss .loss .mean ())
327-
356+ alf .summary .scalar ("grad_penalty" ,
357+ expert_disc_loss .extra .grad_penalty .mean ())
328358 # use predicted reward
329359 reward = self .predict_reward (inputs , rollout_info )
330360
0 commit comments