|
| 1 | +import os |
| 2 | +import textwrap |
| 3 | +import argparse |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +from tqdm import tqdm |
| 7 | +from termcolor import colored |
| 8 | + |
| 9 | +import jericho |
| 10 | + |
| 11 | + |
| 12 | +def parse_args(): |
| 13 | + parser = argparse.ArgumentParser() |
| 14 | + |
| 15 | + parser.add_argument("filenames", nargs="+", |
| 16 | + help="Path to a Z-Machine game(s).") |
| 17 | + parser.add_argument("--index", type=int, help="Index of the walkthrough command to investigate.") |
| 18 | + parser.add_argument("--interactive", action="store_true", |
| 19 | + help="Type the command.") |
| 20 | + parser.add_argument("--debug", action="store_true", |
| 21 | + help="Launch ipdb on FAIL.") |
| 22 | + parser.add_argument("-v", "--verbose", action="store_true", |
| 23 | + help="Print the last observation when not achieving max score.") |
| 24 | + parser.add_argument("-vv", "--very-verbose", action="store_true", |
| 25 | + help="Print the last observation when not achieving max score.") |
| 26 | + parser.add_argument("--check-state", action="store_true", |
| 27 | + help="Check if each command changes the state.") |
| 28 | + return parser.parse_args() |
| 29 | + |
| 30 | +args = parse_args() |
| 31 | + |
| 32 | + |
| 33 | +def get_zmp(env): |
| 34 | + zmp = env.get_state()[0] |
| 35 | + #start = zmp.view(">u2")[6] # ref: https://inform-fiction.org/zmachine/standards/z1point1/sect06.html#two |
| 36 | + #length = 240 * 2 # 240 2-byte global variables. |
| 37 | + #globals = zmp[start:start + length].view(">i2") |
| 38 | + return zmp |
| 39 | + |
| 40 | + |
| 41 | +def display_indices(indices): |
| 42 | + indices = sorted(indices) |
| 43 | + |
| 44 | + NB_COLS = 16 |
| 45 | + for row in range(int(np.ceil(len(indices) / NB_COLS))): |
| 46 | + for col in range(NB_COLS): |
| 47 | + idx = (row * NB_COLS) + col |
| 48 | + if idx >= len(indices): |
| 49 | + break |
| 50 | + |
| 51 | + text = f"{indices[idx]:6d}" |
| 52 | + print(text, end=",") |
| 53 | + |
| 54 | + print() |
| 55 | + |
| 56 | + |
| 57 | +def display_relevant(G, changed, relevant, noise): |
| 58 | + relevant = sorted(relevant) |
| 59 | + |
| 60 | + NB_COLS = 16 |
| 61 | + for row in range(int(np.ceil(len(relevant) / NB_COLS))): |
| 62 | + for col in range(NB_COLS): |
| 63 | + idx = (row * NB_COLS) + col |
| 64 | + if idx >= len(relevant): |
| 65 | + break |
| 66 | + |
| 67 | + color = 'white' |
| 68 | + # assert len(relevant & noise) == 0 |
| 69 | + #if idx in relevant: |
| 70 | + color = 'green' |
| 71 | + if idx in noise: |
| 72 | + color = 'red' |
| 73 | + |
| 74 | + bg_color = None |
| 75 | + attrs = [] |
| 76 | + if idx in changed: |
| 77 | + attrs.append("bold") |
| 78 | + bg_color = "on_grey" |
| 79 | + |
| 80 | + text = colored(f"{relevant[idx]:5x}", color, bg_color) |
| 81 | + print(text, end=",") |
| 82 | + |
| 83 | + print() |
| 84 | + |
| 85 | +def display_ram(G, changed, relevant, noise): |
| 86 | + NB_COLS = 150 |
| 87 | + for row in range(int(np.ceil(len(G) / NB_COLS))): |
| 88 | + for col in range(NB_COLS): |
| 89 | + idx = (row * NB_COLS) + col |
| 90 | + color = 'white' |
| 91 | + assert len(relevant & noise) == 0 |
| 92 | + if idx in relevant: |
| 93 | + color = 'green' |
| 94 | + if idx in noise: |
| 95 | + color = 'red' |
| 96 | + |
| 97 | + bg_color = None |
| 98 | + attrs = [] |
| 99 | + if idx in changed: |
| 100 | + attrs.append("bold") |
| 101 | + bg_color = "on_grey" |
| 102 | + |
| 103 | + text = colored(f".", color, bg_color) |
| 104 | + print(text, end="") |
| 105 | + |
| 106 | + print() |
| 107 | + |
| 108 | + |
| 109 | +def show_mem_diff(M0, M1, M2, start=24985, length=240*2): |
| 110 | + # TODO: start can be obtained from ZMP.view('>i2')[6] |
| 111 | + import numpy as np |
| 112 | + |
| 113 | + M1 = M1[start:start+length].view(">i2") |
| 114 | + M2 = M2[start:start+length].view(">i2") |
| 115 | + indices = np.nonzero(M1 != M2)[0] |
| 116 | + |
| 117 | + if M0 is not None: |
| 118 | + M0 = M0[start:start+length].view(">i2") |
| 119 | + to_print = sorted(set(indices) - set(np.nonzero(M0 != M1)[0])) |
| 120 | + to_print += [None] |
| 121 | + to_print += sorted(set(indices) & set(np.nonzero(M0 != M1)[0])) |
| 122 | + |
| 123 | + for i in to_print: |
| 124 | + if i is None: |
| 125 | + print("--") |
| 126 | + continue |
| 127 | + |
| 128 | + print(f"{i:3d} [{start+i*2:5d}]: {M1[i]:5d} -> {M2[i]:5d}") |
| 129 | + |
| 130 | + |
| 131 | +filename_max_length = max(map(len, args.filenames)) |
| 132 | +for filename in sorted(args.filenames): |
| 133 | + rom = os.path.basename(filename) |
| 134 | + print(filename.ljust(filename_max_length))#, end=" ") |
| 135 | + |
| 136 | + env = jericho.FrotzEnv(filename) |
| 137 | + if not env.is_fully_supported: |
| 138 | + print(colored("SKIP\tUnsupported game", 'yellow')) |
| 139 | + continue |
| 140 | + |
| 141 | + if "walkthrough" not in env.bindings: |
| 142 | + print(colored("SKIP\tMissing walkthrough", 'yellow')) |
| 143 | + continue |
| 144 | + |
| 145 | + env.reset() |
| 146 | + Z = get_zmp(env) |
| 147 | + |
| 148 | + history = [] |
| 149 | + history.append((0, 'reset', env.get_state(), Z)) |
| 150 | + |
| 151 | + walkthrough = env.get_walkthrough() |
| 152 | + cmd_no = 0 |
| 153 | + |
| 154 | + noise_indices = set() |
| 155 | + relevant_indices = set() |
| 156 | + changes_history = [] |
| 157 | + cpt = 0 |
| 158 | + while True: |
| 159 | + if cmd_no >= len(walkthrough): |
| 160 | + break |
| 161 | + |
| 162 | + cmd = walkthrough[cmd_no].lower() |
| 163 | + |
| 164 | + if args.interactive: |
| 165 | + manual_cmd = input(f"{cpt}. [{cmd}]> ") |
| 166 | + if manual_cmd.strip(): |
| 167 | + cmd = manual_cmd.lower() |
| 168 | + else: |
| 169 | + print(f"{cpt}. > {cmd}") |
| 170 | + |
| 171 | + if cmd == walkthrough[cmd_no].lower(): |
| 172 | + cmd_no += 1 |
| 173 | + |
| 174 | + last_env_objs = env.get_world_objects(clean=False) |
| 175 | + last_env_objs_cleaned = env.get_world_objects(clean=True) |
| 176 | + |
| 177 | + last_Z = Z |
| 178 | + obs, rew, done, info = env.step(cmd) |
| 179 | + cpt += 1 |
| 180 | + # print(">", cmd) |
| 181 | + print(obs) |
| 182 | + |
| 183 | + env_objs = env.get_world_objects(clean=False) |
| 184 | + env_objs_cleaned = env.get_world_objects(clean=True) |
| 185 | + |
| 186 | + Z = get_zmp(env) |
| 187 | + |
| 188 | + changes = set(np.nonzero(Z != last_Z)[0]) |
| 189 | + |
| 190 | + history.append((cmd, env.get_state())) |
| 191 | + changes_history.append(changes) |
| 192 | + |
| 193 | + # ans = "" |
| 194 | + # while ans not in {'y', 'n', 's'}: |
| 195 | + # ans = input("State has changed? [y/n/s]> ").lower() |
| 196 | + |
| 197 | + # if ans == "y": |
| 198 | + # relevant_indices |= changes - noise_indices |
| 199 | + # elif ans == "s": |
| 200 | + # pass |
| 201 | + # elif ans == "n": |
| 202 | + # noise_indices |= changes |
| 203 | + # relevant_indices -= noise_indices |
| 204 | + # else: |
| 205 | + # assert False |
| 206 | + |
| 207 | + #print(special_indices) |
| 208 | + |
| 209 | + # display_relevant(Z, changes, relevant_indices, noise_indices) |
| 210 | + |
| 211 | + def display_objs_diff(): |
| 212 | + print("Objects diff:") |
| 213 | + for o1, o2 in zip(last_env_objs, env_objs): |
| 214 | + if o1 != o2: |
| 215 | + print(colored(f"{o1}\n{o2}", "red")) |
| 216 | + |
| 217 | + print("Cleaned objects diff:") |
| 218 | + for o1, o2 in zip(last_env_objs_cleaned, env_objs_cleaned): |
| 219 | + if o1 != o2: |
| 220 | + print(colored(f"{o1}\n{o2}", "red")) |
| 221 | + |
| 222 | + # breakpoint() |
| 223 | + |
| 224 | + def search_unique_changes(idx): |
| 225 | + counter = {idx: 0 for idx in changes_history[idx]} |
| 226 | + matches = {idx: [] for idx in changes_history[idx]} |
| 227 | + for i, changes in enumerate(changes_history): |
| 228 | + for idx in changes: |
| 229 | + if idx in counter: |
| 230 | + counter[idx] += 1 |
| 231 | + matches[idx].append((i, history[i+1][0])) |
| 232 | + |
| 233 | + for idx, count in sorted(counter.items(), key=lambda e: e[::-1]): |
| 234 | + if matches[idx][0][0] == 0: |
| 235 | + continue |
| 236 | + |
| 237 | + print(f"{idx:6d}: {count:3d} : " + ", ".join(f"{i}.{cmd}" for i, cmd in matches[idx][:10])) |
| 238 | + |
| 239 | + #if len(changes_history) > : |
| 240 | + print(f"Ram changes unique to command: {args.index}. > {history[args.index+1][0]}") |
| 241 | + search_unique_changes(args.index) |
| 242 | + # display_indices(search_unique_changes(args.index)) |
| 243 | + if args.debug: |
| 244 | + breakpoint() |
0 commit comments