diff --git a/baselines/ray_exp/red_gym_env_ray.py b/baselines/ray_exp/red_gym_env_ray.py index 8b2dd49f4..89d11dd03 100644 --- a/baselines/ray_exp/red_gym_env_ray.py +++ b/baselines/ray_exp/red_gym_env_ray.py @@ -1,6 +1,6 @@ import sys -import uuid +import uuid import os from math import floor, sqrt import json @@ -116,14 +116,14 @@ def reset(self, *, seed=None, options=None): self.init_knn() self.recent_memory = np.zeros((self.output_shape[1]*self.memory_height, 3), dtype=np.uint8) - + self.recent_frames = np.zeros( - (self.frame_stacks, self.output_shape[0], + (self.frame_stacks, self.output_shape[0], self.output_shape[1], self.output_shape[2]), dtype=np.uint8) self.agent_stats = [] - + if self.save_video: base_dir = self.s_path / Path('rollouts') base_dir.mkdir(exist_ok=True) @@ -133,7 +133,7 @@ def reset(self, *, seed=None, options=None): self.full_frame_writer.__enter__() self.model_frame_writer = media.VideoWriter(base_dir / model_name, self.output_full[:2], fps=60) self.model_frame_writer.__enter__() - + self.levels_satisfied = False self.base_explore = 0 self.max_opponent_level = 0 @@ -147,7 +147,7 @@ def reset(self, *, seed=None, options=None): self.total_reward = sum([val for _, val in self.progress_reward.items()]) self.reset_count += 1 return self.render(add_memory=False), {} - + def init_knn(self): # Declaring index self.knn_index = hnswlib.Index(space='l2', dim=self.vec_dim) # possible options are l2, cosine or ip @@ -163,11 +163,11 @@ def render(self, reduce_res=True, add_memory=True, update_mem=True): self.recent_frames[0] = game_pixels_render if add_memory: pad = np.zeros( - shape=(self.mem_padding, self.output_shape[1], 3), + shape=(self.mem_padding, self.output_shape[1], 3), dtype=np.uint8) game_pixels_render = np.concatenate( ( - self.create_exploration_memory(), + self.create_exploration_memory(), pad, self.create_recent_memory(), pad, @@ -175,7 +175,7 @@ def render(self, reduce_res=True, add_memory=True, update_mem=True): ), axis=0) return game_pixels_render - + def step(self, action): self.run_action_on_emulator(action) @@ -191,11 +191,11 @@ def step(self, action): obs_flat = obs_memory.flatten().astype(np.float32) self.update_frame_knn_index(obs_flat) - + self.update_heal_reward() new_reward, new_prog = self.update_reward() - + self.last_health = self.read_hp_fraction() # shift over short term reward memory @@ -222,7 +222,7 @@ def run_action_on_emulator(self, action): # release arrow self.pyboy.send_input(self.release_arrow[action]) if action > 3 and action < 6: - # release button + # release button self.pyboy.send_input(self.release_button[action - 4]) if action == WindowEvent.PRESS_BUTTON_START: self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_START) @@ -231,18 +231,18 @@ def run_action_on_emulator(self, action): self.pyboy.tick() if self.save_video and self.fast_video: self.add_video_frame() - + def add_video_frame(self): self.full_frame_writer.add_image(self.render(reduce_res=False, update_mem=False)) self.model_frame_writer.add_image(self.render(reduce_res=True, update_mem=False)) - + def append_agent_stats(self): x_pos = self.read_m(0xD362) y_pos = self.read_m(0xD361) map_n = self.read_m(0xD35E) levels = [self.read_m(a) for a in [0xD18C, 0xD1B8, 0xD1E4, 0xD210, 0xD23C, 0xD268]] self.agent_stats.append({ - 'step': self.step_count, 'x': x_pos, 'y': y_pos, 'map': map_n, + 'step': self.step_count, 'x': x_pos, 'y': y_pos, 'map': map_n, 'pcount': self.read_m(0xD163), 'levels': levels, 'ptypes': self.read_party(), 'hp': self.read_hp_fraction(), 'frames': self.knn_index.get_current_count(), @@ -251,7 +251,7 @@ def append_agent_stats(self): }) def update_frame_knn_index(self, frame_vec): - + if self.get_levels_sum() >= 22 and not self.levels_satisfied: self.levels_satisfied = True self.base_explore = self.knn_index.get_current_count() @@ -263,7 +263,7 @@ def update_frame_knn_index(self, frame_vec): frame_vec, np.array([self.knn_index.get_current_count()]) ) else: - # check for nearest frame and add if current + # check for nearest frame and add if current labels, distances = self.knn_index.knn_query(frame_vec, k = 1) if distances[0] > self.similar_frame_dist: self.knn_index.add_items( @@ -280,25 +280,25 @@ def update_reward(self): if new_step < 0 and self.read_hp_fraction() > 0: #print(f'\n\nreward went down! {self.progress_reward}\n\n') self.save_screenshot('neg_reward') - + self.total_reward = new_total - return (new_step, - (new_prog[0]-old_prog[0], - new_prog[1]-old_prog[1], + return (new_step, + (new_prog[0]-old_prog[0], + new_prog[1]-old_prog[1], new_prog[2]-old_prog[2]) ) - + def group_rewards(self): prog = self.progress_reward # these values are only used by memory - return (prog['level'] * 100, self.read_hp_fraction()*2000, prog['explore'] * 160)#(prog['events'], - # prog['levels'] + prog['party_xp'], + return (prog['level'] * 100, self.read_hp_fraction()*2000, prog['explore'] * 160)#(prog['events'], + # prog['levels'] + prog['party_xp'], # prog['explore']) def create_exploration_memory(self): w = self.output_shape[1] h = self.memory_height - + def make_reward_channel(r_val): col_steps = self.col_steps row = floor(r_val / (h * col_steps)) @@ -308,17 +308,17 @@ def make_reward_channel(r_val): col = floor((r_val - row_covered) / col_steps) memory[:col, row] = 255 col_covered = col * col_steps - last_pixel = floor(r_val - row_covered - col_covered) + last_pixel = floor(r_val - row_covered - col_covered) memory[col, row] = last_pixel * (255 // col_steps) return memory - + level, hp, explore = self.group_rewards() full_memory = np.stack(( make_reward_channel(level), make_reward_channel(hp), make_reward_channel(explore) ), axis=-1) - + if self.get_badges() > 0: full_memory[:, -1, :] = 255 @@ -326,8 +326,8 @@ def make_reward_channel(r_val): def create_recent_memory(self): return rearrange( - self.recent_memory, - '(w h) c -> h w c', + self.recent_memory, + '(w h) c -> h w c', h=self.memory_height) def check_if_done(self): @@ -347,10 +347,10 @@ def save_and_print_info(self, done, obs_memory): prog_string += f' {key}: {val:5.2f}' prog_string += f' sum: {self.total_reward:5.2f}' print(f'\r{prog_string}', end='', flush=True) - + if self.step_count % 50 == 0: plt.imsave( - self.s_path / Path(f'curframe_{self.instance_id}.jpeg'), + self.s_path / Path(f'curframe_{self.instance_id}.jpeg'), self.render(reduce_res=False)) if self.print_rewards and done: @@ -359,10 +359,10 @@ def save_and_print_info(self, done, obs_memory): fs_path = self.s_path / Path('final_states') fs_path.mkdir(exist_ok=True) plt.imsave( - fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_small.jpeg'), + fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_small.jpeg'), obs_memory) plt.imsave( - fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_full.jpeg'), + fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_full.jpeg'), self.render(reduce_res=False)) if self.save_video and done: @@ -375,18 +375,18 @@ def save_and_print_info(self, done, obs_memory): json.dump(self.all_runs, f) pd.DataFrame(self.agent_stats).to_csv( self.s_path / Path(f'agent_stats_{self.instance_id}.csv.gz'), compression='gzip', mode='a') - + def read_m(self, addr): return self.pyboy.get_memory_value(addr) def read_bit(self, addr, bit: int) -> bool: # add padding so zero will read '0b100000000' instead of '0b0' return bin(256 + self.read_m(addr))[-bit-1] == '1' - + def get_levels_sum(self): poke_levels = [max(self.read_m(a) - 2, 0) for a in [0xD18C, 0xD1B8, 0xD1E4, 0xD210, 0xD23C, 0xD268]] return max(sum(poke_levels) - 4, 0) # subtract starting pokemon level - + def get_levels_reward(self): explore_thresh = 22 scale_factor = 4 @@ -397,7 +397,7 @@ def get_levels_reward(self): scaled = (level_sum-explore_thresh) / scale_factor + explore_thresh self.max_level_rew = max(self.max_level_rew, scaled) return self.max_level_rew - + def get_knn_reward(self): pre_rew = 0.004 post_rew = 0.01 @@ -405,13 +405,13 @@ def get_knn_reward(self): base = (self.base_explore if self.levels_satisfied else cur_size) * pre_rew post = (cur_size if self.levels_satisfied else 0) * post_rew return base + post - + def get_badges(self): return self.bit_count(self.read_m(0xD356)) def read_party(self): return [self.read_m(addr) for addr in [0xD164, 0xD165, 0xD166, 0xD167, 0xD168, 0xD169]] - + def update_heal_reward(self): cur_health = self.read_hp_fraction() if cur_health > self.last_health: @@ -423,10 +423,10 @@ def update_heal_reward(self): self.total_healing_rew += heal_amount * 4 else: self.died_count += 1 - + def get_all_events_reward(self): return max(sum([self.bit_count(self.read_m(i)) for i in range(0xD747, 0xD886)]) - 13, 0) - + def get_game_state_reward(self, print_stats=False): # addresses from https://datacrystal.romhacking.net/wiki/Pok%C3%A9mon_Red/Blue:RAM_map # https://github.com/pret/pokered/blob/91dc3c9f9c8fd529bb6e8307b58b96efa0bec67e/constants/event_constants.asm @@ -436,13 +436,13 @@ def get_game_state_reward(self, print_stats=False): #money = self.read_money() - 975 # subtract starting money seen_poke_count = sum([self.bit_count(self.read_m(i)) for i in range(0xD30A, 0xD31D)]) all_events_score = sum([self.bit_count(self.read_m(i)) for i in range(0xD747, 0xD886)]) - oak_parcel = self.read_bit(0xD74E, 1) + oak_parcel = self.read_bit(0xD74E, 1) oak_pokedex = self.read_bit(0xD74B, 5) opponent_level = self.read_m(0xCFF3) self.max_opponent_level = max(self.max_opponent_level, opponent_level) enemy_poke_count = self.read_m(0xD89C) self.max_opponent_poke = max(self.max_opponent_poke, enemy_poke_count) - + if print_stats: print(f'num_poke : {num_poke}') print(f'poke_levels : {poke_levels}') @@ -451,11 +451,11 @@ def get_game_state_reward(self, print_stats=False): print(f'seen_poke_count : {seen_poke_count}') print(f'oak_parcel: {oak_parcel} oak_pokedex: {oak_pokedex} all_events_score: {all_events_score}') ''' - + state_scores = { - 'event': self.update_max_event_rew(), + 'event': self.update_max_event_rew(), #'party_xp': 0.1*sum(poke_xps), - 'level': self.get_levels_reward(), + 'level': self.get_levels_reward(), 'heal': self.total_healing_rew, 'op_lvl': self.update_max_op_level(), 'dead': -0.1*self.died_count, @@ -465,16 +465,16 @@ def get_game_state_reward(self, print_stats=False): #'seen_poke': seen_poke_count * 400, 'explore': self.get_knn_reward() } - + return state_scores - + def save_screenshot(self, name): ss_dir = self.s_path / Path('screenshots') ss_dir.mkdir(exist_ok=True) plt.imsave( - ss_dir / Path(f'frame{self.instance_id}_r{self.total_reward:.4f}_{self.reset_count}_{name}.jpeg'), + ss_dir / Path(f'frame{self.instance_id}_r{self.total_reward:.4f}_{self.reset_count}_{name}.jpeg'), self.render(reduce_res=False)) - + def update_max_op_level(self): #opponent_level = self.read_m(0xCFE8) - 5 # base level opponent_level = max([self.read_m(a) for a in [0xD8C5, 0xD8F1, 0xD91D, 0xD949, 0xD975, 0xD9A1]]) - 5 @@ -482,7 +482,7 @@ def update_max_op_level(self): # self.save_screenshot('highlevelop') self.max_opponent_level = max(self.max_opponent_level, opponent_level) return self.max_opponent_level * 0.2 - + def update_max_event_rew(self): cur_rew = self.get_all_events_reward() self.max_event_rew = max(cur_rew, self.max_event_rew) @@ -502,11 +502,11 @@ def bit_count(self, bits): def read_triple(self, start_add): return 256*256*self.read_m(start_add) + 256*self.read_m(start_add+1) + self.read_m(start_add+2) - + def read_bcd(self, num): return 10 * ((num >> 4) & 0x0f) + (num & 0x0f) - + def read_money(self): - return (100 * 100 * self.read_bcd(self.read_m(0xD347)) + + return (100 * 100 * self.read_bcd(self.read_m(0xD347)) + 100 * self.read_bcd(self.read_m(0xD348)) + self.read_bcd(self.read_m(0xD349))) diff --git a/baselines/ray_exp/train_ray.py b/baselines/ray_exp/train_ray.py index 1221ef0f1..a2c2e8614 100644 --- a/baselines/ray_exp/train_ray.py +++ b/baselines/ray_exp/train_ray.py @@ -9,7 +9,7 @@ env_config = { 'headless': True, 'save_final_state': True, 'early_stop': False, - 'action_freq': 24, 'init_state': '../../has_pokedex_nballs.state', 'max_steps': ep_length, + 'action_freq': 24, 'init_state': '../../has_pokedex_nballs.state', 'max_steps': ep_length, 'print_rewards': False, 'save_video': False, 'fast_video': True, 'session_path': sess_path, 'gb_path': '../../PokemonRed.gb', 'debug': False, 'sim_frame_dist': 500_000.0 } diff --git a/baselines/red_gym_env.py b/baselines/red_gym_env.py index 1b667ff9c..4f8910dd2 100644 --- a/baselines/red_gym_env.py +++ b/baselines/red_gym_env.py @@ -1,6 +1,6 @@ import sys -import uuid +import uuid import os from math import floor, sqrt import json @@ -63,7 +63,7 @@ def __init__( WindowEvent.PRESS_BUTTON_A, WindowEvent.PRESS_BUTTON_B, ] - + if self.extra_buttons: self.valid_actions.extend([ WindowEvent.PRESS_BUTTON_START, @@ -111,7 +111,7 @@ def __init__( if not config['headless']: self.pyboy.set_emulation_speed(6) - + self.reset() def reset(self, seed=None): @@ -119,21 +119,21 @@ def reset(self, seed=None): # restart game, skipping credits with open(self.init_state, "rb") as f: self.pyboy.load_state(f) - + if self.use_screen_explore: self.init_knn() else: self.init_map_mem() self.recent_memory = np.zeros((self.output_shape[1]*self.memory_height, 3), dtype=np.uint8) - + self.recent_frames = np.zeros( - (self.frame_stacks, self.output_shape[0], + (self.frame_stacks, self.output_shape[0], self.output_shape[1], self.output_shape[2]), dtype=np.uint8) self.agent_stats = [] - + if self.save_video: base_dir = self.s_path / Path('rollouts') base_dir.mkdir(exist_ok=True) @@ -143,7 +143,7 @@ def reset(self, seed=None): self.full_frame_writer.__enter__() self.model_frame_writer = media.VideoWriter(base_dir / model_name, self.output_full[:2], fps=60) self.model_frame_writer.__enter__() - + self.levels_satisfied = False self.base_explore = 0 self.max_opponent_level = 0 @@ -158,14 +158,14 @@ def reset(self, seed=None): self.total_reward = sum([val for _, val in self.progress_reward.items()]) self.reset_count += 1 return self.render(), {} - + def init_knn(self): # Declaring index self.knn_index = hnswlib.Index(space='l2', dim=self.vec_dim) # possible options are l2, cosine or ip # Initing index - the maximum number of elements should be known beforehand self.knn_index.init_index( max_elements=self.num_elements, ef_construction=100, M=16) - + def init_map_mem(self): self.seen_coords = {} @@ -177,11 +177,11 @@ def render(self, reduce_res=True, add_memory=True, update_mem=True): self.recent_frames[0] = game_pixels_render if add_memory: pad = np.zeros( - shape=(self.mem_padding, self.output_shape[1], 3), + shape=(self.mem_padding, self.output_shape[1], 3), dtype=np.uint8) game_pixels_render = np.concatenate( ( - self.create_exploration_memory(), + self.create_exploration_memory(), pad, self.create_recent_memory(), pad, @@ -189,7 +189,7 @@ def render(self, reduce_res=True, add_memory=True, update_mem=True): ), axis=0) return game_pixels_render - + def step(self, action): self.run_action_on_emulator(action) @@ -207,12 +207,12 @@ def step(self, action): self.update_frame_knn_index(obs_flat) else: self.update_seen_coords() - + self.update_heal_reward() self.party_size = self.read_m(0xD163) new_reward, new_prog = self.update_reward() - + self.last_health = self.read_hp_fraction() # shift over short term reward memory @@ -242,7 +242,7 @@ def run_action_on_emulator(self, action): # release arrow self.pyboy.send_input(self.release_arrow[action]) if action > 3 and action < 6: - # release button + # release button self.pyboy.send_input(self.release_button[action - 4]) if self.valid_actions[action] == WindowEvent.PRESS_BUTTON_START: self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_START) @@ -253,11 +253,11 @@ def run_action_on_emulator(self, action): self.pyboy.tick() if self.save_video and self.fast_video: self.add_video_frame() - + def add_video_frame(self): self.full_frame_writer.add_image(self.render(reduce_res=False, update_mem=False)) self.model_frame_writer.add_image(self.render(reduce_res=True, update_mem=False)) - + def append_agent_stats(self, action): x_pos = self.read_m(0xD362) y_pos = self.read_m(0xD361) @@ -271,8 +271,8 @@ def append_agent_stats(self, action): 'step': self.step_count, 'x': x_pos, 'y': y_pos, 'map': map_n, 'map_location': self.get_map_location(map_n), 'last_action': action, - 'pcount': self.read_m(0xD163), - 'levels': levels, + 'pcount': self.read_m(0xD163), + 'levels': levels, 'levels_sum': sum(levels), 'ptypes': self.read_party(), 'hp': self.read_hp_fraction(), @@ -282,7 +282,7 @@ def append_agent_stats(self, action): }) def update_frame_knn_index(self, frame_vec): - + if self.get_levels_sum() >= 22 and not self.levels_satisfied: self.levels_satisfied = True self.base_explore = self.knn_index.get_current_count() @@ -294,14 +294,14 @@ def update_frame_knn_index(self, frame_vec): frame_vec, np.array([self.knn_index.get_current_count()]) ) else: - # check for nearest frame and add if current + # check for nearest frame and add if current labels, distances = self.knn_index.knn_query(frame_vec, k = 1) if distances[0][0] > self.similar_frame_dist: # print(f"distances[0][0] : {distances[0][0]} similar_frame_dist : {self.similar_frame_dist}") self.knn_index.add_items( frame_vec, np.array([self.knn_index.get_current_count()]) ) - + def update_seen_coords(self): x_pos = self.read_m(0xD362) y_pos = self.read_m(0xD361) @@ -311,7 +311,7 @@ def update_seen_coords(self): self.levels_satisfied = True self.base_explore = len(self.seen_coords) self.seen_coords = {} - + self.seen_coords[coord_string] = self.step_count def update_reward(self): @@ -324,28 +324,28 @@ def update_reward(self): if new_step < 0 and self.read_hp_fraction() > 0: #print(f'\n\nreward went down! {self.progress_reward}\n\n') self.save_screenshot('neg_reward') - + self.total_reward = new_total - return (new_step, - (new_prog[0]-old_prog[0], - new_prog[1]-old_prog[1], + return (new_step, + (new_prog[0]-old_prog[0], + new_prog[1]-old_prog[1], new_prog[2]-old_prog[2]) ) - + def group_rewards(self): prog = self.progress_reward # these values are only used by memory - return (prog['level'] * 100 / self.reward_scale, - self.read_hp_fraction()*2000, + return (prog['level'] * 100 / self.reward_scale, + self.read_hp_fraction()*2000, prog['explore'] * 150 / (self.explore_weight * self.reward_scale)) - #(prog['events'], - # prog['levels'] + prog['party_xp'], + #(prog['events'], + # prog['levels'] + prog['party_xp'], # prog['explore']) def create_exploration_memory(self): w = self.output_shape[1] h = self.memory_height - + def make_reward_channel(r_val): col_steps = self.col_steps max_r_val = (w-1) * h * col_steps @@ -359,17 +359,17 @@ def make_reward_channel(r_val): col = floor((r_val - row_covered) / col_steps) memory[:col, row] = 255 col_covered = col * col_steps - last_pixel = floor(r_val - row_covered - col_covered) + last_pixel = floor(r_val - row_covered - col_covered) memory[col, row] = last_pixel * (255 // col_steps) return memory - + level, hp, explore = self.group_rewards() full_memory = np.stack(( make_reward_channel(level), make_reward_channel(hp), make_reward_channel(explore) ), axis=-1) - + if self.get_badges() > 0: full_memory[:, -1, :] = 255 @@ -377,8 +377,8 @@ def make_reward_channel(r_val): def create_recent_memory(self): return rearrange( - self.recent_memory, - '(w h) c -> h w c', + self.recent_memory, + '(w h) c -> h w c', h=self.memory_height) def check_if_done(self): @@ -398,10 +398,10 @@ def save_and_print_info(self, done, obs_memory): prog_string += f' {key}: {val:5.2f}' prog_string += f' sum: {self.total_reward:5.2f}' print(f'\r{prog_string}', end='', flush=True) - + if self.step_count % 50 == 0: plt.imsave( - self.s_path / Path(f'curframe_{self.instance_id}.jpeg'), + self.s_path / Path(f'curframe_{self.instance_id}.jpeg'), self.render(reduce_res=False)) if self.print_rewards and done: @@ -410,10 +410,10 @@ def save_and_print_info(self, done, obs_memory): fs_path = self.s_path / Path('final_states') fs_path.mkdir(exist_ok=True) plt.imsave( - fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_small.jpeg'), + fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_small.jpeg'), obs_memory) plt.imsave( - fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_full.jpeg'), + fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_full.jpeg'), self.render(reduce_res=False)) if self.save_video and done: @@ -426,18 +426,18 @@ def save_and_print_info(self, done, obs_memory): json.dump(self.all_runs, f) pd.DataFrame(self.agent_stats).to_csv( self.s_path / Path(f'agent_stats_{self.instance_id}.csv.gz'), compression='gzip', mode='a') - + def read_m(self, addr): return self.pyboy.get_memory_value(addr) def read_bit(self, addr, bit: int) -> bool: # add padding so zero will read '0b100000000' instead of '0b0' return bin(256 + self.read_m(addr))[-bit-1] == '1' - + def get_levels_sum(self): poke_levels = [max(self.read_m(a) - 2, 0) for a in [0xD18C, 0xD1B8, 0xD1E4, 0xD210, 0xD23C, 0xD268]] return max(sum(poke_levels) - 4, 0) # subtract starting pokemon level - + def get_levels_reward(self): explore_thresh = 22 scale_factor = 4 @@ -448,22 +448,22 @@ def get_levels_reward(self): scaled = (level_sum-explore_thresh) / scale_factor + explore_thresh self.max_level_rew = max(self.max_level_rew, scaled) return self.max_level_rew - + def get_knn_reward(self): - + pre_rew = self.explore_weight * 0.005 post_rew = self.explore_weight * 0.01 cur_size = self.knn_index.get_current_count() if self.use_screen_explore else len(self.seen_coords) base = (self.base_explore if self.levels_satisfied else cur_size) * pre_rew post = (cur_size if self.levels_satisfied else 0) * post_rew return base + post - + def get_badges(self): return self.bit_count(self.read_m(0xD356)) def read_party(self): return [self.read_m(addr) for addr in [0xD164, 0xD165, 0xD166, 0xD167, 0xD168, 0xD169]] - + def update_heal_reward(self): cur_health = self.read_hp_fraction() # if health increased and party size did not change @@ -477,7 +477,7 @@ def update_heal_reward(self): self.total_healing_rew += heal_amount * 4 else: self.died_count += 1 - + def get_all_events_reward(self): # adds up all event flags, exclude museum ticket event_flags_start = 0xD747 @@ -505,13 +505,13 @@ def get_game_state_reward(self, print_stats=False): #money = self.read_money() - 975 # subtract starting money seen_poke_count = sum([self.bit_count(self.read_m(i)) for i in range(0xD30A, 0xD31D)]) all_events_score = sum([self.bit_count(self.read_m(i)) for i in range(0xD747, 0xD886)]) - oak_parcel = self.read_bit(0xD74E, 1) + oak_parcel = self.read_bit(0xD74E, 1) oak_pokedex = self.read_bit(0xD74B, 5) opponent_level = self.read_m(0xCFF3) self.max_opponent_level = max(self.max_opponent_level, opponent_level) enemy_poke_count = self.read_m(0xD89C) self.max_opponent_poke = max(self.max_opponent_poke, enemy_poke_count) - + if print_stats: print(f'num_poke : {num_poke}') print(f'poke_levels : {poke_levels}') @@ -520,11 +520,11 @@ def get_game_state_reward(self, print_stats=False): print(f'seen_poke_count : {seen_poke_count}') print(f'oak_parcel: {oak_parcel} oak_pokedex: {oak_pokedex} all_events_score: {all_events_score}') ''' - + state_scores = { - 'event': self.reward_scale*self.update_max_event_rew(), + 'event': self.reward_scale*self.update_max_event_rew(), #'party_xp': self.reward_scale*0.1*sum(poke_xps), - 'level': self.reward_scale*self.get_levels_reward(), + 'level': self.reward_scale*self.get_levels_reward(), 'heal': self.reward_scale*self.total_healing_rew, 'op_lvl': self.reward_scale*self.update_max_op_level(), 'dead': self.reward_scale*-0.1*self.died_count, @@ -534,16 +534,16 @@ def get_game_state_reward(self, print_stats=False): #'seen_poke': self.reward_scale * seen_poke_count * 400, 'explore': self.reward_scale * self.get_knn_reward() } - + return state_scores - + def save_screenshot(self, name): ss_dir = self.s_path / Path('screenshots') ss_dir.mkdir(exist_ok=True) plt.imsave( - ss_dir / Path(f'frame{self.instance_id}_r{self.total_reward:.4f}_{self.reset_count}_{name}.jpeg'), + ss_dir / Path(f'frame{self.instance_id}_r{self.total_reward:.4f}_{self.reset_count}_{name}.jpeg'), self.render(reduce_res=False)) - + def update_max_op_level(self): #opponent_level = self.read_m(0xCFE8) - 5 # base level opponent_level = max([self.read_m(a) for a in [0xD8C5, 0xD8F1, 0xD91D, 0xD949, 0xD975, 0xD9A1]]) - 5 @@ -551,7 +551,7 @@ def update_max_op_level(self): # self.save_screenshot('highlevelop') self.max_opponent_level = max(self.max_opponent_level, opponent_level) return self.max_opponent_level * 0.2 - + def update_max_event_rew(self): cur_rew = self.get_all_events_reward() self.max_event_rew = max(cur_rew, self.max_event_rew) @@ -572,12 +572,12 @@ def bit_count(self, bits): def read_triple(self, start_add): return 256*256*self.read_m(start_add) + 256*self.read_m(start_add+1) + self.read_m(start_add+2) - + def read_bcd(self, num): return 10 * ((num >> 4) & 0x0f) + (num & 0x0f) - + def read_money(self): - return (100 * 100 * self.read_bcd(self.read_m(0xD347)) + + return (100 * 100 * self.read_bcd(self.read_m(0xD347)) + 100 * self.read_bcd(self.read_m(0xD348)) + self.read_bcd(self.read_m(0xD349))) @@ -621,4 +621,3 @@ def get_map_location(self, map_idx): return map_locations[map_idx] else: return "Unknown Location" - diff --git a/baselines/render_all_needed_grids.py b/baselines/render_all_needed_grids.py index 269fc56e2..83e76c28a 100644 --- a/baselines/render_all_needed_grids.py +++ b/baselines/render_all_needed_grids.py @@ -30,7 +30,7 @@ def run_save(save): sess_path = f'grid_renders/session_{save.stem}' env_config = { 'headless': True, 'save_final_state': True, 'early_stop': False, - 'action_freq': 24, 'init_state': '../has_pokedex_nballs.state', 'max_steps': ep_length, + 'action_freq': 24, 'init_state': '../has_pokedex_nballs.state', 'max_steps': ep_length, 'print_rewards': True, 'save_video': True, 'fast_video': False, 'session_path': sess_path, 'gb_path': '../PokemonRed.gb', 'debug': False, 'sim_frame_dist': 2_000_000.0 } @@ -65,10 +65,10 @@ def run_save(save): if __name__ == '__main__': run_save(sys.argv[1]) - + # all_saves = list(Path('session_4da05e87').glob('*.zip')) # selected_saves = [Path('session_4da05e87/init')] + all_saves[:10] + all_saves[10:120:5] + all_saves[120:420:10] # len(selected_saves) - + # for idx, save in enumerate(selected_saves): - + diff --git a/baselines/run_baseline_parallel.py b/baselines/run_baseline_parallel.py index f4423a3a5..df271dc8a 100644 --- a/baselines/run_baseline_parallel.py +++ b/baselines/run_baseline_parallel.py @@ -31,22 +31,22 @@ def _init(): env_config = { 'headless': True, 'save_final_state': True, 'early_stop': False, - 'action_freq': 24, 'init_state': '../has_pokedex_nballs.state', 'max_steps': ep_length, + 'action_freq': 24, 'init_state': '../has_pokedex_nballs.state', 'max_steps': ep_length, 'print_rewards': True, 'save_video': False, 'fast_video': True, 'session_path': sess_path, - 'gb_path': '../PokemonRed.gb', 'debug': False, 'sim_frame_dist': 2_000_000.0, + 'gb_path': '../PokemonRed.gb', 'debug': False, 'sim_frame_dist': 2_000_000.0, 'use_screen_explore': True, 'extra_buttons': False } - - + + num_cpu = 44 #64 #46 # Also sets the number of episodes per training iteration env = SubprocVecEnv([make_env(i, env_config) for i in range(num_cpu)]) - + checkpoint_callback = CheckpointCallback(save_freq=ep_length, save_path=sess_path, name_prefix='poke') #env_checker.check_env(env) learn_steps = 40 file_name = 'session_e41c9eff/poke_38207488_steps' #'session_e41c9eff/poke_250871808_steps' - + #'session_bfdca25a/poke_42532864_steps' #'session_d3033abb/poke_47579136_steps' #'session_a17cc1f5/poke_33546240_steps' #'session_e4bdca71/poke_8945664_steps' #'session_eb21989e/poke_40255488_steps' #'session_80f70ab4/poke_58982400_steps' if exists(file_name + '.zip'): print('\nloading checkpoint') @@ -58,6 +58,6 @@ def _init(): model.rollout_buffer.reset() else: model = PPO('CnnPolicy', env, verbose=1, n_steps=ep_length, batch_size=512, n_epochs=1, gamma=0.999) - + for i in range(learn_steps): model.learn(total_timesteps=(ep_length)*num_cpu*1000, callback=checkpoint_callback) diff --git a/baselines/run_baseline_parallel_fast.py b/baselines/run_baseline_parallel_fast.py index de82739d9..55a4ae76b 100644 --- a/baselines/run_baseline_parallel_fast.py +++ b/baselines/run_baseline_parallel_fast.py @@ -39,9 +39,9 @@ def _init(): 'use_screen_explore': True, 'reward_scale': 4, 'extra_buttons': False, 'explore_weight': 3 # 2.5 } - + print(env_config) - + num_cpu = 16 # Also sets the number of episodes per training iteration env = SubprocVecEnv([make_env(i, env_config) for i in range(num_cpu)]) @@ -57,8 +57,8 @@ def _init(): project="pokemon-train", id=sess_id, config=env_config, - sync_tensorboard=True, - monitor_gym=True, + sync_tensorboard=True, + monitor_gym=True, save_code=True, ) callbacks.append(WandbCallback()) @@ -66,8 +66,8 @@ def _init(): #env_checker.check_env(env) learn_steps = 40 # put a checkpoint here you want to start from - file_name = 'session_e41c9eff/poke_38207488_steps' - + file_name = 'session_e41c9eff/poke_38207488_steps' + if exists(file_name + '.zip'): print('\nloading checkpoint') model = PPO.load(file_name, env=env) diff --git a/baselines/run_pretrained_interactive.py b/baselines/run_pretrained_interactive.py index 2c2671332..21e5693f8 100644 --- a/baselines/run_pretrained_interactive.py +++ b/baselines/run_pretrained_interactive.py @@ -30,20 +30,20 @@ def _init(): env_config = { 'headless': False, 'save_final_state': True, 'early_stop': False, - 'action_freq': 24, 'init_state': '../has_pokedex_nballs.state', 'max_steps': ep_length, + 'action_freq': 24, 'init_state': '../has_pokedex_nballs.state', 'max_steps': ep_length, 'print_rewards': True, 'save_video': False, 'fast_video': True, 'session_path': sess_path, 'gb_path': '../PokemonRed.gb', 'debug': False, 'sim_frame_dist': 2_000_000.0, 'extra_buttons': True } - + num_cpu = 1 #64 #46 # Also sets the number of episodes per training iteration env = make_env(0, env_config)() #SubprocVecEnv([make_env(i, env_config) for i in range(num_cpu)]) - + #env_checker.check_env(env) file_name = 'session_4da05e87_main_good/poke_439746560_steps' - + print('\nloading checkpoint') model = PPO.load(file_name, env=env, custom_objects={'lr_schedule': 0, 'clip_range': 0}) - + #keyboard.on_press_key("M", toggle_agent) obs, info = env.reset() while True: diff --git a/baselines/run_recorded_actions.py b/baselines/run_recorded_actions.py index 31d2d033b..b222c4383 100644 --- a/baselines/run_recorded_actions.py +++ b/baselines/run_recorded_actions.py @@ -6,14 +6,14 @@ def run_recorded_actions_on_emulator_and_save_video(sess_id, instance_id, run_index): sess_path = Path(f'session_{sess_id}') tdf = pd.read_csv(f"session_{sess_id}/agent_stats_{instance_id}.csv.gz", compression='gzip') - tdf = tdf[tdf['map'] != 'map'] # remove unused + tdf = tdf[tdf['map'] != 'map'] # remove unused action_arrays = np.array_split(tdf, np.array((tdf["step"].astype(int) == 0).sum())) action_list = [int(x) for x in list(action_arrays[run_index]["last_action"])] max_steps = len(action_list) - 1 env_config = { 'headless': True, 'save_final_state': True, 'early_stop': False, - 'action_freq': 24, 'init_state': '../has_pokedex_nballs.state', 'max_steps': max_steps, #ep_length, + 'action_freq': 24, 'init_state': '../has_pokedex_nballs.state', 'max_steps': max_steps, #ep_length, 'print_rewards': False, 'save_video': True, 'fast_video': False, 'session_path': sess_path, 'gb_path': '../PokemonRed.gb', 'debug': False, 'sim_frame_dist': 2_000_000.0, 'instance_id': f'{instance_id}_recorded' } diff --git a/baselines/tensorboard_callback.py b/baselines/tensorboard_callback.py index 1371e2eec..fb6487ec2 100644 --- a/baselines/tensorboard_callback.py +++ b/baselines/tensorboard_callback.py @@ -9,7 +9,7 @@ def merge_dicts_by_mean(dicts): for d in dicts: for k, v in d.items(): - if isinstance(v, (int, float)): + if isinstance(v, (int, float)): sum_dict[k] = sum_dict.get(k, 0) + v count_dict[k] = count_dict.get(k, 0) + 1 @@ -25,14 +25,14 @@ def __init__(self, verbose=0): super().__init__(verbose) def _on_step(self) -> bool: - + if self.training_env.env_method("check_if_done", indices=[0])[0]: all_infos = self.training_env.get_attr("agent_stats") all_final_infos = [stats[-1] for stats in all_infos] mean_infos = merge_dicts_by_mean(all_final_infos) for key,val in mean_infos.items(): self.logger.record(f"env_stats/{key}", val) - + images = self.training_env.env_method("render") # use reduce_res=False for full res screens images_arr = np.array(images) images_row = rearrange(images_arr, "b h w c -> h (b w) c") diff --git a/baselines/tile_vids_to_grid.py b/baselines/tile_vids_to_grid.py index 531a700c9..33f7b7531 100644 --- a/baselines/tile_vids_to_grid.py +++ b/baselines/tile_vids_to_grid.py @@ -44,9 +44,9 @@ def run_ffmpeg_grid(out_path, files, screen_res_str, full_res_string, gx, gy, sh cmd.append("-t") cmd.append("10") cmd.append(str(out_path.resolve())) - + #-f matroska - - + #proc = subprocess.Popen(cmd) ''' while True: @@ -54,21 +54,21 @@ def run_ffmpeg_grid(out_path, files, screen_res_str, full_res_string, gx, gy, sh if not line: break print(line) ''' - + return ' '.join(cmd) - + def make_script(path): sess_dir = path print(f"generating grid script for {sess_dir.name}") rollout_dir = sess_dir / "rollouts" all_files = list(rollout_dir.glob("full_reset_1*.mp4")) return run_ffmpeg_grid( - (sess_dir / sess_dir.name).with_suffix('.mp4'), all_files, + (sess_dir / sess_dir.name).with_suffix('.mp4'), all_files, "160x144", "1280x720", 8, 5, short_test=False) def make_outer_script(out_file, paths): return run_ffmpeg_grid( - out_file, paths, + out_file, paths, "1280x720", "10240x5760", 8, 8, short_test=False) def write_file(out_file, script): diff --git a/visualization/BetterMapVis_script_version.py b/visualization/BetterMapVis_script_version.py index adfc1b19a..06fcd7eeb 100644 --- a/visualization/BetterMapVis_script_version.py +++ b/visualization/BetterMapVis_script_version.py @@ -28,7 +28,7 @@ def get_sprite_by_coords(img, x, y): def game_coord_to_pixel_coord( x, y, map_idx, base_y): - + global_offset = np.array([1056-16*12, 331]) #np.array([790, -29]) map_offsets = { # https://bulbapedia.bulbagarden.net/wiki/List_of_locations_by_index_number_(Generation_I) @@ -87,7 +87,7 @@ def add_sprite(overlay_map, sprite, coord): else: intermediate[mask] = sprite[mask] overlay_map[coord[1]:coord[1]+16, coord[0]:coord[0]+16, :] = intermediate - + def blend_overlay(background, over): al = over[...,3].reshape(over.shape[0], over.shape[1], 1) ba = (255-al)/255 @@ -102,7 +102,7 @@ def render_video(fname, all_coords, walks, bg, inter_steps=4, add_start=True): errors = [] sprites_rendered = 0 with media.VideoWriter( - f'{fname}.mov', split(bg).shape[:2], codec='prores_ks', + f'{fname}.mov', split(bg).shape[:2], codec='prores_ks', encoded_format='yuva444p', input_format='rgba', fps=60 ) as wr: step_count = len(all_coords) @@ -178,12 +178,12 @@ def test_render(name, dat, walks, bg): ) if __name__ == '__main__': - + run_dir = Path('baselines/session_4da05e87') # Path('baselines/session_ebdfe818') # original session_e41c9eff, main session_4da05e87, extra session_e1b6d2dc - + coords_save_pth = Path('base_coords.npz') - + if coords_save_pth.is_file(): print(f'{coords_save_pth} found, loading from file') base_coords = np.load(coords_save_pth)['arr_0'] @@ -197,14 +197,14 @@ def test_render(name, dat, walks, bg): base_coords = make_all_coords_arrays(dfs) print(f'saving {coords_save_pth}') np.savez_compressed(coords_save_pth, base_coords) - + print(f'initial data shape: {base_coords.shape}') main_map = np.array(Image.open('poke_map/pokemap_full_calibrated_CROPPED_1.png')) chars_img = np.array(Image.open('poke_map/characters.png')) alpha_val = get_sprite_by_coords(chars_img, 1, 0)[0,0] walks = [get_sprite_by_coords(chars_img, x, 0) for x in [1, 4, 6, 8]] - + start_bg = main_map.copy() procs = 16 @@ -215,7 +215,7 @@ def test_render(name, dat, walks, bg): runs = base_data.shape[0] #base_data.shape[1] chunk_size = runs // procs all_render_errors = p.starmap( - test_render, + test_render, #[(f'test_run_p{i}', base_data[:, chunk_size*i:chunk_size*(i+1)], walks, start_bg) for i in range(procs)]) [(f'vids_run1/test_run_p{i}', base_data[chunk_size*i:chunk_size*(i+1)+5], walks, start_bg) for i in range(procs)]) - + diff --git a/visualization/BetterMapVis_script_version_FLOW.py b/visualization/BetterMapVis_script_version_FLOW.py index 274c63834..6ff061eb9 100644 --- a/visualization/BetterMapVis_script_version_FLOW.py +++ b/visualization/BetterMapVis_script_version_FLOW.py @@ -26,7 +26,7 @@ def get_sprite_by_coords(img, x, y): def game_coord_to_global_coord( x, y, map_idx): - + global_offset = np.array([1056-16*12, 331]) #np.array([790, -29]) map_offsets = { # https://bulbapedia.bulbagarden.net/wiki/List_of_locations_by_index_number_(Generation_I) @@ -85,7 +85,7 @@ def add_sprite(overlay_map, sprite, coord): else: intermediate[mask] = sprite[mask] overlay_map[coord[1]:coord[1]+16, coord[0]:coord[0]+16, :] = intermediate - + def blend_overlay(background, over): al = over[...,3].reshape(over.shape[0], over.shape[1], 1) ba = (255-al)/255 @@ -165,7 +165,7 @@ def compute_flow(all_coords, inter_steps=1, add_start=True): else: sprites_rendered += 1 pbar.set_description(f"draws: {sprites_rendered}") - + return all_flows def render_arrows(fname, all_flows, arrow_sprite): @@ -176,10 +176,10 @@ def render_arrows(fname, all_flows, arrow_sprite): max_y = max([k[1] for k in all_flows.keys()]) grid_dims = (max_x - min_x, max_y - min_y) cell_dim = arrow_sprite.size[0] # use x only, assuming square - + #colmap = matplotlib.cm.get_cmap('husl') colmap = seaborn.husl_palette(h=0.1, s=0.95, l=0.75, as_cmap=True) - + full_img = np.zeros( ((grid_dims[0]+1) * cell_dim, (grid_dims[1]+1) * cell_dim, 4 ), dtype=np.uint8) for coord, total_move in tqdm(all_flows.items()): angle = math.atan2(-total_move[0], total_move[1]) @@ -190,13 +190,13 @@ def render_arrows(fname, all_flows, arrow_sprite): #color = hsv2rgb(np.array([0.5*angle/math.pi+0.5, 1.0, 1.0])) color = colmap(0.5*angle/math.pi+0.5) full_img[ - nx * cell_dim : (nx + 1) * cell_dim, + nx * cell_dim : (nx + 1) * cell_dim, ny * cell_dim : (ny + 1) * cell_dim ] = np.array(rotated_arrow) * np.array([color[0], color[1], color[2], 1.0]) print("Writing file") final_img = Image.fromarray(full_img) final_img.save(f"{fname}.png") - + ''' print("generating coords") fig, ax = plt.subplots(figsize=grid_dims) @@ -211,8 +211,8 @@ def render_arrows(fname, all_flows, arrow_sprite): cols.append(mag) print("rendering") ax.quiver( - [k[0] for k in all_flows.keys()], - [k[1] for k in all_flows.keys()], + [k[0] for k in all_flows.keys()], + [k[1] for k in all_flows.keys()], u, v, cols ) @@ -224,7 +224,7 @@ def render_arrows(fname, all_flows, arrow_sprite): print("saving") plt.savefig(f"{fname}.png") ''' - + def compute_flow_wrap(dat): print(f'processing chunk with shape {dat.shape}') return compute_flow( @@ -233,12 +233,12 @@ def compute_flow_wrap(dat): ) if __name__ == '__main__': - + run_dir = Path('baselines/session_4da05e87') # Path('baselines/session_ebdfe818') # original session_e41c9eff, main session_4da05e87, extra session_e1b6d2dc - + coords_save_pth = Path('base_coords.npz') - + if coords_save_pth.is_file(): print(f'{coords_save_pth} found, loading from file') base_coords = np.load(coords_save_pth)['arr_0'] @@ -252,7 +252,7 @@ def compute_flow_wrap(dat): base_coords = make_all_coords_arrays(dfs) print(f'saving {coords_save_pth}') np.savez_compressed(coords_save_pth, base_coords) - + print(f'initial data shape: {base_coords.shape}') main_map = np.array(Image.open('poke_map/pokemap_full_calibrated_CROPPED_1.png')) @@ -261,7 +261,7 @@ def compute_flow_wrap(dat): arrow_img = Image.open('poke_map/transparent_arrow.png').resize((arrow_size, arrow_size)) #alpha_val = get_sprite_by_coords(chars_img, 1, 0)[0,0] #walks = [get_sprite_by_coords(chars_img, x, 0) for x in [1, 4, 6, 8]] - + procs = 8 with Pool(procs) as p: run_steps = 16385 @@ -271,9 +271,9 @@ def compute_flow_wrap(dat): runs = base_data.shape[0] #base_data.shape[1] chunk_size = runs // procs batches_all_flows = p.map( - compute_flow_wrap, + compute_flow_wrap, [base_data[chunk_size*i:chunk_size*(i+1)+5] for i in range(procs)]) - + print(f"merging {len(batches_all_flows)} batches") merged_flows = {} for batch in tqdm(batches_all_flows): @@ -282,5 +282,5 @@ def compute_flow_wrap(dat): merged_flows[cell] += flow else: merged_flows[cell] = flow - + render_arrows("map_flow_run1/full_combined_1", merged_flows, arrow_img) \ No newline at end of file diff --git a/visualization/BetterMapVis_script_version_FLOW_edge.py b/visualization/BetterMapVis_script_version_FLOW_edge.py index b9e13ec78..0c1b01c3c 100644 --- a/visualization/BetterMapVis_script_version_FLOW_edge.py +++ b/visualization/BetterMapVis_script_version_FLOW_edge.py @@ -26,7 +26,7 @@ def get_sprite_by_coords(img, x, y): def game_coord_to_global_coord( x, y, map_idx): - + global_offset = np.array([1056-16*12, 331]) #np.array([790, -29]) map_offsets = { # https://bulbapedia.bulbagarden.net/wiki/List_of_locations_by_index_number_(Generation_I) @@ -85,7 +85,7 @@ def add_sprite(overlay_map, sprite, coord): else: intermediate[mask] = sprite[mask] overlay_map[coord[1]:coord[1]+16, coord[0]:coord[0]+16, :] = intermediate - + def blend_overlay(background, over): al = over[...,3].reshape(over.shape[0], over.shape[1], 1) ba = (255-al)/255 @@ -165,7 +165,7 @@ def compute_flow(all_coords, inter_steps=1, add_start=True): else: sprites_rendered += 1 pbar.set_description(f"draws: {sprites_rendered}") - + return all_flows def render_arrows(fname, all_flows, arrow_sprite): @@ -176,12 +176,12 @@ def render_arrows(fname, all_flows, arrow_sprite): max_y = max([k[1] for k in all_flows.keys()]) grid_dims = (max_x - min_x, max_y - min_y) cell_dim = arrow_sprite.size[0] # use x only, assuming square - + #colmap = matplotlib.cm.get_cmap('husl') colmap = seaborn.husl_palette(h=0.1, s=0.95, l=0.75, as_cmap=True) - + full_img = np.zeros( ((grid_dims[0]+1) * cell_dim, (grid_dims[1]+1) * cell_dim, 4 ), dtype=np.uint8) - + print("computing curl") dense_flow = np.zeros( ((grid_dims[0]+1), (grid_dims[1]+1), 2), dtype=np.float32) for coord, total_move in tqdm(all_flows.items()): @@ -191,9 +191,9 @@ def render_arrows(fname, all_flows, arrow_sprite): # Compute curl curl_F = dFy_dx - dFx_dy - + print(f"total curl: {curl_F.sum()}") - + with open('map_flow_run1/flow_dense.npy', 'wb') as f: np.save(f, dense_flow) @@ -201,16 +201,16 @@ def render_arrows(fname, all_flows, arrow_sprite): coords_to_remove = set() for coord, total_move in tqdm(all_flows.items()): if ( - (coord[0]-1, coord[1]) in all_flows.keys() and - (coord[0]+1, coord[1]) in all_flows.keys() and - (coord[0], coord[1]-1) in all_flows.keys() and + (coord[0]-1, coord[1]) in all_flows.keys() and + (coord[0]+1, coord[1]) in all_flows.keys() and + (coord[0], coord[1]-1) in all_flows.keys() and (coord[0], coord[1]+1) in all_flows.keys() ): coords_to_remove.add(coord) - + for cord in coords_to_remove: del all_flows[cord] - + for coord, total_move in tqdm(all_flows.items()): angle = math.atan2(-total_move[0], total_move[1]) #mag = math.sqrt(coord[0]**2 + coord[1]**2) @@ -220,13 +220,13 @@ def render_arrows(fname, all_flows, arrow_sprite): #color = hsv2rgb(np.array([0.5*angle/math.pi+0.5, 1.0, 1.0])) color = colmap(0.5*angle/math.pi+0.5) full_img[ - nx * cell_dim : (nx + 1) * cell_dim, + nx * cell_dim : (nx + 1) * cell_dim, ny * cell_dim : (ny + 1) * cell_dim ] = np.array(rotated_arrow) * np.array([color[0], color[1], color[2], 1.0]) print("Writing file") final_img = Image.fromarray(full_img) final_img.save(f"{fname}.png") - + ''' print("generating coords") fig, ax = plt.subplots(figsize=grid_dims) @@ -241,8 +241,8 @@ def render_arrows(fname, all_flows, arrow_sprite): cols.append(mag) print("rendering") ax.quiver( - [k[0] for k in all_flows.keys()], - [k[1] for k in all_flows.keys()], + [k[0] for k in all_flows.keys()], + [k[1] for k in all_flows.keys()], u, v, cols ) @@ -254,7 +254,7 @@ def render_arrows(fname, all_flows, arrow_sprite): print("saving") plt.savefig(f"{fname}.png") ''' - + def compute_flow_wrap(dat): print(f'processing chunk with shape {dat.shape}') return compute_flow( @@ -263,12 +263,12 @@ def compute_flow_wrap(dat): ) if __name__ == '__main__': - + run_dir = Path('baselines/session_4da05e87') # Path('baselines/session_ebdfe818') # original session_e41c9eff, main session_4da05e87, extra session_e1b6d2dc - + coords_save_pth = Path('base_coords.npz') - + if coords_save_pth.is_file(): print(f'{coords_save_pth} found, loading from file') base_coords = np.load(coords_save_pth)['arr_0'] @@ -282,7 +282,7 @@ def compute_flow_wrap(dat): base_coords = make_all_coords_arrays(dfs) print(f'saving {coords_save_pth}') np.savez_compressed(coords_save_pth, base_coords) - + print(f'initial data shape: {base_coords.shape}') main_map = np.array(Image.open('poke_map/pokemap_full_calibrated_CROPPED_1.png')) @@ -291,7 +291,7 @@ def compute_flow_wrap(dat): arrow_img = Image.open('poke_map/transparent_arrow.png').resize((arrow_size, arrow_size)) #alpha_val = get_sprite_by_coords(chars_img, 1, 0)[0,0] #walks = [get_sprite_by_coords(chars_img, x, 0) for x in [1, 4, 6, 8]] - + procs = 8 with Pool(procs) as p: run_steps = 16385 @@ -301,9 +301,9 @@ def compute_flow_wrap(dat): runs = base_data.shape[0] #base_data.shape[1] chunk_size = runs // procs batches_all_flows = p.map( - compute_flow_wrap, + compute_flow_wrap, [base_data[chunk_size*i:chunk_size*(i+1)+5] for i in range(procs)]) - + print(f"merging {len(batches_all_flows)} batches") merged_flows = {} for batch in tqdm(batches_all_flows): @@ -312,5 +312,5 @@ def compute_flow_wrap(dat): merged_flows[cell] += flow else: merged_flows[cell] = flow - + render_arrows("map_flow_run1/full_combined_test_inter", merged_flows, arrow_img) \ No newline at end of file diff --git a/visualization/BetterMapVis_script_version_PROG_COLOR.py b/visualization/BetterMapVis_script_version_PROG_COLOR.py index 20253bb5e..976d9ecaa 100644 --- a/visualization/BetterMapVis_script_version_PROG_COLOR.py +++ b/visualization/BetterMapVis_script_version_PROG_COLOR.py @@ -32,7 +32,7 @@ def get_sprite_by_coords(img, x, y): def game_coord_to_pixel_coord( x, y, map_idx, base_y): - + global_offset = np.array([1056-16*12, 331]) #np.array([790, -29]) map_offsets = { # https://bulbapedia.bulbagarden.net/wiki/List_of_locations_by_index_number_(Generation_I) @@ -91,7 +91,7 @@ def add_sprite(overlay_map, sprite, coord): else: intermediate[mask] = sprite[mask] overlay_map[coord[1]:coord[1]+16, coord[0]:coord[0]+16, :] = intermediate - + def blend_overlay(background, over): al = over[...,3].reshape(over.shape[0], over.shape[1], 1) ba = (255-al)/255 @@ -107,7 +107,7 @@ def render_video(fname, all_coords, walks, bg, inter_steps=4, add_start=True): sprites_rendered = 0 turbo_map = get_cmap("cet_isoluminant_cgo_80_c38")._resample(8) #mpl.colormaps['turbo']._resample(8) with media.VideoWriter( - f'{fname}.mov', split(bg).shape[:2], codec='prores_ks', + f'{fname}.mov', split(bg).shape[:2], codec='prores_ks', encoded_format='yuva444p', input_format='rgba', fps=60 ) as wr: step_count = len(all_coords) @@ -184,12 +184,12 @@ def test_render(name, dat, walks, bg): ) if __name__ == '__main__': - + run_dir = Path('baselines/session_4da05e87_main_full') # Path('baselines/session_ebdfe818') # original session_e41c9eff, main session_4da05e87, extra session_e1b6d2dc - + coords_save_pth = Path('base_coords.npz') - + if coords_save_pth.is_file(): print(f'{coords_save_pth} found, loading from file') base_coords = np.load(coords_save_pth)['arr_0'] @@ -203,14 +203,14 @@ def test_render(name, dat, walks, bg): base_coords = make_all_coords_arrays(dfs) print(f'saving {coords_save_pth}') np.savez_compressed(coords_save_pth, base_coords) - + print(f'initial data shape: {base_coords.shape}') main_map = np.array(Image.open('poke_map/pokemap_full_calibrated_CROPPED_1.png')) chars_img = np.array(Image.open('poke_map/characters.png')) alpha_val = get_sprite_by_coords(chars_img, 1, 0)[0,0] walks = [get_sprite_by_coords(chars_img, x, 0) for x in [1, 4, 6, 8]] - + start_bg = main_map.copy() procs = 8#16 @@ -225,6 +225,6 @@ def test_render(name, dat, walks, bg): # f'map_vis_final_state', base_data[(base_data.shape[0]-5):], walks, start_bg) #f'map_vis_initial_state', base_data[:], walks, start_bg all_render_errors = p.starmap( - test_render, + test_render, [(f'map_vis_color/map_vis_initial_state{i}', base_data[chunk_size*i:chunk_size*(i+1)+5], walks, start_bg) for i in range(procs)]) - +