diff --git a/chainerrl_visualizer/launcher.py b/chainerrl_visualizer/launcher.py index cec6b9d..b0ef8c8 100644 --- a/chainerrl_visualizer/launcher.py +++ b/chainerrl_visualizer/launcher.py @@ -34,6 +34,7 @@ def launch_visualizer(agent, gymlike_env, action_meanings, log_dir='log_space', if isinstance(gymlike_env, gym.Env): modify_gym_env_render(gymlike_env) + compensate_agent_lacked_method(agent) profile = inspect_agent(agent, gymlike_env, contains_rnn) job_queue = Queue() @@ -99,8 +100,65 @@ def prepare_log_directory(log_dir): # log_dir is assumed to be full path return True +def compensate_agent_lacked_method(agent): + if not hasattr(agent, 'batch_states'): + agent.batch_states = chainerrl.misc.batch_states + + +def validate_agent_profile(profile): + if profile['distribution_type'] is None and profile['action_value_type'] is None: + raise Exception('Outputs of model do not contain ActionValue nor DistributionType') + + if profile['action_value_type'] is not None \ + and profile['action_value_type'] not in SUPPORTED_ACTION_VALUES: + raise Exception('ActionValue type {} is not supported for now'.format( + profile['action_value_type'])) + + if profile['distribution_type'] is not None \ + and profile['distribution_type'] not in SUPPORTED_DISTRIBUTIONS: + raise Exception('Distribution type {} is not supported for now'.format( + profile['distribution_type'])) + + +# workaround +def inspect_exceptional_agent(agent, gymlike_env, contains_rnn): + profile = { + 'contains_recurrent_model': contains_rnn, + 'state_value_returned': True, + 'distribution_type': None, + 'action_value_type': None, + } + + obs = gymlike_env.reset() + policy = agent.policy + + # workaround + if hasattr(agent, 'xp'): + xp = agent.xp + else: + xp = np + + if isinstance(policy, chainerrl.recurrent.RecurrentChainMixin): + with policy.state_kept(): + dist = policy(agent.batch_states([obs], xp, agent.phi)) + else: + dist = policy(agent.batch_states([obs], xp, agent.phi)) + + profile['distribution_type'] = type(dist).__name__ + + validate_agent_profile(profile) + + return profile + + # Create and return dict contains agent profile def inspect_agent(agent, gymlike_env, contains_rnn): + # workaround + # These three agents are exceptional in that the other agents have `model` attribute + # and `model.__call__()` returns outputs of the model. + if type(agent).__name__ in ['TRPO', 'DDPG', 'PGT']: + return inspect_exceptional_agent(agent, gymlike_env, contains_rnn) + profile = { 'contains_recurrent_model': contains_rnn, 'state_value_returned': False, @@ -142,18 +200,6 @@ def inspect_agent(agent, gymlike_env, contains_rnn): raise Exception( 'Model output type of {} is not supported for now'.format(type(output).__name__)) - # Validations - if profile['distribution_type'] is None and profile['action_value_type'] is None: - raise Exception('Outputs of model do not contain ActionValue nor DistributionType') - - if profile['action_value_type'] is not None \ - and profile['action_value_type'] not in SUPPORTED_ACTION_VALUES: - raise Exception('ActionValue type {} is not supported for now'.format( - profile['action_value_type'])) - - if profile['distribution_type'] is not None \ - and profile['distribution_type'] not in SUPPORTED_DISTRIBUTIONS: - raise Exception('Distribution type {} is not supported for now'.format( - profile['distribution_type'])) + validate_agent_profile(profile) return profile diff --git a/chainerrl_visualizer/worker_jobs/rollout_job.py b/chainerrl_visualizer/worker_jobs/rollout_job.py index a7131f4..65d034a 100644 --- a/chainerrl_visualizer/worker_jobs/rollout_job.py +++ b/chainerrl_visualizer/worker_jobs/rollout_job.py @@ -27,10 +27,8 @@ def rollout(agent, gymlike_env, rollout_dir, step_count, obs_list, render_img_li render_img_list[:] = [] # Clear the shared render images list # workaround - if hasattr(agent, 'xp'): - xp = agent.xp - else: - xp = np + if not hasattr(agent, 'xp'): + agent.xp = np log_fp = open(os.path.join(rollout_dir, ROLLOUT_LOG_FILE_NAME), 'a') writer = jsonlines.Writer(log_fp) @@ -50,17 +48,13 @@ def rollout(agent, gymlike_env, rollout_dir, step_count, obs_list, render_img_li obs_list.append(obs) render_img_list.append(rendered) - if isinstance(agent, chainerrl.recurrent.RecurrentChainMixin): - with agent.model.state_kept(): - outputs = agent.model(agent.batch_states([obs], xp, agent.phi)) + # workaround + # These three agents are exceptional in that the other agents have `model` attribute + # and `model.__call__()` returns outputs of the model. + if type(agent).__name__ in ['TRPO', 'DDPG', 'PGT']: + obs, r, done, action, outputs = _step_exceptional_agent(agent, gymlike_env, obs) else: - outputs = agent.model(agent.batch_states([obs], xp, agent.phi)) - - if not isinstance(outputs, tuple): - outputs = tuple((outputs,)) - - action = agent.act(obs) - obs, r, done, info = gymlike_env.step(action) + obs, r, done, action, outputs = _step_agent(agent, gymlike_env, obs) log_entries = dict() log_entries['step'] = t @@ -133,6 +127,79 @@ def rollout(agent, gymlike_env, rollout_dir, step_count, obs_list, render_img_li raise Exception(error_msg) +def _step_agent(agent, gymlike_env, obs): + if isinstance(agent, chainerrl.recurrent.RecurrentChainMixin): + with agent.model.state_kept(): + outputs = agent.model(agent.batch_states([obs], agent.xp, agent.phi)) + else: + outputs = agent.model(agent.batch_states([obs], agent.xp, agent.phi)) + + if not isinstance(outputs, tuple): + outputs = tuple((outputs,)) + + action = agent.act(obs) + obs, r, done, _ = gymlike_env.step(action) + + return obs, r, done, action, outputs + + +def _step_exceptional_agent(agent, gymlike_env, obs): + policy = agent.policy + agent_type = type(agent).__name__ + b_state = agent.batch_states([obs], agent.xp, agent.phi) + + if agent_type in ['DDPG', 'PGT']: + if isinstance(policy, chainerrl.recurrent.RecurrentChainMixin): + with policy.state_kept(): + action_dist = policy(b_state) + else: + action_dist = policy(b_state) + + # workaround + # If `agent.act()` called when `agent.q_function` has LSTM, + # the params of the model will change. So, we have to directly get `action` + # from `action_dist`. `action` is needed for parameter of `q_function()`. + if agent_type == 'DDPG': + action = action_dist.sample() + else: # PGT + if agent.act_deterministically: + action = action_dist.most_probable + else: + action = action_dist.sample() + + q_function = agent.q_function + if isinstance(q_function, chainerrl.recurrent.RecurrentChainMixin): + with q_function.state_kept(): + q_value = q_function(b_state, action) + else: + q_value = q_function(b_state, action) + + outputs = (action_dist, q_value) + + elif agent_type == 'TRPO': + if isinstance(policy, chainerrl.recurrent.RecurrentChainMixin): + with policy.state_kept(): + action_dist = policy(b_state) + else: + action_dist = policy(b_state) + + value_function = agent.vf + if isinstance(value_function, chainerrl.recurrent.RecurrentChainMixin): + with value_function.state_kept(): + state_value = value_function(b_state) + else: + state_value = value_function(b_state) + + outputs = (action_dist, state_value) + else: + raise Exception('{} is not one of the exceptional agent types'.format(agent_type)) + + action = agent.act(obs) + obs, r, done, _ = gymlike_env.step(action) + + return obs, r, done, action, outputs + + def _save_env_render(rendered, rollout_dir): image = Image.fromarray(rendered) image_path = os.path.join(rollout_dir, 'images', generate_random_string(11) + '.png')