Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ __pycache__

*.pth
*live.ipynb
*.log

/output
/output/**


2 changes: 2 additions & 0 deletions configs/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@
update_post_train=1, # how often to resample the context when collecting data during training (in trajectories)
num_exp_traj_eval=1, # how many exploration trajs to collect before beginning posterior sampling at test time
recurrent=False, # recurrent or permutation-invariant encoder
use_traj_context=False, # use traj or tran as the context
dump_eval_paths=False, # whether to save evaluation trajectories
),
util_params=dict(
base_log_dir='output',
use_gpu=True,
gpu_id=0,
seed=0,
debug=False, # debugging triggers printing and writes logs to debug directory
docker=False, # TODO docker is not yet supported
)
Expand Down
20 changes: 16 additions & 4 deletions launch_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from rlkit.torch.networks import FlattenMlp, MlpEncoder, RecurrentEncoder
from rlkit.torch.sac.sac import PEARLSoftActorCritic
from rlkit.torch.sac.agent import PEARLAgent
from rlkit.launchers.launcher_util import setup_logger
from rlkit.launchers.launcher_util import setup_logger, set_seed
import rlkit.torch.pytorch_util as ptu
from configs.default import default_config

Expand Down Expand Up @@ -123,17 +123,29 @@ def deep_update_dict(fr, to):

@click.command()
@click.argument('config', default=None)
@click.option('--gpu', default=0)
@click.option('--gpu_id', default=0)
@click.option('--docker', is_flag=True, default=False)
@click.option('--debug', is_flag=True, default=False)
def main(config, gpu, docker, debug):
@click.option('--seed', default=0)
@click.option('--rnn/--mlp', default=False, help='choose the encoder network, RNN or MLP')
@click.option('--traj/--tran', default=False, help='use traj or tran context')
@click.option('--srb', is_flag=True, help="save_replay_buffer") ##
def main(config, gpu_id, seed, srb, rnn, traj, docker, debug):

set_seed(seed)

variant = default_config
if config:
with open(os.path.join(config)) as f:
exp_params = json.load(f)
variant = deep_update_dict(exp_params, variant)
variant['util_params']['gpu_id'] = gpu
variant['util_params']['gpu_id'] = gpu_id
variant['util_params']['seed'] = seed
variant['algo_params']['save_replay_buffer'] = srb
variant['algo_params']['recurrent'] = rnn
variant['algo_params']['use_traj_context'] = traj
## with rnn/mlp and traj/tran, we can build the four cases easily
## rnn-tran, rnn-traj, mlp-tran(default), mlp-traj

experiment(variant)

Expand Down
9 changes: 6 additions & 3 deletions rlkit/core/rl_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ def collect_data(self, num_samples, resample_z_rate, update_posterior_rate, add_
gt.stamp('sample')

def _try_to_eval(self, epoch):
logger.save_extra_data(self.get_extra_data_to_save(epoch))
if (epoch+1) % 50 == 0: ## save RB every 50 epochs
logger.save_extra_data(self.get_extra_data_to_save(epoch))
if self._can_evaluate():
self.evaluate(epoch)

Expand Down Expand Up @@ -417,7 +418,9 @@ def evaluate(self, epoch):
for idx in indices:
self.task_idx = idx
self.env.reset_task(idx)
self.agent.clear_z() ## I add this which make rnn encoder works!
paths = []

for _ in range(self.num_steps_per_eval // self.max_path_length):
context = self.sample_context(idx)
self.agent.infer_posterior(context)
Expand Down Expand Up @@ -458,8 +461,8 @@ def evaluate(self, epoch):
self.eval_statistics['AverageTrainReturn_all_train_tasks'] = train_returns
self.eval_statistics['AverageReturn_all_train_tasks'] = avg_train_return
self.eval_statistics['AverageReturn_all_test_tasks'] = avg_test_return
logger.save_extra_data(avg_train_online_return, path='online-train-epoch{}'.format(epoch))
logger.save_extra_data(avg_test_online_return, path='online-test-epoch{}'.format(epoch))
# logger.save_extra_data(avg_train_online_return, path='online-train-epoch{}'.format(epoch))
# logger.save_extra_data(avg_test_online_return, path='online-test-epoch{}'.format(epoch))

for key, value in self.eval_statistics.items():
logger.record_tabular(key, value)
Expand Down
12 changes: 6 additions & 6 deletions rlkit/data_management/simple_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ def __init__(
self._observation_dim = observation_dim
self._action_dim = action_dim
self._max_replay_buffer_size = max_replay_buffer_size
self._observations = np.zeros((max_replay_buffer_size, observation_dim))
self._observations = np.zeros((max_replay_buffer_size, observation_dim), dtype=np.float32)
# It's a bit memory inefficient to save the observations twice,
# but it makes the code *much* easier since you no longer have to
# worry about termination conditions.
self._next_obs = np.zeros((max_replay_buffer_size, observation_dim))
self._actions = np.zeros((max_replay_buffer_size, action_dim))
self._next_obs = np.zeros((max_replay_buffer_size, observation_dim), dtype=np.float32)
self._actions = np.zeros((max_replay_buffer_size, action_dim), dtype=np.float32)
# Make everything a 2D np array to make it easier for other code to
# reason about the shape of the data
self._rewards = np.zeros((max_replay_buffer_size, 1))
self._sparse_rewards = np.zeros((max_replay_buffer_size, 1))
self._rewards = np.zeros((max_replay_buffer_size, 1), dtype=np.float32)
self._sparse_rewards = np.zeros((max_replay_buffer_size, 1), dtype=np.float32)
# self._terminals[i] = a terminal was received at time i
self._terminals = np.zeros((max_replay_buffer_size, 1), dtype='uint8')
self.clear()
Expand Down Expand Up @@ -76,7 +76,7 @@ def random_sequence(self, batch_size):
indices = []
while len(indices) < batch_size:
# TODO hack to not deal with wrapping episodes, just don't take the last one
start = np.random.choice(self.episode_starts[:-1])
start = np.random.choice(self._episode_starts[:-1]) ## previous one is self.episode_starts which should be a typo
pos_idx = self._episode_starts.index(start)
indices += list(range(start, self._episode_starts[pos_idx + 1]))
i += 1
Expand Down
41 changes: 37 additions & 4 deletions rlkit/launchers/launcher_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import datetime
import dateutil.tz
import numpy as np
import torch

from rlkit.core import logger
from rlkit.launchers import config
Expand Down Expand Up @@ -197,8 +198,37 @@ def create_simple_exp_name():
timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
return timestamp

def create_method_exp_name(variant):
"""
Create a semi-unique experiment name that has a timestamp
which network structure: rnn or mlp
which context type: traj or tran?
"""
now = datetime.datetime.now(dateutil.tz.tzlocal())
timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')

seed=variant['util_params']['seed']
recurrent = variant['algo_params']['recurrent']
traj = variant['algo_params']['use_traj_context']

def create_log_dir(exp_prefix, exp_id=None, seed=0, base_log_dir=None):
exp_name = timestamp

if recurrent:
exp_name = exp_name + "-rnn"
else:
exp_name = exp_name + "-mlp"

if traj:
exp_name = exp_name + "-traj"
else:
exp_name = exp_name + "-tran"

exp_name = exp_name + f'-sd{seed}'

return exp_name


def create_log_dir(exp_prefix, exp_id=None, variant=None, base_log_dir=None):
"""
Creates and returns a unique log directory.

Expand All @@ -210,7 +240,8 @@ def create_log_dir(exp_prefix, exp_id=None, seed=0, base_log_dir=None):
base_log_dir = config.LOCAL_LOG_DIR
exp_name = exp_id
if exp_name is None:
exp_name = create_simple_exp_name()
# exp_name = create_simple_exp_name()
exp_name = create_method_exp_name(variant)
log_dir = osp.join(base_log_dir, exp_prefix.replace("_", "-"), exp_name)
os.makedirs(log_dir, exist_ok=True)
return log_dir
Expand All @@ -219,7 +250,7 @@ def create_log_dir(exp_prefix, exp_id=None, seed=0, base_log_dir=None):
def setup_logger(
exp_prefix="default",
exp_id=0,
seed=0,
# seed=0,
variant=None,
base_log_dir=None,
text_log_file="debug.log",
Expand Down Expand Up @@ -261,7 +292,7 @@ def setup_logger(
"""
first_time = log_dir is None
if first_time:
log_dir = create_log_dir(exp_prefix, exp_id=exp_id, seed=seed,
log_dir = create_log_dir(exp_prefix, exp_id=exp_id, variant=variant,
base_log_dir=base_log_dir)

if variant is not None:
Expand Down Expand Up @@ -343,6 +374,8 @@ def set_seed(seed):
seed = int(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


def reset_execution_environment():
Expand Down
4 changes: 3 additions & 1 deletion rlkit/torch/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
policy_pre_activation_weight=0.,
optimizer_class=optim.Adam,
recurrent=False,
use_traj_context=False,
use_information_bottleneck=True,
use_next_obs_in_context=False,
sparse_rewards=False,
Expand All @@ -54,6 +55,7 @@ def __init__(
self.render_eval_paths = render_eval_paths

self.recurrent = recurrent
self.use_traj_context = use_traj_context
self.latent_dim = latent_dim
self.qf_criterion = nn.MSELoss()
self.vf_criterion = nn.MSELoss()
Expand Down Expand Up @@ -133,7 +135,7 @@ def sample_context(self, indices):
# make method work given a single task index
if not hasattr(indices, '__iter__'):
indices = [indices]
batches = [ptu.np_to_pytorch_batch(self.enc_replay_buffer.random_batch(idx, batch_size=self.embedding_batch_size, sequence=self.recurrent)) for idx in indices]
batches = [ptu.np_to_pytorch_batch(self.enc_replay_buffer.random_batch(idx, batch_size=self.embedding_batch_size, sequence=self.use_traj_context)) for idx in indices]
context = [self.unpack_batch(batch, sparse_reward=self.sparse_rewards) for batch in batches]
# group like elements together
context = [[x[i] for x in context] for i in range(len(context[0]))]
Expand Down
14 changes: 14 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# run this file through 'sh run.sh' in the terminal

# nohup python -u launch_experiment.py ./configs/cheetah-dir.json --rnn --tran --gpu_id=0 --seed=0 > pearl-cheetah-dir-rnn-tran-sd0.log 2>&1 &
nohup python -u launch_experiment.py ./configs/cheetah-dir.json --rnn --tran --gpu_id=0 --seed=1 > pearl-cheetah-dir-rnn-tran-sd1.log 2>&1 &
nohup python -u launch_experiment.py ./configs/cheetah-dir.json --rnn --tran --gpu_id=0 --seed=10 > pearl-cheetah-dir-rnn-tran-sd10.log 2>&1 &


# nohup python -u launch_experiment.py ./configs/cheetah-dir.json --rnn --traj --gpu_id=1 > pearl-cheetah-dir-rnn-traj.log 2>&1 &
# nohup python -u launch_experiment.py ./configs/cheetah-dir.json --rnn --traj --gpu_id=1 --seed=1 > pearl-cheetah-dir-rnn-traj-sd1.log 2>&1 &
# nohup python -u launch_experiment.py ./configs/cheetah-dir.json --rnn --traj --gpu_id=1 --seed=10 > pearl-cheetah-dir-rnn-traj-sd10.log 2>&1 &

# nohup python -u launch_experiment.py ./configs/cheetah-dir.json --mlp --tran --gpu_id=0 > pearl-cheetah-dir-mlp-tran.log 2>&1 &

# nohup python -u launch_experiment.py ./configs/cheetah-dir.json --mlp --traj --gpu_id=0 > pearl-cheetah-dir-mlp-traj.log 2>&1 &