Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 55 additions & 55 deletions baselines/ray_exp/red_gym_env_ray.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

import sys
import uuid
import uuid
import os
from math import floor, sqrt
import json
Expand Down Expand Up @@ -116,14 +116,14 @@ def reset(self, *, seed=None, options=None):
self.init_knn()

self.recent_memory = np.zeros((self.output_shape[1]*self.memory_height, 3), dtype=np.uint8)

self.recent_frames = np.zeros(
(self.frame_stacks, self.output_shape[0],
(self.frame_stacks, self.output_shape[0],
self.output_shape[1], self.output_shape[2]),
dtype=np.uint8)

self.agent_stats = []

if self.save_video:
base_dir = self.s_path / Path('rollouts')
base_dir.mkdir(exist_ok=True)
Expand All @@ -133,7 +133,7 @@ def reset(self, *, seed=None, options=None):
self.full_frame_writer.__enter__()
self.model_frame_writer = media.VideoWriter(base_dir / model_name, self.output_full[:2], fps=60)
self.model_frame_writer.__enter__()

self.levels_satisfied = False
self.base_explore = 0
self.max_opponent_level = 0
Expand All @@ -147,7 +147,7 @@ def reset(self, *, seed=None, options=None):
self.total_reward = sum([val for _, val in self.progress_reward.items()])
self.reset_count += 1
return self.render(add_memory=False), {}

def init_knn(self):
# Declaring index
self.knn_index = hnswlib.Index(space='l2', dim=self.vec_dim) # possible options are l2, cosine or ip
Expand All @@ -163,19 +163,19 @@ def render(self, reduce_res=True, add_memory=True, update_mem=True):
self.recent_frames[0] = game_pixels_render
if add_memory:
pad = np.zeros(
shape=(self.mem_padding, self.output_shape[1], 3),
shape=(self.mem_padding, self.output_shape[1], 3),
dtype=np.uint8)
game_pixels_render = np.concatenate(
(
self.create_exploration_memory(),
self.create_exploration_memory(),
pad,
self.create_recent_memory(),
pad,
rearrange(self.recent_frames, 'f h w c -> (f h) w c')
),
axis=0)
return game_pixels_render

def step(self, action):

self.run_action_on_emulator(action)
Expand All @@ -191,11 +191,11 @@ def step(self, action):
obs_flat = obs_memory.flatten().astype(np.float32)

self.update_frame_knn_index(obs_flat)

self.update_heal_reward()

new_reward, new_prog = self.update_reward()

self.last_health = self.read_hp_fraction()

# shift over short term reward memory
Expand All @@ -222,7 +222,7 @@ def run_action_on_emulator(self, action):
# release arrow
self.pyboy.send_input(self.release_arrow[action])
if action > 3 and action < 6:
# release button
# release button
self.pyboy.send_input(self.release_button[action - 4])
if action == WindowEvent.PRESS_BUTTON_START:
self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_START)
Expand All @@ -231,18 +231,18 @@ def run_action_on_emulator(self, action):
self.pyboy.tick()
if self.save_video and self.fast_video:
self.add_video_frame()

def add_video_frame(self):
self.full_frame_writer.add_image(self.render(reduce_res=False, update_mem=False))
self.model_frame_writer.add_image(self.render(reduce_res=True, update_mem=False))

def append_agent_stats(self):
x_pos = self.read_m(0xD362)
y_pos = self.read_m(0xD361)
map_n = self.read_m(0xD35E)
levels = [self.read_m(a) for a in [0xD18C, 0xD1B8, 0xD1E4, 0xD210, 0xD23C, 0xD268]]
self.agent_stats.append({
'step': self.step_count, 'x': x_pos, 'y': y_pos, 'map': map_n,
'step': self.step_count, 'x': x_pos, 'y': y_pos, 'map': map_n,
'pcount': self.read_m(0xD163), 'levels': levels, 'ptypes': self.read_party(),
'hp': self.read_hp_fraction(),
'frames': self.knn_index.get_current_count(),
Expand All @@ -251,7 +251,7 @@ def append_agent_stats(self):
})

def update_frame_knn_index(self, frame_vec):

