diff --git a/code/run_policy.py b/code/run_policy.py index 55bc231d..539f08b3 100644 --- a/code/run_policy.py +++ b/code/run_policy.py @@ -1,6 +1,6 @@ """ -Code to load a policy and generate rollout data. Adapted from https://github.com/berkeleydeeprlcourse. +Code to load a policy and generate rollout data. Adapted from https://github.com/berkeleydeeprlcourse. Example usage: python run_policy.py ../trained_policies/Humanoid-v1/policy_reward_11600/lin_policy_plus.npz Humanoid-v1 --render \ --num_rollouts 20 @@ -8,6 +8,7 @@ import numpy as np import gym + def main(): import argparse parser = argparse.ArgumentParser() @@ -20,13 +21,13 @@ def main(): print('loading and building expert policy') lin_policy = np.load(args.expert_policy_file) - lin_policy = lin_policy.items()[0][1] - + lin_policy = lin_policy[lin_policy.files[0]] + M = lin_policy[0] - # mean and std of state vectors estimated online by ARS. + # mean and std of state vectors estimated online by ARS. mean = lin_policy[1] std = lin_policy[2] - + env = gym.make(args.envname) returns = [] @@ -42,8 +43,8 @@ def main(): action = np.dot(M, (obs - mean)/std) observations.append(obs) actions.append(action) - - + + obs, r, done, _ = env.step(action) totalr += r steps += 1 @@ -57,6 +58,6 @@ def main(): print('returns', returns) print('mean return', np.mean(returns)) print('std of return', np.std(returns)) - + if __name__ == '__main__': main()