Skip to content

Commit 33ede27

Browse files
committed
Add gradient penalty
1 parent 679aa11 commit 33ede27

File tree

5 files changed

+54
-12
lines changed

5 files changed

+54
-12
lines changed

alf/algorithms/algorithm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2132,7 +2132,10 @@ def _hybrid_update(self, experience, batch_info, offline_experience,
21322132
else:
21332133
loss_info = offline_loss_info
21342134

2135-
params = self._backward_and_gradient_update(loss_info.loss * weight)
2135+
params, gns = self._backward_and_gradient_update(
2136+
loss_info.loss * weight)
2137+
2138+
loss_info = loss_info._replace(gns=gns)
21362139

21372140
if self._RL_train:
21382141
# for now, there is no need to do a hybrid after update

alf/algorithms/smodice_algorithm.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@
4646
SmoCriticInfo = namedtuple("SmoCriticInfo",
4747
["values", "initial_v_values", "is_first"])
4848

49-
SmoLossInfo = namedtuple("SmoLossInfo", ["actor"], default_value=())
49+
SmoLossInfo = namedtuple(
50+
"SmoLossInfo", ["actor", "grad_penalty"], default_value=())
5051

5152

5253
@alf.configurable
@@ -77,7 +78,8 @@ def __init__(self,
7778
value_optimizer=None,
7879
discriminator_optimizer=None,
7980
gamma: float = 0.99,
80-
f="chi",
81+
f: str = "chi",
82+
gradient_penalty_weight: float = 1,
8183
env=None,
8284
config: TrainerConfig = None,
8385
checkpoint=None,
@@ -104,7 +106,8 @@ def __init__(self,
104106
value_optimizer (torch.optim.optimizer): The optimizer for value network.
105107
discriminator_optimizer (torch.optim.optimizer): The optimizer for discriminator.
106108
gamma (float): the discount factor.
107-
f (str): the function form for f-divergence. Currently support 'chi' and 'kl'
109+
f: the function form for f-divergence. Currently support 'chi' and 'kl'
110+
gradient_penalty_weight: the weight for discriminator gradient penalty
108111
env (Environment): The environment to interact with. ``env`` is a
109112
batched environment, which means that it runs multiple simulations
110113
simultateously. ``env` only needs to be provided to the root
@@ -155,6 +158,7 @@ def __init__(self,
155158
self._actor_network = actor_network
156159
self._value_network = value_network
157160
self._discriminator_net = discriminator_net
161+
self._gradient_penalty_weight = gradient_penalty_weight
158162

159163
assert actor_optimizer is not None
160164
if actor_optimizer is not None and actor_network is not None:
@@ -236,18 +240,44 @@ def _discriminator_train_step(self, inputs: TimeStep, state, rollout_info,
236240
"""
237241
observation = inputs.observation
238242
action = rollout_info.action
239-
expert_logits, _ = self._discriminator_net((observation, action),
240-
state)
243+
244+
discriminator_inputs = (observation, action)
241245

242246
if is_expert:
247+
# turn on input gradient for gradient penalty in the case of expert data
248+
for e in discriminator_inputs:
249+
e.requires_grad = True
250+
251+
expert_logits, _ = self._discriminator_net(discriminator_inputs, state)
252+
253+
if is_expert:
254+
grads = torch.autograd.grad(
255+
outputs=expert_logits,
256+
inputs=discriminator_inputs,
257+
grad_outputs=torch.ones_like(expert_logits),
258+
create_graph=True,
259+
retain_graph=True,
260+
only_inputs=True)
261+
262+
grad_pen = 0
263+
for g in grads:
264+
grad_pen += self._gradient_penalty_weight * (
265+
g.norm(2, dim=1) - 1).pow(2)
266+
243267
label = torch.ones(expert_logits.size())
268+
# turn on input gradient for gradient penalty in the case of expert data
269+
for e in discriminator_inputs:
270+
e.requires_grad = True
244271
else:
245272
label = torch.zeros(expert_logits.size())
273+
grad_pen = ()
246274

247275
expert_loss = F.binary_cross_entropy_with_logits(
248276
expert_logits, label, reduction='none')
249277

250-
return LossInfo(loss=expert_loss, extra=SmoLossInfo(actor=expert_loss))
278+
return LossInfo(
279+
loss=expert_loss if grad_pen == () else expert_loss + grad_pen,
280+
extra=SmoLossInfo(actor=expert_loss, grad_penalty=grad_pen))
251281

252282
def value_train_step(self, inputs: TimeStep, state, rollout_info):
253283
observation = inputs.observation
@@ -285,7 +315,7 @@ def train_step(self,
285315
alf.summary.scalar("imitation_loss_online",
286316
actor_loss.loss.mean())
287317
alf.summary.scalar("discriminator_loss_online",
288-
expert_disc_loss.loss.mean())
318+
expert_disc_loss.extra.actor.mean())
289319

290320
# use predicted reward
291321
reward = self.predict_reward(inputs, rollout_info)
@@ -305,7 +335,6 @@ def train_step_offline(self,
305335
state,
306336
rollout_info,
307337
pre_train=False):
308-
309338
action_dist, new_state = self._predict_action(
310339
inputs.observation, state=state.actor)
311340

@@ -324,7 +353,8 @@ def train_step_offline(self,
324353
actor_loss.loss.mean())
325354
alf.summary.scalar("discriminator_loss_offline",
326355
expert_disc_loss.loss.mean())
327-
356+
alf.summary.scalar("grad_penalty",
357+
expert_disc_loss.extra.grad_penalty.mean())
328358
# use predicted reward
329359
reward = self.predict_reward(inputs, rollout_info)
330360

alf/examples/data_collection_carla_conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
# This is an example config file for data collection in CARLA.
2828

2929
# the desired replay buffer size for collection
30-
# 100 is just an example. Should set it to he actual desired size.
30+
# 100 is just an example. Should set it to the actual desired size.
3131
replay_buffer_length = 100
3232

3333
# the desired environment for data collection

alf/examples/smodice_bipedal_walker_conf.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
offline_buffer_length = None
3232
offline_buffer_dir = [
33-
"./hybrid_rl/replay_buffer_data/pendulum_replay_buffer_from_sac_10k"
33+
"/home/haichaozhang/data/DATA/pytorch_alf/Hobot_exp/go1wx/sac_bipedal_baseline/train/algorithm/ckpt-80000-replay_buffer"
3434
]
3535

3636
alf.config('Agent', rl_algorithm_cls=SmodiceAlgorithm, optimizer=None)
@@ -67,4 +67,11 @@
6767
# add weight decay to the v_net following smodice paper
6868
value_optimizer=alf.optimizers.Adam(lr=lr, weight_decay=1e-4),
6969
discriminator_optimizer=alf.optimizers.Adam(lr=lr),
70+
gradient_penalty_weight=0.1,
7071
)
72+
73+
# training config
74+
alf.config(
75+
"TrainerConfig",
76+
offline_buffer_dir=offline_buffer_dir,
77+
offline_buffer_length=offline_buffer_length)

alf/examples/smodice_pendulum_conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
# add weight decay to the v_net following smodice paper
7777
value_optimizer=alf.optimizers.Adam(lr=lr, weight_decay=1e-4),
7878
discriminator_optimizer=alf.optimizers.Adam(lr=lr),
79+
gradient_penalty_weight=0.1,
7980
)
8081

8182
num_iterations = 1000000
@@ -91,6 +92,7 @@
9192
rl_train_after_update_steps=0, # joint training
9293
mini_batch_size=256,
9394
mini_batch_length=2,
95+
unroll_length=1,
9496
offline_buffer_dir=offline_buffer_dir,
9597
offline_buffer_length=offline_buffer_length,
9698
num_checkpoints=1,

0 commit comments

Comments
 (0)