-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_utils.py
More file actions
59 lines (43 loc) · 1.88 KB
/
train_utils.py
File metadata and controls
59 lines (43 loc) · 1.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import pickle
import os
import numpy as np
import bz2
from utils_load import load_file
def sample_weights_and_metrics(evolved_weights, run_metrics, sample_epochs):
n_evolved_weights = len(evolved_weights)
n_sample_epochs = len(sample_epochs)
sample_weights = []
sample_metrics = np.zeros((n_sample_epochs, 4))
for i, epoch in enumerate(sample_epochs):
if i < n_evolved_weights:
current_weights = evolved_weights[i]
sample_weights.append(current_weights)
sample_metrics[i] = run_metrics[i]
else:
break
return sample_weights, sample_metrics
def load_set_trained(fname, sample_epochs, use_all=True) -> dict:
if os.path.isfile(fname):
try:
set_pretrained_samples = {}
set_pretrained = load_file(fname)
if use_all:
return set_pretrained
set_pretrained_samples['density_levels'] = set_pretrained['density_levels']
set_pretrained_samples['runs'] = []
for run in set_pretrained['runs']:
old_run = run['run']
new_run = {'run_id': old_run['run_id'], 'set_params': old_run['set_params'],
'training_time': old_run['training_time']}
evolved_weights = old_run['evolved_weights']
run_metrics = old_run['set_metrics']
new_run['evolved_weights'], new_run['set_metrics'] = sample_weights_and_metrics(evolved_weights, run_metrics, sample_epochs)
set_pretrained_samples['runs'].append({'set_sparsity': run['set_sparsity'], 'run': new_run})
del set_pretrained
return set_pretrained_samples
except EOFError:
print(f"FILE malformed: {fname} ")
return {}
else:
print(f"FILE: {fname} already processed or non-existent -> skipping")
return {}