Skip to content

Commit e2be77d

Browse files
new changes
1 parent fd0e1ba commit e2be77d

File tree

7 files changed

+178
-7
lines changed

7 files changed

+178
-7
lines changed

Libs/Particles/ParticleShapeStatistics.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ class ParticleShapeStatistics {
135135
//! Set the meshes for each sample (used for some evaluation metrics)
136136
void set_meshes(const std::vector<Mesh>& meshes) { meshes_ = meshes; }
137137

138+
// import estimated parameters for regression
139+
inline bool import_regression_parameters(Eigen::VectorXd slope, Eigen::VectorXd intercept) { slope_ = slope; intercept_ = intercept; };
140+
138141
private:
139142
unsigned int num_samples_group1_;
140143
unsigned int num_samples_group2_;
@@ -147,9 +150,14 @@ class ParticleShapeStatistics {
147150
std::vector<double> eigenvalues_;
148151
Eigen::VectorXd mean_;
149152
Eigen::VectorXd mean1_;
150-
Eigen::VectorXd mean2_;
153+
Eigen::VectorXd mean2_;x
151154
Eigen::MatrixXd points_minus_mean_;
152155
Eigen::MatrixXd shapes_;
156+
157+
// for regression tasks
158+
Eigen::VectorXd slope_;
159+
Eigen::VectorXd intercept_;
160+
bool regression_enabled_;
153161

154162
std::vector<double> percent_variance_by_mode_;
155163
Eigen::MatrixXd principals_;

Libs/Project/Subject.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class Subject {
123123
std::string display_name_;
124124
bool fixed_ = false;
125125
bool excluded_ = false;
126-
double explanatory_variable_ = 0.0;
126+
double explanatory_variable_ = std::numeric_limits<double>::lowest();
127127
StringList original_filenames_;
128128
StringList groomed_filenames_;
129129
StringList local_particle_filenames_;

Studio/Analysis/AnalysisTool.cpp

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,11 @@ void AnalysisTool::on_reconstructionButton_clicked() {
269269
//---------------------------------------------------------------------------
270270
int AnalysisTool::get_pca_mode() { return ui_->pcaModeSpinBox->value() - 1; }
271271

272+
//---------------------------------------------------------------------------
273+
bool AnalysisTool::get_regression_analysis_status() {
274+
return ui_->enableRegressionCheckBox->isChecked();
275+
}
276+
272277
//---------------------------------------------------------------------------
273278
double AnalysisTool::get_group_ratio() {
274279
double group_slider_value = ui_->group_slider->value();
@@ -505,6 +510,35 @@ void AnalysisTool::network_analysis_clicked() {
505510
app_->get_py_worker()->run_job(network_analysis_job_);
506511
}
507512

513+
Eigen::VectorXd load_regression_parameters(std::string filepath) {
514+
std::ifstream infile(slope_file_path);
515+
if (!infile.good()) {
516+
throw std::runtime_error("Unable to open regression parameter file: \"" +
517+
filepath + "\" for reading");
518+
}
519+
try {
520+
std::vector<double> temp_values;
521+
double value;
522+
while (infile >> value) {
523+
temp_values.push_back(value);
524+
}
525+
if (temp_values.empty()) {
526+
std::cerr << "Error: No data found in file " << slope_file_path
527+
<< std::endl;
528+
return Eigen::VectorXd();
529+
}
530+
Eigen::VectorXd param_vector(temp_values.size());
531+
for (std::size_t i = 0; i < temp_values.size(); ++i) {
532+
param_vector[i] = temp_values[i];
533+
}
534+
return param_vector;
535+
536+
} catch (json::exception& e) {
537+
throw std::runtime_error("Unabled to parse regression parameter file " +
538+
filepath + " : " + e.what());
539+
}
540+
}
541+
508542
//-----------------------------------------------------------------------------
509543
bool AnalysisTool::compute_stats() {
510544
if (stats_ready_) {
@@ -627,6 +661,20 @@ bool AnalysisTool::compute_stats() {
627661
compute_shape_evaluations();
628662
}
629663

664+
can_run_regression_ = check_explanatory_variable_limits();
665+
if (can_run_regression_) {
666+
auto slope = load_regression_parameters(
667+
session_->get_regression_param_file("slope"));
668+
auto intercept = load_regression_parameters(
669+
session_->get_regression_param_file("intercept"));
670+
stats_.import_regression_parameters(slope, intercept);
671+
}
672+
else {
673+
ui_->regression_groupbox->setVisible(false);
674+
ui_->explanatoryVariableSlider->setVisible(false);
675+
ui_->enableRegressionCheckBox->setVisible(false);
676+
}
677+
630678
stats_ready_ = true;
631679

632680
/// Set this to true to export long format sample data (e.g. for import into R)
@@ -667,6 +715,19 @@ bool AnalysisTool::compute_stats() {
667715
return true;
668716
}
669717

718+
bool check_explanatory_variable_limits() {
719+
auto subjects = session_->get_project()->get_subjects();
720+
explanatory_variable_limits_.resize(2);
721+
explanatory_variable_limits_[0] = std::numeric_limits<double>::max();
722+
explanatory_variable_limits_[1] = std::numeric_limits<double>::lowest();
723+
for (auto sub : subjects) {
724+
double exp_val = sub->get_explanatory_variable();
725+
if (exp_val == std::numeric_limits<double>::lowest()) return false;
726+
explanatory_variable_limits_[0] = std::min(explanatory_variable_limits_[0], exp_val);
727+
explanatory_variable_limits_[1] = std::max(explanatory_variable_limits_[1], exp_val);
728+
}
729+
return true;
730+
}
670731
//-----------------------------------------------------------------------------
671732
Particles AnalysisTool::get_mean_shape_points() {
672733
if (!compute_stats()) {
@@ -729,8 +790,8 @@ Particles AnalysisTool::get_shape_points(int mode, double value) {
729790
ui_->explained_variance->setText("");
730791
ui_->cumulative_explained_variance->setText("");
731792
}
732-
733-
temp_shape_ = stats_.get_mean() + (e * (value * lambda));
793+
auto mean = !regression_enabled_ ? stats_.get_mean() : stats_.get_regression_mean(ui_->get_explanatory_variable_value());
794+
temp_shape_ = mean + (e * (value * lambda));
734795

735796
auto positions = temp_shape_;
736797

@@ -829,11 +890,17 @@ ShapeHandle AnalysisTool::get_current_shape() {
829890
int pca_mode = get_pca_mode();
830891
double pca_value = get_pca_value();
831892
auto mca_level = get_mca_level();
832-
if (mca_level == AnalysisTool::McaMode::Vanilla) {
833-
return get_mode_shape(pca_mode, pca_value);
893+
bool regression_analysis_enabled = get_regression_analysis_status();
894+
if (!regression_analysis_enabled) {
895+
if (mca_level == AnalysisTool::McaMode::Vanilla) {
896+
return get_mode_shape(pca_mode, pca_value);
897+
} else {
898+
return get_mca_mode_shape(pca_mode, pca_value, mca_level);
899+
}
834900
} else {
835-
return get_mca_mode_shape(pca_mode, pca_value, mca_level);
901+
836902
}
903+
837904
}
838905

839906
//---------------------------------------------------------------------------
@@ -1095,6 +1162,13 @@ double AnalysisTool::get_pca_value() {
10951162
return value;
10961163
}
10971164

1165+
1166+
double AnalysisTool::get_explanatory_variable_value() {
1167+
int slider_value = ui_->explanatoryVariableSlider->value();
1168+
return t_min + (static_cast<double>(slider_value) / 100.0) * (t_max - t_min);
1169+
1170+
}
1171+
10981172
//---------------------------------------------------------------------------
10991173
void AnalysisTool::pca_labels_changed(QString value, QString eigen, QString lambda) {
11001174
set_labels(QString("pca"), value);

Studio/Analysis/AnalysisTool.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ class AnalysisTool : public QWidget {
8080
bool pca_animate();
8181
McaMode get_mca_level() const;
8282

83+
bool get_regression_analysis_status();
84+
85+
bool check_explanatory_variable_limits();
86+
8387
int get_sample_number();
8488

8589
bool compute_stats();
@@ -236,6 +240,8 @@ class AnalysisTool : public QWidget {
236240
void update_difference_particles();
237241

238242
Eigen::VectorXd get_mean_shape_particles();
243+
244+
Eigen::VectorXd load_regression_parameters(std::string filepath);
239245

240246
ShapeHandle create_shape_from_points(Particles points);
241247

@@ -275,6 +281,9 @@ class AnalysisTool : public QWidget {
275281

276282
std::string feature_map_;
277283

284+
std::vector<double> explanatory_variable_limits_;
285+
bool can_run_regression_;
286+
278287
std::vector<std::string> current_group_names_;
279288
std::vector<std::string> current_group_values_;
280289

Studio/Analysis/AnalysisTool.ui

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1546,6 +1546,69 @@ QWidget#particles_panel {
15461546
</layout>
15471547
</widget>
15481548
</item>
1549+
<item>
1550+
<widget class="QGroupBox" name="regression_groupbox">
1551+
<property name="title">
1552+
<string>Regression Analysis</string>
1553+
</property>
1554+
<layout class="QGridLayout" name="gridLayout_regression">
1555+
<property name="leftMargin">
1556+
<number>0</number>
1557+
</property>
1558+
<property name="topMargin">
1559+
<number>0</number>
1560+
</property>
1561+
<property name="rightMargin">
1562+
<number>0</number>
1563+
</property>
1564+
<property name="bottomMargin">
1565+
<number>0</number>
1566+
</property>
1567+
<item row="0" column="0">
1568+
<widget class="QCheckBox" name="enableRegressionCheckBox">
1569+
<property name="toolTip">
1570+
<string>Enable regression analysis if explanatory variables are provided</string>
1571+
</property>
1572+
<property name="text">
1573+
<string>Enable Regression Analysis</string>
1574+
</property>
1575+
</widget>
1576+
</item>
1577+
<item row="1" column="0">
1578+
<widget class="QLabel" name="label_exp_slider">
1579+
<property name="text">
1580+
<string>Exp.Var.</string>
1581+
</property>
1582+
</widget>
1583+
</item>
1584+
<item row="1" column="1">
1585+
<widget class="CustomSlider" name="explanatoryVariableSlider">
1586+
<property name="enabled">
1587+
<bool>false</bool>
1588+
</property>
1589+
<property name="toolTip">
1590+
<string>Explanatory variable slider</string>
1591+
</property>
1592+
<property name="minimum">
1593+
<number>-20</number>
1594+
</property>
1595+
<property name="maximum">
1596+
<number>20</number>
1597+
</property>
1598+
<property name="orientation">
1599+
<enum>Qt::Horizontal</enum>
1600+
</property>
1601+
<property name="tickPosition">
1602+
<enum>QSlider::TicksBelow</enum>
1603+
</property>
1604+
<property name="tickInterval">
1605+
<number>1</number>
1606+
</property>
1607+
</widget>
1608+
</item>
1609+
</layout>
1610+
</widget>
1611+
</item>
15491612
<item>
15501613
<layout class="QGridLayout" name="gridLayout_3">
15511614
<property name="verticalSpacing">

Studio/Data/Session.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,4 +1418,19 @@ void Session::recompute_surfaces() {
14181418
}
14191419
Q_EMIT update_display();
14201420
}
1421+
1422+
std::string get_regression_param_file(std::string param_name) {
1423+
QFileInfo fileInfo(filename_);
1424+
QString baseName = fileInfo.completeBaseName();
1425+
1426+
QDir projectDir = fileInfo.absoluteDir();
1427+
QString particlesDir = baseName + "_particles";
1428+
QString paramFilePath = projectDir.filePath(particlesDir);
1429+
paramFilePath = QDir(paramFilePath).filePath(param_name);
1430+
1431+
if (!QFile::exists(paramFilePath)) {
1432+
return "";
1433+
}
1434+
return paramFilePath.toStdString();
1435+
}
14211436
} // namespace shapeworks

Studio/Data/Session.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,8 @@ class Session : public QObject, public QEnableSharedFromThis<Session> {
339339

340340
void new_plane_point(PickResult result);
341341

342+
std::string get_regression_param_file(std::string param_name = "slope");
343+
342344
QWidget* parent_{nullptr};
343345

344346
Preferences& preferences_;

0 commit comments

Comments
 (0)