Skip to content

Commit 7e2cf17

Browse files
committed
Remaining pieces for pca shape scalar
1 parent bed1237 commit 7e2cf17

File tree

6 files changed

+97
-13
lines changed

6 files changed

+97
-13
lines changed

Python/shapeworks/shapeworks/shape_scalars.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def run_find_num_components(x, y, max_components, cv=5):
9090
return figdata_png
9191

9292

93-
def pred_from_mbpls(new_x, n_components=3):
93+
def pred_from_mbpls(new_x):
9494
""" Predict new_y from new_x using existing mbpls fit """
9595

9696
if not does_mbpls_model_exist():
@@ -112,4 +112,16 @@ def does_mbpls_model_exist():
112112
except NameError:
113113
return False
114114

115-
return True
115+
return True
116+
117+
def clear_mbpls_model():
118+
""" Clear mbpls model """
119+
120+
global mbpls_model
121+
try:
122+
mbpls_model
123+
except NameError:
124+
return
125+
126+
del mbpls_model
127+
return

Studio/Analysis/AnalysisTool.cpp

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,23 @@ void AnalysisTool::network_analysis_clicked() {
463463
app_->get_py_worker()->run_job(network_analysis_job_);
464464
}
465465

466+
//-----------------------------------------------------------------------------
467+
Eigen::VectorXd AnalysisTool::extract_positions(Eigen::VectorXd& data) {
468+
/*
469+
auto positions = data;
470+
471+
if (pca_shape_plus_scalar_mode()) {
472+
positions = extract_shape_only(data);
473+
} else if (pca_scalar_only_mode()) {
474+
computed_scalars_ = temp_shape_;
475+
if (ui_->pca_predict_shape->isChecked()) {
476+
positions = ShapeScalarJob::predict_shape(session_, QString::fromStdString(feature_map_), computed_scalars_);
477+
} else {
478+
positions = construct_mean_shape();
479+
}
480+
}*/
481+
}
482+
466483
//-----------------------------------------------------------------------------
467484
bool AnalysisTool::compute_stats() {
468485
if (stats_ready_) {
@@ -637,6 +654,10 @@ Particles AnalysisTool::get_mean_shape_points() {
637654
return Particles();
638655
}
639656

657+
if (ui_->pca_scalar_only->isChecked()) {
658+
return convert_from_combined(construct_mean_shape());
659+
}
660+
640661
if (ui_->group1_button->isChecked() || ui_->difference_button->isChecked()) {
641662
return convert_from_combined(stats_.get_group1_mean());
642663
} else if (ui_->group2_button->isChecked()) {
@@ -700,7 +721,11 @@ Particles AnalysisTool::get_shape_points(int mode, double value) {
700721
computed_scalars_ = extract_scalar_only(temp_shape_);
701722
} else if (pca_scalar_only_mode()) {
702723
computed_scalars_ = temp_shape_;
703-
positions = construct_mean_shape();
724+
if (ui_->pca_predict_shape->isChecked()) {
725+
positions = ShapeScalarJob::predict_shape(session_, QString::fromStdString(feature_map_), computed_scalars_);
726+
} else {
727+
positions = construct_mean_shape();
728+
}
704729
}
705730

706731
return convert_from_combined(positions);
@@ -1245,7 +1270,7 @@ ShapeHandle AnalysisTool::create_shape_from_points(Particles points) {
12451270
shape->set_reconstruction_transforms(reconstruction_transforms_);
12461271

12471272
if (feature_map_ != "") {
1248-
if (ui_->pca_predict_scalar->isChecked()) {
1273+
if (ui_->pca_scalar_shape_only->isChecked() && ui_->pca_predict_scalar->isChecked()) {
12491274
auto scalars = ShapeScalarJob::predict_scalars(session_, QString::fromStdString(feature_map_),
12501275
points.get_combined_global_particles());
12511276
shape->set_point_features(feature_map_, scalars);
@@ -1730,6 +1755,11 @@ void AnalysisTool::change_pca_analysis_type() {
17301755
stats_ready_ = false;
17311756
evals_ready_ = false;
17321757
stats_ = ParticleShapeStatistics();
1758+
ShapeScalarJob::clear_model();
1759+
1760+
ui_->pca_predict_scalar->setEnabled(ui_->pca_scalar_shape_only->isChecked());
1761+
ui_->pca_predict_shape->setEnabled(ui_->pca_scalar_only->isChecked());
1762+
17331763
compute_stats();
17341764
Q_EMIT pca_update();
17351765
}

Studio/Analysis/AnalysisTool.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,14 @@ class AnalysisTool : public QWidget {
194194

195195
void change_pca_analysis_type();
196196

197+
//Eigen::VectorXd get_mean_shape();
198+
197199
//! Compute the mean shape outside of the PCA in case we are using scalars only
198200
Eigen::VectorXd construct_mean_shape();
199201

202+
203+
Eigen::VectorXd extract_positions(Eigen::VectorXd& data);
204+
200205
Q_SIGNALS:
201206

202207
void update_view();

Studio/Analysis/AnalysisTool.ui

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1451,14 +1451,14 @@ QWidget#particles_panel {
14511451
<item row="1" column="1">
14521452
<widget class="QCheckBox" name="pca_predict_scalar">
14531453
<property name="text">
1454-
<string>Predict scalar</string>
1454+
<string>Predict Scalar</string>
14551455
</property>
14561456
</widget>
14571457
</item>
14581458
<item row="2" column="1">
14591459
<widget class="QCheckBox" name="pca_predict_shape">
14601460
<property name="text">
1461-
<string>Predict shape</string>
1461+
<string>Predict Shape</string>
14621462
</property>
14631463
</widget>
14641464
</item>

Studio/Job/ShapeScalarJob.cpp

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ using namespace pybind11::literals; // to bring in the `_a` literal
1616

1717
namespace shapeworks {
1818

19+
std::atomic<bool> ShapeScalarJob::needs_clear_ = false;
20+
1921
//---------------------------------------------------------------------------
2022
ShapeScalarJob::ShapeScalarJob(QSharedPointer<Session> session, QString target_feature,
2123
Eigen::MatrixXd target_particles, JobType job_type)
22-
: session_(session), target_feature_(target_feature), target_particles_(target_particles), job_type_(job_type) {}
24+
: session_(session), target_feature_(target_feature), target_values_(target_particles), job_type_(job_type) {}
2325

2426
//---------------------------------------------------------------------------
2527
void ShapeScalarJob::run() {
@@ -68,10 +70,23 @@ QPixmap ShapeScalarJob::get_plot() { return plot_; }
6870
//---------------------------------------------------------------------------
6971
Eigen::VectorXd ShapeScalarJob::predict_scalars(QSharedPointer<Session> session, QString target_feature,
7072
Eigen::MatrixXd target_particles) {
73+
return predict(session, target_feature, target_particles, Direction::To_Scalar);
74+
}
75+
76+
//---------------------------------------------------------------------------
77+
Eigen::VectorXd ShapeScalarJob::predict_shape(QSharedPointer<Session> session, QString target_feature,
78+
Eigen::MatrixXd target_scalars) {
79+
return predict(session, target_feature, target_scalars, Direction::To_Shape);
80+
}
81+
82+
//---------------------------------------------------------------------------
83+
Eigen::VectorXd ShapeScalarJob::predict(QSharedPointer<Session> session, QString target_feature,
84+
Eigen::MatrixXd target_values, Direction direction) {
7185
// blocking call to predict scalars for given target particles
7286

73-
auto job = QSharedPointer<ShapeScalarJob>::create(session, target_feature, target_particles, JobType::Predict);
87+
auto job = QSharedPointer<ShapeScalarJob>::create(session, target_feature, target_values, JobType::Predict);
7488
job->set_quiet_mode(true);
89+
job->set_direction(direction);
7590

7691
Eigen::VectorXd prediction;
7792

@@ -104,8 +119,15 @@ void ShapeScalarJob::run_fit() {
104119
py::module np = py::module::import("numpy");
105120
py::module sw = py::module::import("shapeworks");
106121

107-
py::object A = np.attr("array")(all_particles_);
108-
py::object B = np.attr("array")(all_scalars_);
122+
py::object A;
123+
py::object B;
124+
if (direction_ == Direction::To_Scalar) {
125+
A = np.attr("array")(all_particles_);
126+
B = np.attr("array")(all_scalars_);
127+
} else {
128+
A = np.attr("array")(all_scalars_);
129+
B = np.attr("array")(all_particles_);
130+
}
109131

110132
// returns a tuple of (png_raw_bytes, y_pred, mse)
111133
using ResultType = std::tuple<py::array, Eigen::MatrixXd, double>;
@@ -131,12 +153,13 @@ void ShapeScalarJob::run_prediction() {
131153
py::module sw = py::module::import("shapeworks");
132154

133155
py::object does_mbpls_model_exist = sw.attr("shape_scalars").attr("does_mbpls_model_exist");
134-
if (!does_mbpls_model_exist().cast<bool>()) {
156+
if (needs_clear_ == true || !does_mbpls_model_exist().cast<bool>()) {
135157
SW_LOG("No MBPLS model exists, running fit");
136158
run_fit();
159+
needs_clear_ = false;
137160
}
138161

139-
py::object new_x = np.attr("array")(target_particles_.transpose());
162+
py::object new_x = np.attr("array")(target_values_.transpose());
140163
py::object run_prediction = sw.attr("shape_scalars").attr("pred_from_mbpls");
141164

142165
using ResultType = Eigen::VectorXd;

Studio/Job/ShapeScalarJob.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class ShapeScalarJob : public Job {
1313
Q_OBJECT
1414
public:
1515
enum class JobType { Find_Components, MSE_Plot, Predict };
16+
enum class Direction { To_Shape, To_Scalar };
1617

1718
ShapeScalarJob(QSharedPointer<Session> session, QString target_feature, Eigen::MatrixXd target_particles,
1819
JobType job_type);
@@ -30,12 +31,22 @@ class ShapeScalarJob : public Job {
3031
static Eigen::VectorXd predict_scalars(QSharedPointer<Session> session, QString target_feature,
3132
Eigen::MatrixXd target_particles);
3233

34+
static Eigen::VectorXd predict_shape(QSharedPointer<Session> session, QString target_feature,
35+
Eigen::MatrixXd target_particles);
36+
37+
static void clear_model() { needs_clear_ = true; };
38+
39+
void set_direction(Direction direction) { direction_ = direction; }
40+
3341
private:
3442
void prep_data();
3543

3644
void run_fit();
3745
void run_prediction();
3846

47+
static Eigen::VectorXd predict(QSharedPointer<Session> session, QString target_feature,
48+
Eigen::MatrixXd target_particles, Direction direction);
49+
3950
QSharedPointer<Session> session_;
4051

4152
ParticleShapeStatistics stats_;
@@ -47,13 +58,16 @@ class ShapeScalarJob : public Job {
4758
Eigen::MatrixXd all_particles_;
4859
Eigen::MatrixXd all_scalars_;
4960

50-
Eigen::MatrixXd target_particles_;
61+
Eigen::MatrixXd target_values_;
5162
Eigen::VectorXd prediction_;
5263

5364
bool num_components_ = 3;
5465
int num_folds_ = 5;
5566
int max_components_ = 20;
5667

68+
Direction direction_{Direction::To_Scalar};
5769
JobType job_type_;
70+
71+
static std::atomic<bool> needs_clear_;
5872
};
5973
} // namespace shapeworks

0 commit comments

Comments
 (0)