-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdreamer_tree_search_eval.py
More file actions
341 lines (294 loc) · 14.1 KB
/
dreamer_tree_search_eval.py
File metadata and controls
341 lines (294 loc) · 14.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
"""Module to run tree search/planning version of evaluation.
"""
import os
import re
import pathlib
import importlib
from pathlib import Path
from copy import deepcopy
from datetime import datetime
from typing import Tuple, Dict, Union
import yaml
import sheeprl
from sheeprl.cli import check_configs_evaluation
from sheeprl.utils.utils import dotdict
import hydra
import pandas as pd
from omegaconf import DictConfig, OmegaConf, open_dict
from lightning import Fabric
from dreamer_vfd.utils.eval import gather_ckpt_data
def parse_ckpt_seed(ckpt_filename: str) -> Tuple[int, int]:
"""
parse_ckpt_seed Extracts the checkpoint number and the seed of the given
checkpoint file.
This method assumes a certain file format and directory structure. It
should be {exp-name}/{seed_X}/version_0/checkpoint/{ckpt_X_0.ckpt}
Args:
ckpt_filename (str): name of the full checkpoint path.
Returns:
Tuple[int, int]: the checkpoint number and seed number
"""
# Extract seed number
seed_match = re.search(r"seed_(\d+)", ckpt_filename)
seed_number = int(seed_match.group(1)) if seed_match else None
# Extract checkpoint number
checkpoint_match = re.search(r"ckpt_(\d+)_\d+", ckpt_filename)
checkpoint_number = int(checkpoint_match.group(1)) if checkpoint_match else None
return checkpoint_number, seed_number
def dotdict_to_dict(dot_dict: dotdict) -> dict:
"""
dotdict_to_dict Converts a sheeprl dot dict back to a normal dict
"""
for k, v in dot_dict.items():
if isinstance(v, dotdict):
dot_dict[k] = dotdict_to_dict(v)
return dict(dot_dict)
def skip(
baseline_df: pd.DataFrame,
episode_df: pd.DataFrame,
step_df: pd.DataFrame,
ckpt: int,
seed: int,
planning_horizon: int,
child_nodes: int,
num_episodes: int,
value_bootstrap: bool,
task: str,
exp_name: str,
) -> bool:
if baseline_df.empty or episode_df.empty or step_df.empty:
return False
ckpt_seed_df = baseline_df[
(baseline_df["Checkpoint"] == ckpt) & (baseline_df["Seed"] == seed) & (baseline_df["Experiment"] == exp_name)
]
if ckpt_seed_df.empty or ckpt_seed_df.shape[0] < num_episodes:
return False
ckpt_seed_df = episode_df[
(
(episode_df["Checkpoint"] == ckpt)
& (episode_df["Seed"] == seed)
& (episode_df["Planning Horizon"] == planning_horizon)
& (episode_df["Child Nodes"] == child_nodes)
& (episode_df["Value Bootstrap"] == value_bootstrap)
& (episode_df["Task"] == task)
& (episode_df["Experiment"] == exp_name)
)
]
if ckpt_seed_df.empty or ckpt_seed_df.shape[0] < num_episodes:
return False
ckpt_seed_df = step_df[
(
(step_df["Checkpoint"] == ckpt)
& (step_df["Seed"] == seed)
& (step_df["Planning Horizon"] == planning_horizon)
& (step_df["Child Nodes"] == child_nodes)
& (step_df["Value Bootstrap"] == value_bootstrap)
& (step_df["Task"] == task)
& (step_df["Experiment"] == exp_name)
)
]
if ckpt_seed_df.empty or ckpt_seed_df["Episode Index"].unique().shape[0] < num_episodes:
return False
return True
def task_dict_to_str(task: Union[Dict[str, float], str]) -> str:
if task == "none":
return task
return "_".join(str(value) for value in task.values())
if __name__ == "__main__":
# TODO: get values from cmd line args
CONFIG_PATH = "dreamer_vfd/configs"
CONFIG_NAME = "dreamer_vfd_tree_search_config"
# get our dreamer_vfd dir where our configs dir is
sheeprl_dir = pathlib.Path(sheeprl.__file__).parent.resolve()
# add os env var so sheeprl will look in our configs dir as well
os.environ["SHEEPRL_SEARCH_PATH"] = str(os.path.join(sheeprl_dir, "configs"))
# define decorated evaluation function
@hydra.main(version_base="1.13", config_path=CONFIG_PATH, config_name=CONFIG_NAME)
def evaluation(cfg: DictConfig):
"""Copy of sheeprl's cli.py evaluation function v0.4.5.dev0.
We modify evaluation to find multiples seeds per experiment and find all the checkpoints for
the experiment. After this, it runs evaluation for some number of episodes for each seed and
each checkpoint. The metrics are saved to a file in evaluation where it can be plotted later.
Args:
cfg (DictConfig): Hydra config found by decorator.
"""
# set timestamp
timestamp = datetime.now().strftime("%m%d%y-%H%M%S")
# get possible data options
ckpt_files = cfg.checkpoint_paths
eval_data_path = cfg.exp_eval_file
assert (
ckpt_files != "none" or eval_data_path != "none"
), "checkpoint_path and exp_eval_file cannot both be 'none'."
if ckpt_files == "none":
# get ckpt file from eval data
assert os.path.exists(eval_data_path), f"The experiment path {eval_data_path} doesn't exist."
# load in eval data
eval_dataframe = pd.read_pickle(eval_data_path)
# make sure exp names resolve
exp_name = cfg.experiment_name
available_exp_names = eval_dataframe["Experiment"].unique()
if len(available_exp_names) > 1:
assert (
exp_name in available_exp_names
), f"Experiment '{exp_name}' is not in eval_dataframe. Available experiments are: {available_exp_names}"
# make eval data only contain that experiment
eval_dataframe = eval_dataframe[eval_dataframe["Experiment"] == exp_name]
ckpt_data = gather_ckpt_data(eval_dataframe, cfg.checkpoint_selection)
else:
# extract checkpoint and seed number for each checkpoint, list of tuples(ckpt_num, seed_num)
ckpt_nums_seeds = [parse_ckpt_seed(ckpt_file) for ckpt_file in ckpt_files]
# convert str to Path
ckpt_files = [Path(ckpt_file) for ckpt_file in ckpt_files]
# consolidate into ckpt_data
ckpt_data = []
for ckpt_file, (ckpt_num, ckpt_seed) in zip(ckpt_files, ckpt_nums_seeds):
ckpt_data.append((ckpt_file, ckpt_num, ckpt_seed, "manual"))
# load the first ckpt config to get the first algo name and env name
ckpt_file = ckpt_data[0][0]
ckpt_file = Path(ckpt_file.as_posix().replace("/opt/project/logs/runs", "/data/petabyte/dreamer_vfd"))
ckpt_file = Path(ckpt_file.as_posix().replace("/opt/home", "/data/petabyte"))
ckpt_cfg = OmegaConf.load(ckpt_file.parent.parent.parent / ".hydra" / "config.yaml")
# configure and make output folder
output_folder = cfg.get("output_folder", "./tree_search_eval")
output_folder = Path(output_folder) / ckpt_cfg.algo.name / ckpt_cfg.env.id
os.makedirs(output_folder, exist_ok=True)
cfg.output_folder = output_folder.as_posix()
# set up folder to write configs to
config_output = output_folder / "configs"
os.makedirs(config_output, exist_ok=True)
### ITERATE OVER CKPT_DATA and run eval
for (ckpt_file, ckpt_num, seed, selection) in ckpt_data:
# make sure ckpt exists
# replace /opt/home with /data/petabyte
ckpt_file = Path(ckpt_file.as_posix().replace("/opt/home", "/data/petabyte"))
ckpt_file = Path(ckpt_file.as_posix().replace("/opt/project/logs/runs", "/data/petabyte/dreamer_vfd"))
assert os.path.exists(ckpt_file), f"Checkpoint {ckpt_file}, does not exist."
# Load the checkpoint configuration
ckpt_cfg = OmegaConf.load(ckpt_file.parent.parent.parent / ".hydra" / "config.yaml")
with open_dict(cfg):
cfg.checkpoint_path = ckpt_file.as_posix()
capture_video = getattr(cfg.env, "capture_video", False)
cfg.env = {"capture_video": capture_video, "num_envs": 1}
cfg.exp = {}
cfg.algo = {}
cfg.fabric = {
"devices": 1,
"num_nodes": 1,
"strategy": "auto",
"accelerator": getattr(cfg.fabric, "accelerator", "auto"),
}
# Merge configs
ckpt_cfg.merge_with(cfg)
# Update values after merge
ckpt_cfg.ckpt_num = int(ckpt_num)
ckpt_cfg.last_ckpt = selection == "last_ckpt"
ckpt_cfg.selection = selection
# Check the validity of the configuration and run the evaluation
check_configs_evaluation(ckpt_cfg)
ckpt_cfg = dotdict(OmegaConf.to_container(ckpt_cfg, resolve=True, throw_on_missing=True))
# TODO: change the number of devices when FSDP will be supported
accelerator = ckpt_cfg.fabric.get("accelerator", "auto")
fabric: Fabric = hydra.utils.instantiate(
ckpt_cfg.fabric,
accelerator=accelerator,
devices=1,
num_nodes=1,
_convert_="all",
)
# set the seed for the trial
fabric.seed_everything(cfg.env_seed)
# Load the checkpoint
state = fabric.load(ckpt_cfg.checkpoint_path)
# iterate over what is in horizon_node_dict
run_baseline = True
for planning_horizon in cfg.horizon_node_dict:
for max_child_nodes in cfg.horizon_node_dict[planning_horizon]:
# set vars in config
ckpt_cfg.timestamp = timestamp
ckpt_cfg.planning_horizon = planning_horizon
ckpt_cfg.max_child_nodes = max_child_nodes
run_name = (
output_folder
/ f"evaluation_ckpt_{ckpt_num}_seed_{seed}_ph_{planning_horizon}_nodes_{max_child_nodes}"
)
ckpt_cfg.run_name = run_name.as_posix()
ckpt_cfg.task_str = task_dict_to_str(cfg.task)
# set up df paths
os.makedirs(output_folder / "dfs", exist_ok=True)
baseline_df_path = output_folder / "dfs" / "baseline_df.pkl"
ep_df_path = (
output_folder
/ "dfs"
/ f"edf_vb_{cfg.value_bootstrap}_hor_{planning_horizon}_nodes_{max_child_nodes}_task_{ckpt_cfg.task_str}.pkl"
)
step_df_path = (
output_folder
/ "dfs"
/ f"sdf_vb_{cfg.value_bootstrap}_hor_{planning_horizon}_nodes_{max_child_nodes}_task_{ckpt_cfg.task_str}.pkl"
)
ckpt_cfg.baseline_df_path = baseline_df_path.as_posix()
ckpt_cfg.episode_df_path = ep_df_path.as_posix()
ckpt_cfg.step_df_path = step_df_path.as_posix()
# load df and get ckpt/seed version of it to see if data exists in horizon/node loop
baseline_dataframes = []
episode_dataframes = []
step_dataframes = []
if baseline_df_path.exists():
existing_bdf = pd.read_pickle(baseline_df_path)
baseline_dataframes.append(existing_bdf)
else:
existing_bdf = pd.DataFrame()
if ep_df_path.exists():
existing_edf = pd.read_pickle(ep_df_path)
episode_dataframes.append(existing_edf)
else:
existing_edf = pd.DataFrame()
if step_df_path.exists():
existing_sdf = pd.read_pickle(step_df_path)
step_dataframes.append(existing_sdf)
else:
existing_sdf = pd.DataFrame()
# see if we have this data
if skip(
existing_bdf,
existing_edf,
existing_sdf,
ckpt_num,
seed,
planning_horizon,
max_child_nodes,
cfg.num_episodes,
cfg.value_bootstrap,
ckpt_cfg.task_str,
ckpt_cfg.exp_name,
):
run_baseline = False
continue
print(
f"Running Tree Search for planning horizon {planning_horizon} and child nodes {max_child_nodes}."
)
# write final config to output folder
config_output_file = (
config_output
/ f"config_vb_{cfg.value_bootstrap}_hor_{planning_horizon}_nodes_{max_child_nodes}_task_{ckpt_cfg.task_str}.yaml"
)
with open(config_output_file, "w") as file: # pylint: disable=W1514
dict_config = dotdict_to_dict(deepcopy(ckpt_cfg))
yaml.dump(dict_config, file)
task = importlib.import_module("dreamer_vfd.dreamer_v3_eval.evaluate")
command = task.__dict__["evaluate"]
baseline_df, ep_df, step_df = fabric.launch(command, ckpt_cfg, state, True, run_baseline)
run_baseline = False
if not baseline_df.empty:
baseline_dataframes.append(baseline_df)
bdf = pd.concat(baseline_dataframes, ignore_index=True)
bdf.to_pickle(baseline_df_path)
episode_dataframes.append(ep_df)
edf = pd.concat(episode_dataframes, ignore_index=True)
edf.to_pickle(ep_df_path)
step_dataframes.append(step_df)
sdf = pd.concat(step_dataframes, ignore_index=True)
sdf.to_pickle(step_df_path)
# Call high level evaluation function
evaluation() # pylint: disable=E1120