Skip to content

Commit 82c59d1

Browse files
committed
update docstr
1 parent 944496e commit 82c59d1

File tree

1 file changed

+25
-108
lines changed

1 file changed

+25
-108
lines changed

alf/algorithms/smodice_algorithm.py

Lines changed: 25 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -49,83 +49,6 @@
4949
SmoLossInfo = namedtuple("SmoLossInfo", ["actor"], default_value=())
5050

5151

52-
# -> algorithm
53-
class Discriminator_SA(Algorithm):
54-
def __init__(self, observation_spec, action_spec):
55-
super().__init__(observation_spec=observation_spec)
56-
57-
disc_net = CriticNetwork((observation_spec, action_spec))
58-
self._disc_net = disc_net
59-
60-
def forward(self, inputs, state=()):
61-
return self._disc_net(inputs, state)
62-
63-
def compute_grad_pen(self, expert_state, offline_state, lambda_=10):
64-
alpha = torch.rand(expert_state.size(0), 1)
65-
expert_data = expert_state
66-
offline_data = offline_state
67-
68-
alpha = alpha.expand_as(expert_data).to(expert_data.device)
69-
70-
mixup_data = alpha * expert_data + (1 - alpha) * offline_data
71-
mixup_data.requires_grad = True
72-
73-
disc = self(mixup_data)
74-
ones = torch.ones(disc.size()).to(disc.device)
75-
grad = autograd.grad(
76-
outputs=disc,
77-
inputs=mixup_data,
78-
grad_outputs=ones,
79-
create_graph=True,
80-
retain_graph=True,
81-
only_inputs=True)[0]
82-
83-
grad_pen = lambda_ * (grad.norm(2, dim=1) - 1).pow(2).mean()
84-
return grad_pen
85-
86-
def update(self, expert_loader, offline_loader):
87-
self.train()
88-
89-
loss = 0
90-
n = 0
91-
for expert_state, offline_state in zip(expert_loader, offline_loader):
92-
93-
expert_state = expert_state[0].to(self.device)
94-
offline_state = offline_state[0][:expert_state.shape[0]].to(
95-
self.device)
96-
97-
policy_d = self(offline_state)
98-
expert_d = self(expert_state)
99-
100-
expert_loss = F.binary_cross_entropy_with_logits(
101-
expert_d,
102-
torch.ones(expert_d.size()).to(self.device))
103-
policy_loss = F.binary_cross_entropy_with_logits(
104-
policy_d,
105-
torch.zeros(policy_d.size()).to(self.device))
106-
107-
gail_loss = expert_loss + policy_loss
108-
grad_pen = self.compute_grad_pen(expert_state, offline_state)
109-
110-
loss += (gail_loss + grad_pen).item()
111-
n += 1
112-
113-
self.optimizer.zero_grad()
114-
(gail_loss + grad_pen).backward()
115-
self.optimizer.step()
116-
return loss / n
117-
118-
def predict_reward(self, state):
119-
with torch.no_grad():
120-
self.eval()
121-
d = self(state)
122-
s = torch.sigmoid(d)
123-
# log(d^E/d^O)
124-
# reward = - (1/s-1).log()
125-
reward = s.log() - (1 - s).log()
126-
return reward
127-
128-
12952
@alf.configurable
13053
class SmodiceAlgorithm(OffPolicyAlgorithm):
13154
r"""SMODICE algorithm.
@@ -143,27 +66,24 @@ class SmodiceAlgorithm(OffPolicyAlgorithm):
14366
ICML 2022.
14467
"""
14568

146-
def __init__(
147-
self,
148-
observation_spec,
149-
action_spec: BoundedTensorSpec,
150-
reward_spec=TensorSpec(()),
151-
actor_network_cls=ActorNetwork,
152-
v_network_cls=ValueNetwork,
153-
discriminator_network_cls=None,
154-
actor_optimizer=None,
155-
value_optimizer=None,
156-
discriminator_optimizer=None,
157-
#=====new params
158-
gamma: float = 0.99,
159-
v_l2_reg: float = 0.001,
160-
env=None,
161-
config: TrainerConfig = None,
162-
checkpoint=None,
163-
debug_summaries=False,
164-
epsilon_greedy=None,
165-
f="chi",
166-
name="SmodiceAlgorithm"):
69+
def __init__(self,
70+
observation_spec,
71+
action_spec: BoundedTensorSpec,
72+
reward_spec=TensorSpec(()),
73+
actor_network_cls=ActorNetwork,
74+
v_network_cls=ValueNetwork,
75+
discriminator_network_cls=None,
76+
actor_optimizer=None,
77+
value_optimizer=None,
78+
discriminator_optimizer=None,
79+
gamma: float = 0.99,
80+
f="chi",
81+
env=None,
82+
config: TrainerConfig = None,
83+
checkpoint=None,
84+
debug_summaries=False,
85+
epsilon_greedy=None,
86+
name="SmodiceAlgorithm"):
16787
"""
16888
Args:
16989
observation_spec (nested TensorSpec): representing the observations.
@@ -178,7 +98,13 @@ def __init__(
17898
actor_network_cls (Callable): is used to construct the actor network.
17999
The constructed actor network is a determinstic network and
180100
will be used to generate continuous actions.
101+
v_network_cls (Callable): is used to construct the value network.
102+
discriminator_network_cls (Callable): is used to construct the discriminatr.
181103
actor_optimizer (torch.optim.optimizer): The optimizer for actor.
104+
value_optimizer (torch.optim.optimizer): The optimizer for value network.
105+
discriminator_optimizer (torch.optim.optimizer): The optimizer for discriminator.
106+
gamma (float): the discount factor.
107+
f (str): the function form for f-divergence. Currently support 'chi' and 'kl'
182108
env (Environment): The environment to interact with. ``env`` is a
183109
batched environment, which means that it runs multiple simulations
184110
simultateously. ``env` only needs to be provided to the root
@@ -242,12 +168,9 @@ def __init__(
242168
if discriminator_optimizer is not None and discriminator_net is not None:
243169
self.add_optimizer(discriminator_optimizer, [discriminator_net])
244170

245-
self._actor_optimizer = actor_optimizer
246-
self._value_optimizer = value_optimizer
247-
self._v_l2_reg = v_l2_reg
248171
self._gamma = gamma
249172
self._f = f
250-
assert f == "chi", "only support chi form"
173+
assert f in ["chi", "kl"], "only support chi or kl form"
251174

252175
# f-divergence functions
253176
if self._f == 'chi':
@@ -327,14 +250,8 @@ def _discriminator_train_step(self, inputs: TimeStep, state, rollout_info,
327250
return LossInfo(loss=expert_loss, extra=SmoLossInfo(actor=expert_loss))
328251

329252
def value_train_step(self, inputs: TimeStep, state, rollout_info):
330-
# initial_v_values, e_v, result={}
331253
observation = inputs.observation
332-
333-
# extract initial observation from batch, or prepare a batch
334254
initial_observation = observation
335-
336-
# Shared network values
337-
# mini_batch_length
338255
initial_v_values, _ = self._value_network(initial_observation)
339256

340257
# mini-batch len

0 commit comments

Comments
 (0)