Skip to content

Commit 942e9bb

Browse files
committed
Fix DeepSSM "restore defaults", Fix loss function.
1 parent 5bcd95d commit 942e9bb

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

Studio/DeepSSM/DeepSSMParameters.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -358,24 +358,35 @@ void DeepSSMParameters::restore_split_defaults() {
358358
params_.remove_entry(Keys::TRAINING_SPLIT);
359359
params_.remove_entry(Keys::VALIDATION_SPLIT);
360360
params_.remove_entry(Keys::TESTING_SPLIT);
361+
params_.remove_entry(Keys::AUG_PERCENT_VARIABILITY);
362+
params_.remove_entry(Keys::SPACING);
361363
}
362364

363365
//---------------------------------------------------------------------------
364366
void DeepSSMParameters::restore_augmentation_defaults() {
365367
params_.remove_entry(Keys::AUG_NUM_SAMPLES);
366-
params_.remove_entry(Keys::AUG_NUM_DIMS);
367-
params_.remove_entry(Keys::AUG_PERCENT_VARIABILITY);
368368
params_.remove_entry(Keys::AUG_SAMPLER_TYPE);
369369
}
370370

371371
//---------------------------------------------------------------------------
372372
void DeepSSMParameters::restore_training_defaults() {
373+
params_.remove_entry(Keys::LOSS_FUNCTION);
373374
params_.remove_entry(Keys::TRAIN_EPOCHS);
374375
params_.remove_entry(Keys::TRAIN_LEARNING_RATE);
376+
params_.remove_entry(Keys::TRAIN_BATCH_SIZE);
375377
params_.remove_entry(Keys::TRAIN_DECAY_LEARNING_RATE);
376378
params_.remove_entry(Keys::TRAIN_FINE_TUNING);
377379
params_.remove_entry(Keys::TRAIN_FINE_TUNING_EPOCHS);
378-
params_.remove_entry(Keys::TRAIN_BATCH_SIZE);
379380
params_.remove_entry(Keys::TRAIN_FINE_TUNING_LEARNING_RATE);
381+
382+
params_.remove_entry(Keys::TL_NET_ENABLED);
383+
params_.remove_entry(Keys::TL_NET_AE_EPOCHS);
384+
params_.remove_entry(Keys::TL_NET_TF_EPOCHS);
385+
params_.remove_entry(Keys::TL_NET_JOINT_EPOCHS);
386+
params_.remove_entry(Keys::TL_NET_ALPHA);
387+
params_.remove_entry(Keys::TL_NET_A_AE);
388+
params_.remove_entry(Keys::TL_NET_C_AE);
389+
params_.remove_entry(Keys::TL_NET_A_LAT);
390+
params_.remove_entry(Keys::TL_NET_C_LAT);
380391
}
381392
} // namespace shapeworks

Studio/DeepSSM/DeepSSMTool.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ void DeepSSMTool::load_params() {
173173
ui_->tl_lat_a->setText(QString::number(params.get_tl_net_a_lat()));
174174
ui_->tl_lat_c->setText(QString::number(params.get_tl_net_c_lat()));
175175

176+
ui_->loss_function->setCurrentText(QString::fromStdString(params.get_loss_function()));
176177
update_panels();
177178
update_meshes();
178179
}
@@ -209,6 +210,8 @@ void DeepSSMTool::store_params() {
209210
params.set_tl_net_a_lat(ui_->tl_lat_a->text().toDouble());
210211
params.set_tl_net_c_lat(ui_->tl_lat_c->text().toDouble());
211212

213+
params.set_loss_function(ui_->loss_function->currentText().toStdString());
214+
212215
params.save_to_project();
213216
}
214217

0 commit comments

Comments
 (0)