Skip to content

Commit 181baf4

Browse files
committed
Speed up 2bpls prediction.
1 parent f87ebe3 commit 181baf4

File tree

8 files changed

+154
-69
lines changed

8 files changed

+154
-69
lines changed

Python/shapeworks/shapeworks/shape_scalars.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,22 @@ def get_fig_png():
2828
def run_mbpls(x, y, n_components=3, cv=5):
2929
""" Run MBPLS on shape and scalar data """
3030

31-
model = MBPLS(n_components=n_components)
31+
# don't set cv higher than the number of samples
32+
cv = min(cv, len(x))
33+
34+
global mbpls_model
35+
mbpls_model = MBPLS(n_components=n_components)
36+
if cv != 1:
37+
y_pred = cross_val_predict(mbpls_model, x, y, cv=cv)
38+
39+
mbpls_model.fit(x, y)
40+
3241
if cv == 1:
33-
model.fit(x, y)
34-
y_pred = model.predict(x)
35-
else:
36-
y_pred = cross_val_predict(model, x, y, cv=cv)
42+
y_pred = mbpls_model.predict(x)
3743

3844
mse = mean_squared_error(y, y_pred)
3945

40-
sw_message(f'MSE: {mse}')
46+
sw_message(f'Python MSE: {mse}')
4147

4248
prediction = pd.DataFrame(np.array(y_pred))
4349

@@ -84,17 +90,26 @@ def run_find_num_components(x, y, max_components, cv=5):
8490
return figdata_png
8591

8692

87-
def pred_from_mbpls(x, y, new_x, n_components=3):
88-
""" Run MBPLS on shape and scalar data, then predict new_y from new_x """
93+
def pred_from_mbpls(new_x, n_components=3):
94+
""" Predict new_y from new_x using existing mbpls fit """
95+
96+
if not does_mbpls_model_exist():
97+
sw_message('MBPLS model does not exist, returning none')
98+
return None
99+
100+
global mbpls_model
101+
y_pred = mbpls_model.predict(new_x)
102+
# return as vector
103+
return y_pred.flatten()
104+
105+
def does_mbpls_model_exist():
106+
""" Check if mbpls model exists """
89107

90108
# check if global variable model exists, otherwise create it
91-
global model
109+
global mbpls_model
92110
try:
93-
model
111+
mbpls_model
94112
except NameError:
95-
model = MBPLS(n_components=n_components)
96-
model.fit(x, y)
113+
return False
97114

98-
y_pred = model.predict(new_x)
99-
# return as vector
100-
return y_pred.flatten()
115+
return True

Studio/Analysis/AnalysisTool.cpp

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ bool AnalysisTool::compute_stats() {
473473
return false;
474474
}
475475

476-
SW_LOG("Compute Stats!");
476+
SW_DEBUG("Compute Stats!");
477477
compute_reconstructed_domain_transforms();
478478

479479
ui_->pcaModeSpinBox->setMaximum(std::max<double>(1, session_->get_shapes().size() - 1));
@@ -691,11 +691,14 @@ Particles AnalysisTool::get_shape_points(int mode, double value) {
691691

692692
auto positions = temp_shape_;
693693

694+
computed_scalars_ = Eigen::VectorXd();
694695
if (pca_shape_plus_scalar_mode()) {
695696
positions = extract_shape_only(temp_shape_);
696-
temp_scalars_ = extract_scalar_only(temp_shape_);
697+
computed_scalars_ = extract_scalar_only(temp_shape_);
697698
} else if (pca_scalar_only_mode()) {
698699
SW_LOG("Scalar only mode not implemented yet");
700+
computed_scalars_ = temp_shape_;
701+
positions = construct_mean_shape();
699702
}
700703

701704
return convert_from_combined(positions);
@@ -908,19 +911,13 @@ AnalysisTool::GroupAnalysisType AnalysisTool::get_group_analysis_type() {
908911
}
909912

910913
//---------------------------------------------------------------------------
911-
bool AnalysisTool::pca_scalar_only_mode() {
912-
return ui_->pca_scalar_only->isChecked();
913-
}
914+
bool AnalysisTool::pca_scalar_only_mode() { return ui_->pca_scalar_only->isChecked(); }
914915

915916
//---------------------------------------------------------------------------
916-
bool AnalysisTool::pca_shape_plus_scalar_mode() {
917-
return ui_->pca_shape_and_scalar->isChecked();
918-
}
917+
bool AnalysisTool::pca_shape_plus_scalar_mode() { return ui_->pca_shape_and_scalar->isChecked(); }
919918

920919
//---------------------------------------------------------------------------
921-
bool AnalysisTool::pca_shape_only_mode() {
922-
return ui_->pca_scalar_shape_only->isChecked();
923-
}
920+
bool AnalysisTool::pca_shape_only_mode() { return ui_->pca_scalar_shape_only->isChecked(); }
924921

925922
//---------------------------------------------------------------------------
926923
void AnalysisTool::on_tabWidget_currentChanged() { update_analysis_mode(); }
@@ -976,6 +973,8 @@ void AnalysisTool::handle_pca_timer() {
976973
}
977974

978975
ui_->pcaSlider->setValue(value);
976+
977+
QApplication::processEvents();
979978
}
980979

