@@ -109,7 +109,7 @@ void activity_classifier::save_impl(oarchive& oarc) const {
109
109
variant_deep_save (state, oarc);
110
110
111
111
// Save neural net weights.
112
- oarc << nn_spec_ ->export_params_view ();
112
+ oarc << read_model_spec () ->export_params_view ();
113
113
}
114
114
115
115
void activity_classifier::load_version (iarchive& iarc, size_t version) {
@@ -122,6 +122,7 @@ void activity_classifier::load_version(iarchive& iarc, size_t version) {
122
122
bool use_random_init = false ;
123
123
nn_spec_ = init_model (use_random_init);
124
124
nn_spec_->update_params (nn_params);
125
+ nn_spec_synchronized_ = true ;
125
126
}
126
127
127
128
void activity_classifier::init_options (
@@ -294,11 +295,10 @@ std::tuple<float, float> activity_classifier::compute_validation_metrics(
294
295
return std::make_tuple (average_val_accuracy, average_val_loss);
295
296
}
296
297
297
- void activity_classifier::init_table_printer (bool has_validation,
298
- bool show_loss) {
298
+ void activity_classifier::init_table_printer (bool has_validation) {
299
299
if (read_state<bool >(" verbose" )) {
300
300
if (has_validation) {
301
- if (show_loss ) {
301
+ if (show_loss_ ) {
302
302
training_table_printer_.reset (
303
303
new table_printer ({{" Iteration" , 12 },
304
304
{" Train Accuracy" , 12 },
@@ -315,7 +315,7 @@ void activity_classifier::init_table_printer(bool has_validation,
315
315
{" Elapsed Time" , 12 }}));
316
316
}
317
317
} else {
318
- if (show_loss ) {
318
+ if (show_loss_ ) {
319
319
training_table_printer_.reset (
320
320
new table_printer ({{" Iteration" , 12 },
321
321
{" Train Accuracy" , 12 },
@@ -338,38 +338,55 @@ void activity_classifier::train(
338
338
turi::timer time_object;
339
339
time_object.start ();
340
340
341
- bool show_loss = true ;
342
- auto show_loss_it = opts.find (" _show_loss" );
343
- if (show_loss_it != opts.end ()) {
344
- show_loss = show_loss_it->second ;
345
- }
346
-
347
341
// Instantiate the training dependencies: data iterator, compute context,
348
342
// backend NN model.
349
- init_train (data, target_column_name, session_id_column_name, validation_data ,
350
- opts);
343
+ init_training (data, target_column_name, session_id_column_name,
344
+ validation_data, opts);
351
345
352
346
// Perform all the iterations at once.
353
347
flex_int max_iterations = read_state<flex_int>(" max_iterations" );
354
348
while (read_state<flex_int>(" training_iterations" ) < max_iterations) {
355
- perform_training_iteration (show_loss);
349
+ iterate_training ();
350
+ }
351
+
352
+ finalize_training ();
353
+
354
+ variant_map_type state_update;
355
+ state_update[" training_time" ] = time_object.current_time ();
356
+ add_or_update_state (state_update);
357
+
358
+ logprogress_stream << " Training complete" << std::endl;
359
+ logprogress_stream << " Total Time Spent: "
360
+ << read_state<flex_float>(" training_time" ) << std::endl;
361
+ }
362
+
363
+ // iterate_training() performs a complete epoch, synchronizing with the GPU. As
364
+ // a result, no explicit synchronization is needed. We expose this method just
365
+ // for consistency with other models, like object_detector.
366
+ void activity_classifier::synchronize_training () {}
367
+
368
+ const model_spec* activity_classifier::read_model_spec () const {
369
+ if (training_model_ && !nn_spec_synchronized_) {
370
+ float_array_map trained_weights = training_model_->export_weights ();
371
+ nn_spec_->update_params (trained_weights);
372
+ nn_spec_synchronized_ = true ;
356
373
}
374
+ return nn_spec_.get ();
375
+ }
357
376
377
+ void activity_classifier::finalize_training () {
358
378
// Finish printing progress.
359
379
if (training_table_printer_) {
360
380
training_table_printer_->print_footer ();
361
381
training_table_printer_.reset ();
362
382
}
363
383
364
- // Sync trained weights to our local storage of the NN weights.
365
- float_array_map trained_weights = training_model_->export_weights ();
366
- nn_spec_->update_params (trained_weights);
367
-
368
384
variant_map_type state_update;
369
385
370
386
// Update the state with recall, precision and confusion matrix for training
371
387
// data
372
388
gl_sarray train_predictions = predict (training_data_, " probability_vector" );
389
+ flex_string target_column_name = read_state<flex_string>(" target" );
373
390
variant_map_type train_metric = evaluation::compute_classifier_metrics (
374
391
training_data_, target_column_name, " report" , train_predictions,
375
392
{{" classes" , read_state<flex_list>(" classes" )}});
@@ -392,13 +409,7 @@ void activity_classifier::train(
392
409
}
393
410
394
411
state_update[" verbose" ] = read_state<bool >(" verbose" );
395
- state_update[" num_examples" ] = data.size ();
396
- state_update[" training_time" ] = time_object.current_time ();
397
-
398
412
add_or_update_state (state_update);
399
- logprogress_stream << " Training complete" << std::endl;
400
- logprogress_stream << " Total Time Spent: " << read_state<flex_float>(" training_time" ) << std::endl;
401
-
402
413
}
403
414
404
415
gl_sarray activity_classifier::predict (gl_sframe data,
@@ -411,7 +422,8 @@ gl_sarray activity_classifier::predict(gl_sframe data,
411
422
412
423
// Bind the data to a data iterator.
413
424
std::unique_ptr<data_iterator> data_it =
414
- create_iterator (data, /* requires_labels */ false , /* is_train */ false ,
425
+ create_iterator (data, /* requires_labels */ false ,
426
+ /* infer_class_labels */ false , /* is_train */ false ,
415
427
/* use_data_augmentation */ false );
416
428
417
429
// Accumulate the class probabilities for each prediction window.
@@ -456,7 +468,8 @@ gl_sframe activity_classifier::predict_per_window(gl_sframe data,
456
468
457
469
// Bind the data to a data iterator.
458
470
std::unique_ptr<data_iterator> data_it =
459
- create_iterator (data, /* requires_labels */ false , /* is_train */ false ,
471
+ create_iterator (data, /* requires_labels */ false ,
472
+ /* infer_class_labels */ false , /* is_train */ false ,
460
473
/* use_data_augmentation */ false );
461
474
462
475
// Accumulate the class probabilities for each prediction window.
@@ -496,7 +509,8 @@ gl_sframe activity_classifier::classify(gl_sframe data,
496
509
497
510
// perform inference
498
511
std::unique_ptr<data_iterator> data_it =
499
- create_iterator (data, /* requires_labels */ false , /* is_train */ false ,
512
+ create_iterator (data, /* requires_labels */ false ,
513
+ /* infer_class_labels */ false , /* is_train */ false ,
500
514
/* use_data_augmentation */ false );
501
515
gl_sframe raw_preds_per_window = perform_inference (data_it.get ());
502
516
@@ -577,7 +591,8 @@ gl_sframe activity_classifier::predict_topk(gl_sframe data,
577
591
578
592
// data inference
579
593
std::unique_ptr<data_iterator> data_it =
580
- create_iterator (data, /* requires_labels */ false , /* is_train */ false ,
594
+ create_iterator (data, /* requires_labels */ false ,
595
+ /* infer_class_labels */ false , /* is_train */ false ,
581
596
/* use_data_augmentation */ false );
582
597
gl_sframe raw_preds_per_window = perform_inference (data_it.get ());
583
598
@@ -704,12 +719,9 @@ std::shared_ptr<MLModelWrapper> activity_classifier::export_to_coreml(
704
719
{
705
720
std::shared_ptr<MLModelWrapper> model_wrapper =
706
721
export_activity_classifier_model (
707
- *nn_spec_,
708
- read_state<flex_int>(" prediction_window" ),
709
- read_state<flex_list>(" features" ),
710
- LSTM_HIDDEN_SIZE,
711
- read_state<flex_list>(" classes" ),
712
- read_state<flex_string>(" target" ));
722
+ *read_model_spec (), read_state<flex_int>(" prediction_window" ),
723
+ read_state<flex_list>(" features" ), LSTM_HIDDEN_SIZE,
724
+ read_state<flex_list>(" classes" ), read_state<flex_string>(" target" ));
713
725
714
726
const flex_list& features_list = read_state<flex_list>(" features" );
715
727
const flex_string features_string =
@@ -839,16 +851,17 @@ void activity_classifier::import_from_custom_model(
839
851
bool use_random_init = false ;
840
852
nn_spec_ = init_model (use_random_init);
841
853
nn_spec_->update_params (nn_params);
854
+ nn_spec_synchronized_ = true ;
842
855
model_data.erase (model_iter);
843
856
}
844
857
845
858
std::unique_ptr<data_iterator> activity_classifier::create_iterator (
846
- gl_sframe data, bool requires_labels, bool is_train ,
847
- bool use_data_augmentation) const {
859
+ gl_sframe data, bool requires_labels, bool infer_class_labels ,
860
+ bool is_train, bool use_data_augmentation) const {
848
861
data_iterator::parameters data_params;
849
862
data_params.data = std::move (data);
850
863
851
- if (!is_train ) {
864
+ if (!infer_class_labels ) {
852
865
data_params.class_labels = read_state<flex_list>(" classes" );
853
866
}
854
867
@@ -1020,11 +1033,10 @@ activity_classifier::init_data(gl_sframe data, variant_type validation_data,
1020
1033
return std::make_tuple (train_data,val_data);
1021
1034
}
1022
1035
1023
- void activity_classifier::init_train (
1036
+ void activity_classifier::init_training (
1024
1037
gl_sframe data, std::string target_column_name,
1025
1038
std::string session_id_column_name, variant_type validation_data,
1026
- std::map<std::string, flexible_type> opts)
1027
- {
1039
+ std::map<std::string, flexible_type> opts) {
1028
1040
// Extract feature names from options.
1029
1041
std::vector<std::string> feature_column_names;
1030
1042
auto features_it = opts.find (" features" );
@@ -1037,10 +1049,9 @@ void activity_classifier::init_train(
1037
1049
opts.erase (features_it);
1038
1050
}
1039
1051
1040
- bool show_loss = true ;
1041
1052
auto show_loss_it = opts.find (" _show_loss" );
1042
1053
if (show_loss_it != opts.end ()) {
1043
- show_loss = show_loss_it->second ;
1054
+ show_loss_ = show_loss_it->second ;
1044
1055
opts.erase (show_loss_it);
1045
1056
}
1046
1057
@@ -1059,7 +1070,7 @@ void activity_classifier::init_train(
1059
1070
init_data (data, validation_data, session_id_column_name);
1060
1071
1061
1072
// Begin printing progress.
1062
- init_table_printer (!validation_data_.empty (), show_loss );
1073
+ init_table_printer (!validation_data_.empty ());
1063
1074
1064
1075
add_or_update_state ({{" session_id" , session_id_column_name},
1065
1076
{" target" , target_column_name},
@@ -1070,15 +1081,17 @@ void activity_classifier::init_train(
1070
1081
bool use_data_augmentation = read_state<bool >(" use_data_augmentation" );
1071
1082
training_data_iterator_ =
1072
1083
create_iterator (training_data_, /* requires_labels */ true ,
1073
- /* is_train */ true , use_data_augmentation);
1084
+ /* infer_class_labels */ true , /* is_train */ true ,
1085
+ use_data_augmentation);
1074
1086
1075
1087
add_or_update_state ({{" classes" , training_data_iterator_->class_labels ()}});
1076
1088
1077
1089
// Bind the validation data to a data iterator.
1078
1090
if (!validation_data_.empty ()) {
1079
- validation_data_iterator_ = create_iterator (
1080
- validation_data_, /* requires_labels */ true , /* is_train */ false ,
1081
- /* use_data_augmentation */ false );
1091
+ validation_data_iterator_ =
1092
+ create_iterator (validation_data_, /* requires_labels */ true ,
1093
+ /* infer_class_labels */ false , /* is_train */ false ,
1094
+ /* use_data_augmentation */ false );
1082
1095
} else {
1083
1096
validation_data_iterator_ = nullptr ;
1084
1097
}
@@ -1097,6 +1110,7 @@ void activity_classifier::init_train(
1097
1110
add_or_update_state ({
1098
1111
{" features" , training_data_iterator_->feature_names ()},
1099
1112
{" num_classes" , training_data_iterator_->class_labels ().size ()},
1113
+ {" num_examples" , training_data_.size ()},
1100
1114
{" num_features" , training_data_iterator_->feature_names ().size ()},
1101
1115
{" num_sessions" , training_data_iterator_->num_sessions ()},
1102
1116
{" training_iterations" , 0 },
@@ -1106,6 +1120,7 @@ void activity_classifier::init_train(
1106
1120
// the data iterator.
1107
1121
bool use_random_init = true ;
1108
1122
nn_spec_ = init_model (use_random_init);
1123
+ nn_spec_synchronized_ = true ;
1109
1124
1110
1125
// Defining the struct for ac parameters
1111
1126
ac_parameters ac_params;
@@ -1116,7 +1131,7 @@ void activity_classifier::init_train(
1116
1131
ac_params.num_predictions_per_chunk = NUM_PREDICTIONS_PER_CHUNK;
1117
1132
ac_params.random_seed = read_state<int >(" random_seed" );
1118
1133
ac_params.is_training = true ;
1119
- ac_params.weights = nn_spec_ ->export_params_view ();
1134
+ ac_params.weights = read_model_spec () ->export_params_view ();
1120
1135
1121
1136
// Instantiate the NN backend.
1122
1137
training_model_ =
@@ -1128,11 +1143,72 @@ void activity_classifier::init_train(
1128
1143
}
1129
1144
}
1130
1145
1131
- void activity_classifier::perform_training_iteration (bool show_loss) {
1146
+ void activity_classifier::resume_training (gl_sframe data,
1147
+ variant_type validation_data) {
1148
+ // Perform validation split if necessary.
1149
+ flex_string session_id_column_name = read_state<flex_string>(" session_id" );
1150
+ std::tie (training_data_, validation_data_) =
1151
+ init_data (data, validation_data, session_id_column_name);
1152
+
1153
+ // Begin printing progress.
1154
+ init_table_printer (!validation_data_.empty ());
1155
+
1156
+ // Bind the data to a data iterator.
1157
+ bool use_data_augmentation = read_state<bool >(" use_data_augmentation" );
1158
+ training_data_iterator_ =
1159
+ create_iterator (training_data_, /* requires_labels */ true ,
1160
+ /* infer_class_labels */ false , /* is_train */ true ,
1161
+ use_data_augmentation);
1162
+
1163
+ // Bind the validation data to a data iterator.
1164
+ if (!validation_data_.empty ()) {
1165
+ validation_data_iterator_ =
1166
+ create_iterator (validation_data_, /* requires_labels */ true ,
1167
+ /* infer_class_labels */ false , /* is_train */ false ,
1168
+ /* use_data_augmentation */ false );
1169
+ } else {
1170
+ validation_data_iterator_ = nullptr ;
1171
+ }
1172
+
1173
+ // Instantiate the compute context.
1174
+ training_compute_context_ = create_compute_context ();
1175
+ if (training_compute_context_ == nullptr ) {
1176
+ log_and_throw (" No neural network compute context provided" );
1177
+ }
1178
+
1179
+ // Report to the user what GPU(s) is being used.
1180
+ std::vector<std::string> gpu_names = training_compute_context_->gpu_names ();
1181
+ print_training_device (gpu_names);
1182
+
1183
+ // Defining the struct for ac parameters
1184
+ ac_parameters ac_params;
1185
+ ac_params.batch_size = read_state<int >(" batch_size" );
1186
+ ac_params.num_features = read_state<int >(" num_features" );
1187
+ ac_params.prediction_window = read_state<int >(" prediction_window" );
1188
+ ac_params.num_classes = read_state<int >(" num_classes" );
1189
+ ac_params.num_predictions_per_chunk = NUM_PREDICTIONS_PER_CHUNK;
1190
+ ac_params.random_seed = read_state<int >(" random_seed" );
1191
+ ac_params.is_training = true ;
1192
+ ac_params.weights = read_model_spec ()->export_params_view ();
1193
+
1194
+ // Instantiate the NN backend.
1195
+ training_model_ =
1196
+ training_compute_context_->create_activity_classifier (ac_params);
1197
+
1198
+ // Print the header last, after any logging triggered by initialization above.
1199
+ if (training_table_printer_) {
1200
+ training_table_printer_->print_header ();
1201
+ }
1202
+ }
1203
+
1204
+ void activity_classifier::iterate_training () {
1132
1205
// Training must have been initialized.
1133
1206
ASSERT_TRUE (training_data_iterator_ != nullptr );
1134
1207
ASSERT_TRUE (training_model_ != nullptr );
1135
1208
1209
+ // Invalidate any local copy of the model.
1210
+ nn_spec_synchronized_ = false ;
1211
+
1136
1212
const size_t batch_size = read_state<flex_int>(" batch_size" );
1137
1213
const size_t iteration_idx = read_state<flex_int>(" training_iterations" );
1138
1214
@@ -1224,7 +1300,7 @@ void activity_classifier::perform_training_iteration(bool show_loss) {
1224
1300
1225
1301
if (training_table_printer_) {
1226
1302
if (validation_data_iterator_) {
1227
- if (show_loss ) {
1303
+ if (show_loss_ ) {
1228
1304
training_table_printer_->print_progress_row (
1229
1305
iteration_idx, iteration_idx + 1 , average_batch_accuracy,
1230
1306
average_batch_loss, average_val_accuracy, average_val_loss,
@@ -1235,7 +1311,7 @@ void activity_classifier::perform_training_iteration(bool show_loss) {
1235
1311
average_val_accuracy, progress_time ());
1236
1312
}
1237
1313
} else {
1238
- if (show_loss ) {
1314
+ if (show_loss_ ) {
1239
1315
training_table_printer_->print_progress_row (
1240
1316
iteration_idx, iteration_idx + 1 , average_batch_accuracy,
1241
1317
average_batch_loss, progress_time ());
@@ -1273,7 +1349,7 @@ gl_sframe activity_classifier::perform_inference(data_iterator *data) const {
1273
1349
ac_params.num_predictions_per_chunk = NUM_PREDICTIONS_PER_CHUNK;
1274
1350
ac_params.random_seed = read_state<int >(" random_seed" );
1275
1351
ac_params.is_training = false ;
1276
- ac_params.weights = nn_spec_ ->export_params_view ();
1352
+ ac_params.weights = read_model_spec () ->export_params_view ();
1277
1353
1278
1354
// Initialize the NN backend.
1279
1355
std::unique_ptr<compute_context> ctx = create_compute_context ();
0 commit comments