if self.get_levels_sum() >= 22 and not self.levels_satisfied:
self.levels_satisfied = True
self.base_explore = self.knn_index.get_current_count()
Expand All @@ -263,7 +263,7 @@ def update_frame_knn_index(self, frame_vec):
frame_vec, np.array([self.knn_index.get_current_count()])
)
else:
# check for nearest frame and add if current
# check for nearest frame and add if current
labels, distances = self.knn_index.knn_query(frame_vec, k = 1)
if distances[0] > self.similar_frame_dist:
self.knn_index.add_items(
Expand All @@ -280,25 +280,25 @@ def update_reward(self):
if new_step < 0 and self.read_hp_fraction() > 0:
#print(f'\n\nreward went down! {self.progress_reward}\n\n')
self.save_screenshot('neg_reward')

self.total_reward = new_total
return (new_step,
(new_prog[0]-old_prog[0],
new_prog[1]-old_prog[1],
return (new_step,
(new_prog[0]-old_prog[0],
new_prog[1]-old_prog[1],
new_prog[2]-old_prog[2])
)

def group_rewards(self):
prog = self.progress_reward
# these values are only used by memory
return (prog['level'] * 100, self.read_hp_fraction()*2000, prog['explore'] * 160)#(prog['events'],
# prog['levels'] + prog['party_xp'],
return (prog['level'] * 100, self.read_hp_fraction()*2000, prog['explore'] * 160)#(prog['events'],
# prog['levels'] + prog['party_xp'],
# prog['explore'])

def create_exploration_memory(self):
w = self.output_shape[1]
h = self.memory_height

def make_reward_channel(r_val):
col_steps = self.col_steps
row = floor(r_val / (h * col_steps))
Expand All @@ -308,26 +308,26 @@ def make_reward_channel(r_val):
col = floor((r_val - row_covered) / col_steps)
memory[:col, row] = 255
col_covered = col * col_steps
last_pixel = floor(r_val - row_covered - col_covered)
last_pixel = floor(r_val - row_covered - col_covered)
memory[col, row] = last_pixel * (255 // col_steps)
return memory

level, hp, explore = self.group_rewards()
full_memory = np.stack((
make_reward_channel(level),
make_reward_channel(hp),
make_reward_channel(explore)
), axis=-1)

if self.get_badges() > 0:
full_memory[:, -1, :] = 255

return full_memory

def create_recent_memory(self):
return rearrange(
self.recent_memory,
'(w h) c -> h w c',
self.recent_memory,
'(w h) c -> h w c',
h=self.memory_height)

def check_if_done(self):
Expand All @@ -347,10 +347,10 @@ def save_and_print_info(self, done, obs_memory):
prog_string += f' {key}: {val:5.2f}'
prog_string += f' sum: {self.total_reward:5.2f}'
print(f'\r{prog_string}', end='', flush=True)

if self.step_count % 50 == 0:
plt.imsave(
self.s_path / Path(f'curframe_{self.instance_id}.jpeg'),
self.s_path / Path(f'curframe_{self.instance_id}.jpeg'),
self.render(reduce_res=False))

if self.print_rewards and done:
Expand All @@ -359,10 +359,10 @@ def save_and_print_info(self, done, obs_memory):
fs_path = self.s_path / Path('final_states')
fs_path.mkdir(exist_ok=True)
plt.imsave(
fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_small.jpeg'),
fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_small.jpeg'),
obs_memory)
plt.imsave(
fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_full.jpeg'),
fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_full.jpeg'),
self.render(reduce_res=False))

if self.save_video and done:
Expand All @@ -375,18 +375,18 @@ def save_and_print_info(self, done, obs_memory):
json.dump(self.all_runs, f)
pd.DataFrame(self.agent_stats).to_csv(
self.s_path / Path(f'agent_stats_{self.instance_id}.csv.gz'), compression='gzip', mode='a')

def read_m(self, addr):
return self.pyboy.get_memory_value(addr)

def read_bit(self, addr, bit: int) -> bool:
# add padding so zero will read '0b100000000' instead of '0b0'
return bin(256 + self.read_m(addr))[-bit-1] == '1'

def get_levels_sum(self):
poke_levels = [max(self.read_m(a) - 2, 0) for a in [0xD18C, 0xD1B8, 0xD1E4, 0xD210, 0xD23C, 0xD268]]
return max(sum(poke_levels) - 4, 0) # subtract starting pokemon level

def get_levels_reward(self):
explore_thresh = 22
scale_factor = 4
Expand All @@ -397,21 +397,21 @@ def get_levels_reward(self):
scaled = (level_sum-explore_thresh) / scale_factor + explore_thresh
self.max_level_rew = max(self.max_level_rew, scaled)
return self.max_level_rew

def get_knn_reward(self):
pre_rew = 0.004
post_rew = 0.01
cur_size = self.knn_index.get_current_count()
base = (self.base_explore if self.levels_satisfied else cur_size) * pre_rew
post = (cur_size if self.levels_satisfied else 0) * post_rew
return base + post

def get_badges(self):
return self.bit_count(self.read_m(0xD356))

def read_party(self):
return [self.read_m(addr) for addr in [0xD164, 0xD165, 0xD166, 0xD167, 0xD168, 0xD169]]

def update_heal_reward(self):
cur_health = self.read_hp_fraction()
if cur_health > self.last_health:
Expand All @@ -423,10 +423,10 @@ def update_heal_reward(self):
self.total_healing_rew += heal_amount * 4
else:
self.died_count += 1

def get_all_events_reward(self):
return max(sum([self.bit_count(self.read_m(i)) for i in range(0xD747, 0xD886)]) - 13, 0)

def get_game_state_reward(self, print_stats=False):
# addresses from https://datacrystal.romhacking.net/wiki/Pok%C3%A9mon_Red/Blue:RAM_map
# https://github.com/pret/pokered/blob/91dc3c9f9c8fd529bb6e8307b58b96efa0bec67e/constants/event_constants.asm
Expand All @@ -436,13 +436,13 @@ def get_game_state_reward(self, print_stats=False):
#money = self.read_money() - 975 # subtract starting money
seen_poke_count = sum([self.bit_count(self.read_m(i)) for i in range(0xD30A, 0xD31D)])
all_events_score = sum([self.bit_count(self.read_m(i)) for i in range(0xD747, 0xD886)])
oak_parcel = self.read_bit(0xD74E, 1)
oak_parcel = self.read_bit(0xD74E, 1)
oak_pokedex = self.read_bit(0xD74B, 5)
opponent_level = self.read_m(0xCFF3)
self.max_opponent_level = max(self.max_opponent_level, opponent_level)
enemy_poke_count = self.read_m(0xD89C)
self.max_opponent_poke = max(self.max_opponent_poke, enemy_poke_count)

if print_stats:
print(f'num_poke : {num_poke}')
print(f'poke_levels : {poke_levels}')
Expand All @@ -451,11 +451,11 @@ def get_game_state_reward(self, print_stats=False):
print(f'seen_poke_count : {seen_poke_count}')
print(f'oak_parcel: {oak_parcel} oak_pokedex: {oak_pokedex} all_events_score: {all_events_score}')
'''

state_scores = {
'event': self.update_max_event_rew(),
'event': self.update_max_event_rew(),
#'party_xp': 0.1*sum(poke_xps),
'level': self.get_levels_reward(),
'level': self.get_levels_reward(),
'heal': self.total_healing_rew,
'op_lvl': self.update_max_op_level(),
'dead': -0.1*self.died_count,
Expand All @@ -465,24 +465,24 @@ def get_game_state_reward(self, print_stats=False):
#'seen_poke': seen_poke_count * 400,
'explore': self.get_knn_reward()
}

return state_scores

def save_screenshot(self, name):
ss_dir = self.s_path / Path('screenshots')
ss_dir.mkdir(exist_ok=True)
plt.imsave(
ss_dir / Path(f'frame{self.instance_id}_r{self.total_reward:.4f}_{self.reset_count}_{name}.jpeg'),
ss_dir / Path(f'frame{self.instance_id}_r{self.total_reward:.4f}_{self.reset_count}_{name}.jpeg'),
self.render(reduce_res=False))

def update_max_op_level(self):
#opponent_level = self.read_m(0xCFE8) - 5 # base level
opponent_level = max([self.read_m(a) for a in [0xD8C5, 0xD8F1, 0xD91D, 0xD949, 0xD975, 0xD9A1]]) - 5
#if opponent_level >= 7:
# self.save_screenshot('highlevelop')
self.max_opponent_level = max(self.max_opponent_level, opponent_level)
return self.max_opponent_level * 0.2

def update_max_event_rew(self):
cur_rew = self.get_all_events_reward()
self.max_event_rew = max(cur_rew, self.max_event_rew)
Expand All @@ -502,11 +502,11 @@ def bit_count(self, bits):

def read_triple(self, start_add):
return 256*256*self.read_m(start_add) + 256*self.read_m(start_add+1) + self.read_m(start_add+2)

def read_bcd(self, num):
return 10 * ((num >> 4) & 0x0f) + (num & 0x0f)

def read_money(self):
return (100 * 100 * self.read_bcd(self.read_m(0xD347)) +
return (100 * 100 * self.read_bcd(self.read_m(0xD347)) +
100 * self.read_bcd(self.read_m(0xD348)) +
self.read_bcd(self.read_m(0xD349)))
2 changes: 1 addition & 1 deletion baselines/ray_exp/train_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

env_config = {
'headless': True, 'save_final_state': True, 'early_stop': False,
'action_freq': 24, 'init_state': '../../has_pokedex_nballs.state', 'max_steps': ep_length,
'action_freq': 24, 'init_state': '../../has_pokedex_nballs.state', 'max_steps': ep_length,
'print_rewards': False, 'save_video': False, 'fast_video': True, 'session_path': sess_path,
'gb_path': '../../PokemonRed.gb', 'debug': False, 'sim_frame_dist': 500_000.0
}
Expand Down
Loading