981980
//---------------------------------------------------------------------------
@@ -1244,12 +1243,13 @@ ShapeHandle AnalysisTool::create_shape_from_points(Particles points) {
12441243
shape->set_reconstruction_transforms(reconstruction_transforms_);
12451244

12461245
if (feature_map_ != "") {
1247-
// auto scalars = ShapeScalarJob::predict_scalars(session_, QString::fromStdString(feature_map_),
1248-
// points.get_combined_global_particles());
1249-
1250-
// shape->set_point_features(feature_map_, scalars);
1251-
1252-
shape->set_point_features(feature_map_, temp_scalars_);
1246+
if (ui_->pca_predict_scalar->isChecked()) {
1247+
auto scalars = ShapeScalarJob::predict_scalars(session_, QString::fromStdString(feature_map_),
1248+
points.get_combined_global_particles());
1249+
shape->set_point_features(feature_map_, scalars);
1250+
} else {
1251+
shape->set_point_features(feature_map_, computed_scalars_);
1252+
}
12531253
}
12541254
return shape;
12551255
}
@@ -1729,6 +1729,25 @@ void AnalysisTool::change_pca_analysis_type() {
17291729
evals_ready_ = false;
17301730
stats_ = ParticleShapeStatistics();
17311731
compute_stats();
1732+
Q_EMIT pca_update();
1733+
}
1734+
1735+
//---------------------------------------------------------------------------
1736+
Eigen::VectorXd AnalysisTool::construct_mean_shape() {
1737+
if (session_->get_shapes().empty()) {
1738+
return Eigen::VectorXd();
1739+
}
1740+
1741+
Eigen::VectorXd sum_shape =
1742+
Eigen::VectorXd::Zero(session_->get_shapes()[0]->get_global_correspondence_points().size());
1743+
1744+
for (auto& shape : session_->get_shapes()) {
1745+
Eigen::VectorXd particles = shape->get_global_correspondence_points();
1746+
sum_shape += particles;
1747+
}
1748+
1749+
Eigen::VectorXd mean_shape = sum_shape / session_->get_shapes().size();
1750+
return mean_shape;
17321751
}
17331752

17341753
//---------------------------------------------------------------------------

