Skip to content

Commit 712b319

Browse files
committed
feat(sweep):auto hyperparameter discovery
1 parent 44094b6 commit 712b319

File tree

3 files changed

+133
-47
lines changed

3 files changed

+133
-47
lines changed

exp/exp_main.py

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
from pathlib import Path
22
import datetime
33
import warnings
4-
import yaml
54
import json
65
from collections import OrderedDict
76
from typing import Generator
8-
from dataclasses import asdict
97
import importlib
108

119
import numpy as np
@@ -195,37 +193,9 @@ def vali(
195193

196194
def train(self) -> None:
197195
logger.info('>>>>>>> training start <<<<<<<')
198-
# save training config file for reference
199196
path = Path(self.configs.checkpoints) / self.configs.dataset_name / self.configs.model_name / self.configs.model_id / f"{self.configs.seq_len}_{self.configs.pred_len}" / self.configs.subfolder_train / f"iter{self.configs.itr_i}"
200-
path.mkdir(parents=True, exist_ok=True)
201-
logger.info(f"Training iter{self.configs.itr_i} save to: {path}")
202-
with open(path / "configs.yaml", 'w', encoding='utf-8') as f:
203-
yaml.dump(asdict(self.configs), f, default_flow_style=False)
204-
205-
accelerator.project_configuration.set_directories(project_dir=path)
206-
207-
# init exp tracker
208197
if (self.configs.wandb and accelerator.is_main_process) or self.configs.sweep:
209198
import wandb
210-
run = wandb.init(
211-
# Set the project where this run will be logged
212-
project="YOUR_PROJECT_NAME",
213-
# Track hyperparameters and run metadata
214-
config={
215-
"model_name": self.configs.model_name,
216-
"model_id": self.configs.model_id,
217-
"dataset_name": self.configs.dataset_name,
218-
"seq_len": self.configs.seq_len,
219-
"pred_len": self.configs.pred_len,
220-
"learning_rate": self.configs.learning_rate,
221-
"batch_size": self.configs.batch_size
222-
},
223-
dir=path
224-
)
225-
if self.configs.sweep:
226-
# overwrite default configs by wandb.config when sweeping
227-
self.configs.learning_rate = wandb.config.learning_rate
228-
self.configs.batch_size = wandb.config.batch_size
229199

230200
train_data, train_loader = self._get_data(flag='train')
231201
vali_data, vali_loader = self._get_data(flag='val')
@@ -278,7 +248,7 @@ def train(self) -> None:
278248
outputs: dict[str, Tensor] = model_train(
279249
exp_stage="train",
280250
train_stage=train_stage,
281-
current_epoch=epoch
251+
current_epoch=epoch,
282252
**batch
283253
)
284254

main.py

Lines changed: 82 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
import random
2+
from pathlib import Path
23
import datetime
4+
import importlib
5+
import yaml
6+
from dataclasses import asdict
7+
import pprint
38

49
import torch
510
import numpy as np
@@ -8,6 +13,8 @@
813
from utils.globals import logger, accelerator
914
from utils.configs import configs
1015

16+
hyperparameters_sweep: dict[str, dict[str, list]] = {}
17+
1118
def main():
1219
# random seed
1320
fix_seed_list = range(2024, 2024 + configs.itr)
@@ -16,6 +23,41 @@ def main():
1623

1724
Exp = Exp_Main
1825

26+
def start_exp_train() -> Exp_Main:
27+
# save training config file for reference
28+
path = Path(configs.checkpoints) / configs.dataset_name / configs.model_name / configs.model_id / f"{configs.seq_len}_{configs.pred_len}" / configs.subfolder_train / f"iter{configs.itr_i}" # same as the one in Exp_Main.train()
29+
path.mkdir(parents=True, exist_ok=True)
30+
logger.info(f"Training iter{configs.itr_i} save to: {path}")
31+
with open(path / "configs.yaml", 'w', encoding='utf-8') as f:
32+
yaml.dump(asdict(configs), f, default_flow_style=False)
33+
# init exp tracker
34+
if (configs.wandb and accelerator.is_main_process) or configs.sweep:
35+
import wandb
36+
run = wandb.init(
37+
# Set the project where this run will be logged
38+
project="YOUR_PROJECT_NAME",
39+
# Track hyperparameters and run metadata
40+
config={
41+
"model_name": configs.model_name,
42+
"model_id": configs.model_id,
43+
"dataset_name": configs.dataset_name,
44+
"seq_len": configs.seq_len,
45+
"pred_len": configs.pred_len,
46+
"learning_rate": configs.learning_rate,
47+
"batch_size": configs.batch_size
48+
},
49+
dir=path
50+
)
51+
# overwrite model hyperparameters when sweeping
52+
for attribute_name in hyperparameters_sweep.keys():
53+
setattr(configs, attribute_name, getattr(wandb.config, attribute_name))
54+
55+
accelerator.project_configuration.set_directories(project_dir=path)
56+
57+
exp = Exp(configs)
58+
exp.train()
59+
return exp
60+
1961
if configs.sweep:
2062
'''
2163
Currently, wandb sweep with huggingface accelerate multi GPU is tricky, use at your own risk.
@@ -42,12 +84,12 @@ def main():
4284
torch.manual_seed(fix_seed_list[configs.itr_i])
4385
np.random.seed(fix_seed_list[configs.itr_i])
4486

45-
exp = Exp(configs)
46-
47-
exp.train()
87+
exp = start_exp_train()
4888
exp.test()
49-
5089
elif configs.is_training:
90+
'''
91+
Normal train&test
92+
'''
5193
subfolder = datetime.datetime.now().strftime("%Y_%m%d_%H%M")
5294
configs.subfolder_train = subfolder
5395
for i in range(configs.itr):
@@ -57,12 +99,13 @@ def main():
5799
torch.manual_seed(fix_seed_list[i])
58100
np.random.seed(fix_seed_list[i])
59101

60-
exp = Exp(configs)
61-
exp.train()
62-
102+
exp = start_exp_train()
63103
torch.cuda.empty_cache()
64104
exp.test()
65105
else:
106+
'''
107+
test only
108+
'''
66109
exp = Exp(configs)
67110
exp.test()
68111
torch.cuda.empty_cache()
@@ -73,14 +116,40 @@ def main():
73116
if not configs.sweep:
74117
main()
75118
else:
119+
# first determine the hyperparameters actually accessed by model
120+
from utils.ExpConfigs import ExpConfigsTracker
121+
configs_tracker = ExpConfigsTracker(configs)
122+
model_module = importlib.import_module("models." + configs.model_name)
123+
model = model_module.Model(configs_tracker)
124+
del model
125+
accessed_configs: set[str] = configs_tracker.get_accessed_attributes()
126+
max_count = 1
127+
for accessed_config in accessed_configs:
128+
try:
129+
ref_values = configs.get_sweep_values(accessed_config)
130+
if ref_values and (type(ref_values) is list):
131+
hyperparameters_sweep[accessed_config] = {
132+
"values": ref_values
133+
}
134+
max_count *= len(ref_values)
135+
except Exception as e:
136+
continue
137+
138+
if hyperparameters_sweep == {}:
139+
logger.error(f"No hyperparameter to be searched, stopping now..")
140+
logger.debug(f"{configs.model_name} access these attributes in ExpConfigs:")
141+
configs_tracker.print_access_report()
142+
logger.debug("""Possible reasons: (1) The model does not access any hyperparameters in ExpConfigs; (2) The accessed hyperparameters have not set their metadata properly. Check the ExpConfigs class in utils/ExpConfigs.py. Example metadata setting:
143+
d_model: int = field(metadata={"sweep": [32, 64, 128, 256]})""")
144+
exit(0)
145+
else:
146+
logger.info(f"""{len(hyperparameters_sweep)} hyperparameters and {max_count} runs: \n{pprint.pformat(hyperparameters_sweep)}""")
147+
76148
import wandb
77149
sweep_configuration = {
78150
"method": "grid",
79-
"metric": {"goal": "minimize", "name": "loss_val"},
80-
"parameters": {
81-
"learning_rate": {"values": [0.01, 0.001, 0.0001, 0.00001]},
82-
"batch_size": {"values": [16, 32, 64, 128]},
83-
},
151+
"metric": {"goal": "minimize", "name": "loss_val_best"},
152+
"parameters": hyperparameters_sweep
84153
}
85154
temp_file_path = "storage/tmp.txt"
86155
if accelerator.is_main_process:
@@ -95,7 +164,7 @@ def main():
95164
sweep_id,
96165
function=main,
97166
project="YOUR_PROJECT_NAME",
98-
count=16
167+
count=max_count
99168
)
100169
except KeyboardInterrupt:
101170
if accelerator.is_main_process:

utils/ExpConfigs.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from dataclasses import dataclass
1+
from dataclasses import dataclass, field
2+
from typing import Any, Optional
23

34
@dataclass
45
class ExpConfigs:
@@ -7,6 +8,13 @@ class ExpConfigs:
78
89
Make sure to update this dataclass after adding new args in argparse
910
'''
11+
@classmethod
12+
def get_sweep_values(cls, attr_name: str) -> Optional[list]:
13+
for field_info in cls.__dataclass_fields__.values():
14+
if field_info.name == attr_name:
15+
return field_info.metadata.get('sweep')
16+
return None
17+
1018
# basic config
1119
task_name: str
1220
is_training: int
@@ -85,10 +93,10 @@ class ExpConfigs:
8593
enc_in: int
8694
dec_in: int
8795
c_out: int
88-
d_model: int
96+
d_model: int = field(metadata={"sweep": [32, 64, 128, 256]})
8997
d_timesteps: int
9098
n_heads: int
91-
n_layers: int
99+
n_layers: int = field(metadata={"sweep": [1, 2, 3, 4]})
92100
e_layers: int
93101
d_layers: int
94102
hidden_layers: int
@@ -155,3 +163,42 @@ class ExpConfigs:
155163
patch_len_max_irr: int | None = None # maximum number of observations along time dimension in a patch of x, set in irregular time series datasets
156164
subfolder_train: str = "" # timestamp of training in format %Y_%m%d_%H%M
157165
itr_i: int = 0 # current training iteration. [0, itr-1]
166+
167+
class ExpConfigsTracker:
168+
"""Wrapper that tracks which ExpConfigs attributes are accessed"""
169+
170+
def __init__(self, configs: ExpConfigs):
171+
object.__setattr__(self, '_config', configs)
172+
object.__setattr__(self, '_accessed_attrs', set())
173+
174+
def __getattr__(self, name: str) -> Any:
175+
if hasattr(self._config, name):
176+
self._accessed_attrs.add(name)
177+
return getattr(self._config, name)
178+
raise AttributeError(f"'{type(self._config).__name__}' object has no attribute '{name}'")
179+
180+
def __setattr__(self, name: str, value: Any) -> None:
181+
if name.startswith('_'):
182+
object.__setattr__(self, name, value)
183+
else:
184+
self._accessed_attrs.add(name)
185+
setattr(self._config, name, value)
186+
187+
def get_accessed_attributes(self) -> set[str]:
188+
"""Return set of accessed attribute names"""
189+
return self._accessed_attrs.copy()
190+
191+
def get_unused_attributes(self) -> set[str]:
192+
"""Return set of unused attribute names"""
193+
all_attrs = {field.name for field in self._config.__dataclass_fields__.values()}
194+
return all_attrs - self._accessed_attrs
195+
196+
def print_access_report(self):
197+
"""Print a report of accessed vs unused attributes"""
198+
accessed = self.get_accessed_attributes()
199+
unused = self.get_unused_attributes()
200+
201+
print("=== ExpConfigs Access Report ===")
202+
print(f"Accessed attributes ({len(accessed)}):")
203+
for attr in sorted(accessed):
204+
print(f" ✓ {attr}")

0 commit comments

Comments
 (0)