Skip to content
This repository was archived by the owner on Dec 21, 2023. It is now read-only.

Commit 82208b3

Browse files
authored
Begin adding checkpointing support for Activity Classification (#3030)
1 parent 10bb08b commit 82208b3

File tree

3 files changed

+168
-65
lines changed

3 files changed

+168
-65
lines changed

src/toolkits/activity_classification/activity_classifier.cpp

Lines changed: 128 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ void activity_classifier::save_impl(oarchive& oarc) const {
109109
variant_deep_save(state, oarc);
110110

111111
// Save neural net weights.
112-
oarc << nn_spec_->export_params_view();
112+
oarc << read_model_spec()->export_params_view();
113113
}
114114

115115
void activity_classifier::load_version(iarchive& iarc, size_t version) {
@@ -122,6 +122,7 @@ void activity_classifier::load_version(iarchive& iarc, size_t version) {
122122
bool use_random_init = false;
123123
nn_spec_ = init_model(use_random_init);
124124
nn_spec_->update_params(nn_params);
125+
nn_spec_synchronized_ = true;
125126
}
126127

127128
void activity_classifier::init_options(
@@ -294,11 +295,10 @@ std::tuple<float, float> activity_classifier::compute_validation_metrics(
294295
return std::make_tuple(average_val_accuracy, average_val_loss);
295296
}
296297

297-
void activity_classifier::init_table_printer(bool has_validation,
298-
bool show_loss) {
298+
void activity_classifier::init_table_printer(bool has_validation) {
299299
if (read_state<bool>("verbose")) {
300300
if (has_validation) {
301-
if (show_loss) {
301+
if (show_loss_) {
302302
training_table_printer_.reset(
303303
new table_printer({{"Iteration", 12},
304304
{"Train Accuracy", 12},
@@ -315,7 +315,7 @@ void activity_classifier::init_table_printer(bool has_validation,
315315
{"Elapsed Time", 12}}));
316316
}
317317
} else {
318-
if (show_loss) {
318+
if (show_loss_) {
319319
training_table_printer_.reset(
320320
new table_printer({{"Iteration", 12},
321321
{"Train Accuracy", 12},
@@ -338,38 +338,55 @@ void activity_classifier::train(
338338
turi::timer time_object;
339339
time_object.start();
340340

341-
bool show_loss = true;
342-
auto show_loss_it = opts.find("_show_loss");
343-
if (show_loss_it != opts.end()) {
344-
show_loss = show_loss_it->second;
345-
}
346-
347341
// Instantiate the training dependencies: data iterator, compute context,
348342
// backend NN model.
349-
init_train(data, target_column_name, session_id_column_name, validation_data,
350-
opts);
343+
init_training(data, target_column_name, session_id_column_name,
344+
validation_data, opts);
351345

352346
// Perform all the iterations at once.
353347
flex_int max_iterations = read_state<flex_int>("max_iterations");
354348
while (read_state<flex_int>("training_iterations") < max_iterations) {
355-
perform_training_iteration(show_loss);
349+
iterate_training();
350+
}
351+
352+
finalize_training();
353+
354+
variant_map_type state_update;
355+
state_update["training_time"] = time_object.current_time();
356+
add_or_update_state(state_update);
357+
358+
logprogress_stream << "Training complete" << std::endl;
359+
logprogress_stream << "Total Time Spent: "
360+
<< read_state<flex_float>("training_time") << std::endl;
361+
}
362+
363+
// iterate_training() performs a complete epoch, synchronizing with the GPU. As
364+
// a result, no explicit synchronization is needed. We expose this method just
365+
// for consistency with other models, like object_detector.
366+
void activity_classifier::synchronize_training() {}
367+
368+
const model_spec* activity_classifier::read_model_spec() const {
369+
if (training_model_ && !nn_spec_synchronized_) {
370+
float_array_map trained_weights = training_model_->export_weights();
371+
nn_spec_->update_params(trained_weights);
372+
nn_spec_synchronized_ = true;
356373
}
374+
return nn_spec_.get();
375+
}
357376

377+
void activity_classifier::finalize_training() {
358378
// Finish printing progress.
359379
if (training_table_printer_) {
360380
training_table_printer_->print_footer();
361381
training_table_printer_.reset();
362382
}
363383

364-
// Sync trained weights to our local storage of the NN weights.
365-
float_array_map trained_weights = training_model_->export_weights();
366-
nn_spec_->update_params(trained_weights);
367-
368384
variant_map_type state_update;
369385

370386
// Update the state with recall, precision and confusion matrix for training
371387
// data
372388
gl_sarray train_predictions = predict(training_data_, "probability_vector");
389+
flex_string target_column_name = read_state<flex_string>("target");
373390
variant_map_type train_metric = evaluation::compute_classifier_metrics(
374391
training_data_, target_column_name, "report", train_predictions,
375392
{{"classes", read_state<flex_list>("classes")}});
@@ -392,13 +409,7 @@ void activity_classifier::train(
392409
}
393410

394411
state_update["verbose"] = read_state<bool>("verbose");
395-
state_update["num_examples"] = data.size();
396-
state_update["training_time"] = time_object.current_time();
397-
398412
add_or_update_state(state_update);
399-
logprogress_stream << "Training complete" << std::endl;
400-
logprogress_stream << "Total Time Spent: " << read_state<flex_float>("training_time") << std::endl;
401-
402413
}
403414

404415
gl_sarray activity_classifier::predict(gl_sframe data,
@@ -411,7 +422,8 @@ gl_sarray activity_classifier::predict(gl_sframe data,
411422

412423
// Bind the data to a data iterator.
413424
std::unique_ptr<data_iterator> data_it =
414-
create_iterator(data, /* requires_labels */ false, /* is_train */ false,
425+
create_iterator(data, /* requires_labels */ false,
426+
/* infer_class_labels */ false, /* is_train */ false,
415427
/* use_data_augmentation */ false);
416428

417429
// Accumulate the class probabilities for each prediction window.
@@ -456,7 +468,8 @@ gl_sframe activity_classifier::predict_per_window(gl_sframe data,
456468

457469
// Bind the data to a data iterator.
458470
std::unique_ptr<data_iterator> data_it =
459-
create_iterator(data, /* requires_labels */ false, /* is_train */ false,
471+
create_iterator(data, /* requires_labels */ false,
472+
/* infer_class_labels */ false, /* is_train */ false,
460473
/* use_data_augmentation */ false);
461474

462475
// Accumulate the class probabilities for each prediction window.
@@ -496,7 +509,8 @@ gl_sframe activity_classifier::classify(gl_sframe data,
496509

497510
// perform inference
498511
std::unique_ptr<data_iterator> data_it =
499-
create_iterator(data, /* requires_labels */ false, /* is_train */ false,
512+
create_iterator(data, /* requires_labels */ false,
513+
/* infer_class_labels */ false, /* is_train */ false,
500514
/* use_data_augmentation */ false);
501515
gl_sframe raw_preds_per_window = perform_inference(data_it.get());
502516

@@ -577,7 +591,8 @@ gl_sframe activity_classifier::predict_topk(gl_sframe data,
577591

578592
// data inference
579593
std::unique_ptr<data_iterator> data_it =
580-
create_iterator(data, /* requires_labels */ false, /* is_train */ false,
594+
create_iterator(data, /* requires_labels */ false,
595+
/* infer_class_labels */ false, /* is_train */ false,
581596
/* use_data_augmentation */ false);
582597
gl_sframe raw_preds_per_window = perform_inference(data_it.get());
583598

@@ -704,12 +719,9 @@ std::shared_ptr<MLModelWrapper> activity_classifier::export_to_coreml(
704719
{
705720
std::shared_ptr<MLModelWrapper> model_wrapper =
706721
export_activity_classifier_model(
707-
*nn_spec_,
708-
read_state<flex_int>("prediction_window"),
709-
read_state<flex_list>("features"),
710-
LSTM_HIDDEN_SIZE,
711-
read_state<flex_list>("classes"),
712-
read_state<flex_string>("target"));
722+
*read_model_spec(), read_state<flex_int>("prediction_window"),
723+
read_state<flex_list>("features"), LSTM_HIDDEN_SIZE,
724+
read_state<flex_list>("classes"), read_state<flex_string>("target"));
713725

714726
const flex_list& features_list = read_state<flex_list>("features");
715727
const flex_string features_string =
@@ -839,16 +851,17 @@ void activity_classifier::import_from_custom_model(
839851
bool use_random_init = false;
840852
nn_spec_ = init_model(use_random_init);
841853
nn_spec_->update_params(nn_params);
854+
nn_spec_synchronized_ = true;
842855
model_data.erase(model_iter);
843856
}
844857

845858
std::unique_ptr<data_iterator> activity_classifier::create_iterator(
846-
gl_sframe data, bool requires_labels, bool is_train,
847-
bool use_data_augmentation) const {
859+
gl_sframe data, bool requires_labels, bool infer_class_labels,
860+
bool is_train, bool use_data_augmentation) const {
848861
data_iterator::parameters data_params;
849862
data_params.data = std::move(data);
850863

851-
if (!is_train) {
864+
if (!infer_class_labels) {
852865
data_params.class_labels = read_state<flex_list>("classes");
853866
}
854867

@@ -1020,11 +1033,10 @@ activity_classifier::init_data(gl_sframe data, variant_type validation_data,
10201033
return std::make_tuple(train_data,val_data);
10211034
}
10221035

1023-
void activity_classifier::init_train(
1036+
void activity_classifier::init_training(
10241037
gl_sframe data, std::string target_column_name,
10251038
std::string session_id_column_name, variant_type validation_data,
1026-
std::map<std::string, flexible_type> opts)
1027-
{
1039+
std::map<std::string, flexible_type> opts) {
10281040
// Extract feature names from options.
10291041
std::vector<std::string> feature_column_names;
10301042
auto features_it = opts.find("features");
@@ -1037,10 +1049,9 @@ void activity_classifier::init_train(
10371049
opts.erase(features_it);
10381050
}
10391051

1040-
bool show_loss = true;
10411052
auto show_loss_it = opts.find("_show_loss");
10421053
if (show_loss_it != opts.end()) {
1043-
show_loss = show_loss_it->second;
1054+
show_loss_ = show_loss_it->second;
10441055
opts.erase(show_loss_it);
10451056
}
10461057

@@ -1059,7 +1070,7 @@ void activity_classifier::init_train(
10591070
init_data(data, validation_data, session_id_column_name);
10601071

10611072
// Begin printing progress.
1062-
init_table_printer(!validation_data_.empty(), show_loss);
1073+
init_table_printer(!validation_data_.empty());
10631074

10641075
add_or_update_state({{"session_id", session_id_column_name},
10651076
{"target", target_column_name},
@@ -1070,15 +1081,17 @@ void activity_classifier::init_train(
10701081
bool use_data_augmentation = read_state<bool>("use_data_augmentation");
10711082
training_data_iterator_ =
10721083
create_iterator(training_data_, /* requires_labels */ true,
1073-
/* is_train */ true, use_data_augmentation);
1084+
/* infer_class_labels */ true, /* is_train */ true,
1085+
use_data_augmentation);
10741086

10751087
add_or_update_state({{"classes", training_data_iterator_->class_labels()}});
10761088

10771089
// Bind the validation data to a data iterator.
10781090
if (!validation_data_.empty()) {
1079-
validation_data_iterator_ = create_iterator(
1080-
validation_data_, /* requires_labels */ true, /* is_train */ false,
1081-
/* use_data_augmentation */ false);
1091+
validation_data_iterator_ =
1092+
create_iterator(validation_data_, /* requires_labels */ true,
1093+
/* infer_class_labels */ false, /* is_train */ false,
1094+
/* use_data_augmentation */ false);
10821095
} else {
10831096
validation_data_iterator_ = nullptr;
10841097
}
@@ -1097,6 +1110,7 @@ void activity_classifier::init_train(
10971110
add_or_update_state({
10981111
{"features", training_data_iterator_->feature_names()},
10991112
{"num_classes", training_data_iterator_->class_labels().size()},
1113+
{"num_examples", training_data_.size()},
11001114
{"num_features", training_data_iterator_->feature_names().size()},
11011115
{"num_sessions", training_data_iterator_->num_sessions()},
11021116
{"training_iterations", 0},
@@ -1106,6 +1120,7 @@ void activity_classifier::init_train(
11061120
// the data iterator.
11071121
bool use_random_init = true;
11081122
nn_spec_ = init_model(use_random_init);
1123+
nn_spec_synchronized_ = true;
11091124

11101125
// Defining the struct for ac parameters
11111126
ac_parameters ac_params;
@@ -1116,7 +1131,7 @@ void activity_classifier::init_train(
11161131
ac_params.num_predictions_per_chunk = NUM_PREDICTIONS_PER_CHUNK;
11171132
ac_params.random_seed = read_state<int>("random_seed");
11181133
ac_params.is_training = true;
1119-
ac_params.weights = nn_spec_->export_params_view();
1134+
ac_params.weights = read_model_spec()->export_params_view();
11201135

11211136
// Instantiate the NN backend.
11221137
training_model_ =
@@ -1128,11 +1143,72 @@ void activity_classifier::init_train(
11281143
}
11291144
}
11301145

1131-
void activity_classifier::perform_training_iteration(bool show_loss) {
1146+
void activity_classifier::resume_training(gl_sframe data,
1147+
variant_type validation_data) {
1148+
// Perform validation split if necessary.
1149+
flex_string session_id_column_name = read_state<flex_string>("session_id");
1150+
std::tie(training_data_, validation_data_) =
1151+
init_data(data, validation_data, session_id_column_name);
1152+
1153+
// Begin printing progress.
1154+
init_table_printer(!validation_data_.empty());
1155+
1156+
// Bind the data to a data iterator.
1157+
bool use_data_augmentation = read_state<bool>("use_data_augmentation");
1158+
training_data_iterator_ =
1159+
create_iterator(training_data_, /* requires_labels */ true,
1160+
/* infer_class_labels */ false, /* is_train */ true,
1161+
use_data_augmentation);
1162+
1163+
// Bind the validation data to a data iterator.
1164+
if (!validation_data_.empty()) {
1165+
validation_data_iterator_ =
1166+
create_iterator(validation_data_, /* requires_labels */ true,
1167+
/* infer_class_labels */ false, /* is_train */ false,
1168+
/* use_data_augmentation */ false);
1169+
} else {
1170+
validation_data_iterator_ = nullptr;
1171+
}
1172+
1173+
// Instantiate the compute context.
1174+
training_compute_context_ = create_compute_context();
1175+
if (training_compute_context_ == nullptr) {
1176+
log_and_throw("No neural network compute context provided");
1177+
}
1178+
1179+
// Report to the user what GPU(s) is being used.
1180+
std::vector<std::string> gpu_names = training_compute_context_->gpu_names();
1181+
print_training_device(gpu_names);
1182+
1183+
// Defining the struct for ac parameters
1184+
ac_parameters ac_params;
1185+
ac_params.batch_size = read_state<int>("batch_size");
1186+
ac_params.num_features = read_state<int>("num_features");
1187+
ac_params.prediction_window = read_state<int>("prediction_window");
1188+
ac_params.num_classes = read_state<int>("num_classes");
1189+
ac_params.num_predictions_per_chunk = NUM_PREDICTIONS_PER_CHUNK;
1190+
ac_params.random_seed = read_state<int>("random_seed");
1191+
ac_params.is_training = true;
1192+
ac_params.weights = read_model_spec()->export_params_view();
1193+
1194+
// Instantiate the NN backend.
1195+
training_model_ =
1196+
training_compute_context_->create_activity_classifier(ac_params);
1197+
1198+
// Print the header last, after any logging triggered by initialization above.
1199+
if (training_table_printer_) {
1200+
training_table_printer_->print_header();
1201+
}
1202+
}
1203+
1204+
void activity_classifier::iterate_training() {
11321205
// Training must have been initialized.
11331206
ASSERT_TRUE(training_data_iterator_ != nullptr);
11341207
ASSERT_TRUE(training_model_ != nullptr);
11351208

1209+
// Invalidate any local copy of the model.
1210+
nn_spec_synchronized_ = false;
1211+
11361212
const size_t batch_size = read_state<flex_int>("batch_size");
11371213
const size_t iteration_idx = read_state<flex_int>("training_iterations");
11381214

@@ -1224,7 +1300,7 @@ void activity_classifier::perform_training_iteration(bool show_loss) {
12241300

12251301
if (training_table_printer_) {
12261302
if (validation_data_iterator_) {
1227-
if (show_loss) {
1303+
if (show_loss_) {
12281304
training_table_printer_->print_progress_row(
12291305
iteration_idx, iteration_idx + 1, average_batch_accuracy,
12301306
average_batch_loss, average_val_accuracy, average_val_loss,
@@ -1235,7 +1311,7 @@ void activity_classifier::perform_training_iteration(bool show_loss) {
12351311
average_val_accuracy, progress_time());
12361312
}
12371313
} else {
1238-
if (show_loss) {
1314+
if (show_loss_) {
12391315
training_table_printer_->print_progress_row(
12401316
iteration_idx, iteration_idx + 1, average_batch_accuracy,
12411317
average_batch_loss, progress_time());
@@ -1273,7 +1349,7 @@ gl_sframe activity_classifier::perform_inference(data_iterator *data) const {
12731349
ac_params.num_predictions_per_chunk = NUM_PREDICTIONS_PER_CHUNK;
12741350
ac_params.random_seed = read_state<int>("random_seed");
12751351
ac_params.is_training = false;
1276-
ac_params.weights = nn_spec_->export_params_view();
1352+
ac_params.weights = read_model_spec()->export_params_view();
12771353

12781354
// Initialize the NN backend.
12791355
std::unique_ptr<compute_context> ctx = create_compute_context();

0 commit comments

Comments
 (0)