forked from haarnoja/sac
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_traces.py
More file actions
70 lines (62 loc) · 2.79 KB
/
plot_traces.py
File metadata and controls
70 lines (62 loc) · 2.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
import argparse
import numpy as np
import joblib
import tensorflow as tf
import os
from sac.misc import utils
from sac.policies.hierarchical_policy import FixedOptionPolicy
from sac.misc.sampler import rollouts
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('file', type=str, help='Path to the snapshot file.')
parser.add_argument('--max-path-length', '-l', type=int, default=100)
parser.add_argument('--n_paths', type=int, default=1)
parser.add_argument('--dim_0', type=int, default=0)
parser.add_argument('--dim_1', type=int, default=1)
parser.add_argument('--use_qpos', type=bool, default=False)
parser.add_argument('--use_action', type=bool, default=False)
parser.add_argument('--deterministic', '-d', dest='deterministic',
action='store_true')
parser.add_argument('--no-deterministic', '-nd', dest='deterministic',
action='store_false')
parser.set_defaults(deterministic=True)
args = parser.parse_args()
filename = '{}_{}_{}_trace.png'.format(os.path.splitext(args.file)[0],
args.dim_0, args.dim_1)
with tf.Session() as sess:
data = joblib.load(args.file)
policy = data['policy']
env = data['env']
num_skills = data['policy'].observation_space.flat_dim - data['env'].spec.observation_space.flat_dim
plt.figure(figsize=(6, 6))
palette = sns.color_palette('hls', num_skills)
with policy.deterministic(args.deterministic):
for z in range(num_skills):
fixed_z_policy = FixedOptionPolicy(policy, num_skills, z)
for path_index in range(args.n_paths):
obs = env.reset()
if args.use_qpos:
qpos = env.wrapped_env.env.model.data.qpos[:, 0]
obs_vec = [qpos]
else:
obs_vec = [obs]
for t in range(args.max_path_length):
action, _ = fixed_z_policy.get_action(obs)
(obs, _, _, _) = env.step(action)
if args.use_qpos:
qpos = env.wrapped_env.env.model.data.qpos[:, 0]
obs_vec.append(qpos)
elif args.use_action:
obs_vec.append(action)
else:
obs_vec.append(obs)
obs_vec = np.array(obs_vec)
x = obs_vec[:, args.dim_0]
y = obs_vec[:, args.dim_1]
plt.plot(x, y, c=palette[z])
plt.savefig(filename)
plt.close()