Studio/Analysis/AnalysisTool.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,9 @@ class AnalysisTool : public QWidget {
194194

195195
void change_pca_analysis_type();
196196

197+
//! Compute the mean shape outside of the PCA in case we are using scalars only
198+
Eigen::VectorXd construct_mean_shape();
199+
197200
Q_SIGNALS:
198201

199202
void update_view();
@@ -253,7 +256,7 @@ class AnalysisTool : public QWidget {
253256
Eigen::VectorXd temp_shape_mca;
254257
std::vector<int> number_of_particles_array_;
255258

256-
Eigen::VectorXd temp_scalars_;
259+
Eigen::VectorXd computed_scalars_;
257260

258261
bool pca_animate_direction_ = true;
259262
QTimer pca_animate_timer_;

Studio/Analysis/AnalysisTool.ui

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1401,7 +1401,13 @@ QWidget#particles_panel {
14011401
</widget>
14021402
</item>
14031403
<item row="0" column="1">
1404-
<widget class="QComboBox" name="pca_scalar_combo"/>
1404+
<widget class="QComboBox" name="pca_scalar_combo">
1405+
<property name="font">
1406+
<font>
1407+
<family>.AppleSystemUIFont</family>
1408+
</font>
1409+
</property>
1410+
</widget>
14051411
</item>
14061412
<item row="1" column="0">
14071413
<widget class="QRadioButton" name="pca_scalar_shape_only">

Studio/Job/Job.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ class Job : public QObject {
4141
//! was the job aborted?
4242
bool is_aborted() const { return abort_; }
4343

44+
//! set to quiet mode (no progress messages)
45+
void set_quiet_mode(bool quiet) { quiet_mode_ = quiet; }
46+
47+
//! get quiet mode
48+
bool get_quiet_mode() { return quiet_mode_; }
49+
4450
public Q_SLOTS:
4551

4652
Q_SIGNALS:
@@ -51,6 +57,7 @@ class Job : public QObject {
5157
private:
5258
std::atomic<bool> complete_ = false;
5359
std::atomic<bool> abort_ = false;
60+
std::atomic<bool> quiet_mode_ = false;
5461

5562
QElapsedTimer timer_;
5663
};

Studio/Job/ShapeScalarJob.cpp

Lines changed: 62 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <pybind11/stl.h>
77

88
#include <Eigen/Dense>
9+
#include <QApplication>
910
#include <QImage>
1011

1112
namespace py = pybind11;
@@ -22,46 +23,22 @@ ShapeScalarJob::ShapeScalarJob(QSharedPointer<Session> session, QString target_f
2223

2324
//---------------------------------------------------------------------------
2425
void ShapeScalarJob::run() {
25-
SW_DEBUG("Running shape scalar job");
26+
// SW_DEBUG("Running shape scalar job");
2627

2728
try {
28-
prep_data();
29-
3029
py::module np = py::module::import("numpy");
31-
py::object A = np.attr("array")(all_particles_);
32-
py::object B = np.attr("array")(all_scalars_);
3330

3431
py::module sw = py::module::import("shapeworks");
3532

3633
if (job_type_ == JobType::MSE_Plot) {
37-
// returns a tuple of (png_raw_bytes, y_pred, mse)
38-
using ResultType = std::tuple<py::array, Eigen::MatrixXd, double>;
39-
40-
py::object run_mbpls = sw.attr("shape_scalars").attr("run_mbpls");
41-
ResultType result = run_mbpls(A, B, num_components_, num_folds_).cast<ResultType>();
42-
43-
py::array png_raw_bytes = std::get<0>(result);
44-
Eigen::MatrixXd y_pred = std::get<1>(result);
45-
double mse = std::get<2>(result);
46-
47-
// interpret png_raw_bytes as a QImage
48-
QImage image;
49-
image.loadFromData((const uchar*)png_raw_bytes.data(), png_raw_bytes.size(), "PNG");
50-
plot_ = QPixmap::fromImage(image);
51-
52-
SW_LOG("mse = {}", mse);
53-
34+
run_fit();
5435
} else if (job_type_ == JobType::Predict) {
55-
py::object new_x = np.attr("array")(target_particles_.transpose());
56-
py::object run_prediction = sw.attr("shape_scalars").attr("pred_from_mbpls");
57-
58-
using ResultType = Eigen::VectorXd;
59-
ResultType result = run_prediction(A, B, new_x).cast<ResultType>();
60-
61-
auto y_pred = result;
62-
63-
prediction_ = y_pred;
36+
run_prediction();
6437
} else if (job_type_ == JobType::Find_Components) {
38+
prep_data();
39+
py::object A = np.attr("array")(all_particles_);
40+
py::object B = np.attr("array")(all_scalars_);
41+
6542
// returns a tuple of (png_raw_bytes, y_pred, mse)
6643
using ResultType = py::array;
6744

@@ -75,7 +52,7 @@ void ShapeScalarJob::run() {
7552
image.loadFromData((const uchar*)png_raw_bytes.data(), png_raw_bytes.size(), "PNG");
7653
plot_ = QPixmap::fromImage(image);
7754
}
78-
SW_DEBUG("End shape scalar job");
55+
// SW_DEBUG("End shape scalar job");
7956

8057
} catch (const std::exception& e) {
8158
SW_ERROR("Exception in shape scalar job: {}", e.what());
@@ -94,6 +71,7 @@ Eigen::VectorXd ShapeScalarJob::predict_scalars(QSharedPointer<Session> session,
9471
// blocking call to predict scalars for given target particles
9572

9673
auto job = QSharedPointer<ShapeScalarJob>::create(session, target_feature, target_particles, JobType::Predict);
74+
job->set_quiet_mode(true);
9775

9876
Eigen::VectorXd prediction;
9977

@@ -106,8 +84,9 @@ Eigen::VectorXd ShapeScalarJob::predict_scalars(QSharedPointer<Session> session,
10684

10785
session->get_py_worker()->run_job(job);
10886

87+
// wait for job to finish without using sleep
10988
while (!finished) {
110-
QThread::msleep(100);
89+
QApplication::processEvents();
11190
}
11291

11392
return prediction;
@@ -119,5 +98,55 @@ void ShapeScalarJob::prep_data() {
11998
all_scalars_ = session_->get_all_scalars(target_feature_.toStdString());
12099
}
121100

101+
//---------------------------------------------------------------------------
102+
void ShapeScalarJob::run_fit() {
103+
prep_data();
104+
py::module np = py::module::import("numpy");
105+
py::module sw = py::module::import("shapeworks");
106+
107+
py::object A = np.attr("array")(all_particles_);
108+
py::object B = np.attr("array")(all_scalars_);
109+
110+
// returns a tuple of (png_raw_bytes, y_pred, mse)
111+
using ResultType = std::tuple<py::array, Eigen::MatrixXd, double>;
112+
113+
py::object run_mbpls = sw.attr("shape_scalars").attr("run_mbpls");
114+
ResultType result = run_mbpls(A, B, num_components_, num_folds_).cast<ResultType>();
115+
116+
py::array png_raw_bytes = std::get<0>(result);
117+
Eigen::MatrixXd y_pred = std::get<1>(result);
118+
double mse = std::get<2>(result);
119+
120+
// interpret png_raw_bytes as a QImage
121+
QImage image;
122+
image.loadFromData((const uchar*)png_raw_bytes.data(), png_raw_bytes.size(), "PNG");
123+
plot_ = QPixmap::fromImage(image);
124+
125+
SW_LOG("mse = {}", mse);
126+
}
127+
128+
//---------------------------------------------------------------------------
129+
void ShapeScalarJob::run_prediction() {
130+
py::module np = py::module::import("numpy");
131+
py::module sw = py::module::import("shapeworks");
132+
133+
py::object does_mbpls_model_exist = sw.attr("shape_scalars").attr("does_mbpls_model_exist");
134+
if (!does_mbpls_model_exist().cast<bool>()) {
135+
SW_LOG("No MBPLS model exists, running fit");
136+
run_fit();
137+
}
138+
139+
py::object new_x = np.attr("array")(target_particles_.transpose());
140+
py::object run_prediction = sw.attr("shape_scalars").attr("pred_from_mbpls");
141+
142+
using ResultType = Eigen::VectorXd;
143+
144+
ResultType result = run_prediction(new_x).cast<ResultType>();
145+
146+
auto y_pred = result;
147+
148+
prediction_ = y_pred;
149+
}
150+
122151
//---------------------------------------------------------------------------
123152
} // namespace shapeworks

Studio/Job/ShapeScalarJob.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ class ShapeScalarJob : public Job {
3333
private:
3434
void prep_data();
3535

36+
void run_fit();
37+
void run_prediction();
3638

3739
QSharedPointer<Session> session_;
3840

Studio/Python/PythonWorker.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,16 @@ void PythonWorker::start_job(QSharedPointer<Job> job) {
7676
if (init()) {
7777
try {
7878
job->start_timer();
79-
SW_LOG("Running Task: " + job->name().toStdString());
79+
if (!job->get_quiet_mode()) {
80+
SW_LOG("Running Task: " + job->name().toStdString());
81+
}
8082
Q_EMIT job->progress(0);
8183
current_job_ = job;
8284
current_job_->run();
8385
current_job_->set_complete(true);
84-
SW_LOG(current_job_->get_completion_message().toStdString());
86+
if (!job->get_quiet_mode()) {
87+
SW_LOG(current_job_->get_completion_message().toStdString());
88+
}
8589
} catch (py::error_already_set& e) {
8690
SW_ERROR(e.what());
8791
}

0 commit comments

Comments
 (0)