-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
50 lines (36 loc) · 1.64 KB
/
train.py
File metadata and controls
50 lines (36 loc) · 1.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import cu__grid_cell.data_gen as dg
import cu__grid_cell.data_gen_threaded as thr
from cu__grid_cell.custom_loss import *
from cu__grid_cell.preparation import *
model_obj = preparation()
working_path = model_obj.get_working_DIR() # make all prep. work!
config = model_obj.config
val_data_obj = dg.data_gen(dataset=config.CU_val_hdf5_path,shuffle=False, augment=False, config = config, batchsize = 'all', percentage_of_data=config.val_sample_percentage)# care loads howl dataset into ram!
train_data_obj = thr.DataGenerator(dataset = config.CU_tr_hdf5_path,shuffle=True,augment=True, config = config, percentage_of_data=config.train_sample_percentage)
# define model
custom_los_obj = loss(config)
if Config.splitted:
model_obj.prep_for_training_splitted(config, train_data_obj, val_data_obj, custom_los_obj)
else:
model_obj.prep_for_training(config, train_data_obj, val_data_obj, custom_los_obj)
model_obj.train()
#def validate(config, model, val_data, validation_steps, metrics_id, epoch):
# prediction = model.predict(val_data, batch_size=batch)
# # list all data in history
# print(history.history.keys())
# summarize history for accuracy
# plt.plot(history.history['loss'])
# plt.plot(history.history['val_loss'])
# plt.title('model accuracy')
# plt.ylabel('accuracy')
# plt.xlabel('epoch')
# plt.legend(['train', 'test'], loc='upper left')
# plt.show()
# summarize history for loss
# plt.plot(history.history['loss'])
# plt.plot(history.history['val_loss'])
# plt.title('model loss')
# plt.ylabel('loss')
# plt.xlabel('epoch')
# plt.legend(['train', 'test'], loc='upper left')
# plt.show()