Skip to content

Commit 381ce9e

Browse files
committed
Merge branch 'dev'
2 parents 387b04e + bf3a0ec commit 381ce9e

File tree

5 files changed

+25
-27
lines changed

5 files changed

+25
-27
lines changed

cherry/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.0.9'
1+
__version__ = '0.1.0'
-16 KB
Binary file not shown.

examples/pybullet/delayed_tsac_pybullet.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
An implementation of Soft Actor-Critic.
55
"""
66

7-
from OpenGL import GLU
8-
import ppt
97
import copy
108
import random
119
import numpy as np
@@ -167,12 +165,6 @@ def update(replay,
167165
env.log("QF2 Loss: ", critic_qf2_loss.item())
168166
env.log("Average Rewards: ", batch.reward().mean().item())
169167

170-
# Plotting via PPT
171-
'''
172-
if random.random() < 0.05:
173-
ppt.plot(replay[-1000:].reward().mean().item(), 'cherry true rewards - TSAC1 delayed')
174-
'''
175-
176168
# Update Critic Networks
177169
critic_qf1_optimizer.zero_grad()
178170
critic_qf1_loss.backward()
@@ -186,8 +178,8 @@ def update(replay,
186178
if STEP % DELAY == 0:
187179

188180
# Policy loss
189-
q_values = th.min( critic_qf1(batch.state(), actions),
190-
critic_qf2(batch.state(), actions) )
181+
q_values = th.min(critic_qf1(batch.state(), actions),
182+
critic_qf2(batch.state(), actions))
191183
policy_loss = sac.policy_loss(log_probs, q_values, alpha)
192184

193185
env.log("Policy Loss: ", policy_loss.item())
@@ -197,14 +189,13 @@ def update(replay,
197189
policy_optimizer.step()
198190

199191
# Move target approximator parameters towards critic parameters per [3]
200-
ch.models.polyak_average( source=target_qf1,
201-
target=critic_qf1,
202-
alpha=VF_TARGET_TAU )
203-
204-
ch.models.polyak_average( source=target_qf2,
205-
target=critic_qf2,
206-
alpha=VF_TARGET_TAU )
192+
ch.models.polyak_average(source=target_qf1,
193+
target=critic_qf1,
194+
alpha=VF_TARGET_TAU)
207195

196+
ch.models.polyak_average(source=target_qf2,
197+
target=critic_qf2,
198+
alpha=VF_TARGET_TAU)
208199

209200

210201
if __name__ == '__main__':
@@ -251,5 +242,15 @@ def update(replay,
251242
replay += ep_replay
252243
replay = replay[-REPLAY_SIZE:]
253244
if len(replay) > MIN_REPLAY:
254-
update(replay, policy, critic_qf1, critic_qf2, target_qf1, target_qf2, log_alpha, policy_opt,
255-
qf1_opt, qf2_opt, alpha_opt, target_entropy)
245+
update(replay,
246+
policy,
247+
critic_qf1,
248+
critic_qf2,
249+
target_qf1,
250+
target_qf2,
251+
log_alpha,
252+
policy_opt,
253+
qf1_opt,
254+
qf2_opt,
255+
alpha_opt,
256+
target_entropy)

examples/pybullet/ppo_pybullet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def main(env='MinitaurTrottingEnv-v0'):
159159
if __name__ == '__main__':
160160
env_name = 'CartPoleBulletEnv-v0'
161161
env_name = 'AntBulletEnv-v0'
162+
env_name = 'HalfCheetahBulletEnv-v0'
162163
# env_name = 'RoboschoolAnt-v1'
163164
# env_name = 'MinitaurTrottingEnv-v0'
164165
main(env_name)

examples/pybullet/sac_pybullet.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,12 @@
44
An implementation of Soft Actor-Critic.
55
"""
66

7-
from OpenGL import GLU
8-
import ppt
7+
#from OpenGL import GLU
98
import copy
109
import random
1110
import numpy as np
1211
import gym
1312
import pybullet_envs
14-
import roboschool
1513

1614
import torch as th
1715
import torch.nn as nn
@@ -155,8 +153,6 @@ def update(replay,
155153
env.log("VF Loss: ", vf_loss.item())
156154
env.log("Policy Loss: ", policy_loss.item())
157155
env.log("Average Rewards: ", batch.reward().mean().item())
158-
if random.random() < 0.05:
159-
ppt.plot(replay[-1000:].reward().mean().item(), 'cherry true rewards')
160156

161157
# Update
162158
qf_opt.zero_grad()
@@ -181,9 +177,9 @@ def update(replay,
181177
np.random.seed(SEED)
182178
th.manual_seed(SEED)
183179
env_name = 'HalfCheetahBulletEnv-v0'
184-
env_name = 'RoboschoolAnt-v1'
180+
# env_name = 'AntBulletEnv-v0'
185181
env = gym.make(env_name)
186-
env = envs.Logger(env, interval=1000)
182+
env = envs.VisdomLogger(env, interval=1000)
187183
env = envs.ActionSpaceScaler(env)
188184
env = envs.Torch(env)
189185
env = envs.Runner(env)

0 commit comments

Comments
 (0)