forked from jamestszhim/modals
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsearch.py
More file actions
94 lines (76 loc) · 2.71 KB
/
search.py
File metadata and controls
94 lines (76 loc) · 2.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import random
import numpy as np
import ray
import ray.tune as tune
from modals.setup import create_hparams, create_parser
from modals.trainer import TextModelTrainer
from ray.tune.schedulers import PopulationBasedTraining
class RayModel(tune.Trainable):
def _setup(self, *args):
self.trainer = TextModelTrainer(self.config)
def _train(self):
print(f'Starting Ray Iteration: {self._iteration}')
train_acc, valid_acc = self.trainer.run_model(self._iteration)
test_acc, test_loss = self.trainer._test(self._iteration, mode='test')
return {'train_acc': train_acc, 'valid_acc': valid_acc, 'test_acc': test_acc}
def _save(self, checkpoint_dir):
print(checkpoint_dir)
path = self.trainer.save_model(checkpoint_dir, self._iteration)
print(path)
return path
def _restore(self, checkpoint_path):
self.trainer.load_model(checkpoint_path)
def reset_config(self, new_config):
self.config = new_config
self.trainer.reset_config(self.config)
return True
def search():
FLAGS = create_parser('search')
hparams = create_hparams('search', FLAGS)
# if FLAGS.restore:
# train_spec["restore"] = FLAGS.restore
def explore(config):
"""Custom explore function.
Args:
config: dictionary containing ray config params.
Returns:
Copy of config with modified augmentation policy.
"""
new_params = []
for i, param in enumerate(config["hp_policy"]):
if random.random() < 0.2:
new_params.append(random.randint(0, 10))
else:
amt = np.random.choice(
[0, 1, 2, 3], p=[0.25, 0.25, 0.25, 0.25])
amt = int(amt)
if random.random() < 0.5:
new_params.append(max(0, param - amt))
else:
new_params.append(min(10, param + amt))
config["hp_policy"] = new_params
return config
ray.init()
pbt = PopulationBasedTraining(
time_attr="training_iteration",
metric="valid_acc",
mode='max',
perturbation_interval=FLAGS.perturbation_interval,
custom_explore_fn=explore,
log_config=True)
tune.run(
RayModel,
name=hparams['ray_name'],
scheduler=pbt,
reuse_actors=True,
verbose=True,
checkpoint_score_attr="valid_acc",
checkpoint_freq=FLAGS.checkpoint_freq,
resources_per_trial={"gpu": FLAGS.gpu, "cpu": FLAGS.cpu},
stop={"training_iteration": hparams['num_epochs']},
config=hparams,
local_dir=FLAGS.ray_dir,
num_samples=FLAGS.num_samples
)
if __name__ == "__main__":
search()