Skip to content

Commit 6eb9637

Browse files
add comupte regression mean logic
1 parent 04c25ea commit 6eb9637

File tree

4 files changed

+37
-9
lines changed

4 files changed

+37
-9
lines changed

Libs/Particles/ParticleShapeStatistics.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,27 @@ ParticleShapeStatistics::ParticleShapeStatistics(std::shared_ptr<Project> projec
492492
groups.push_back(1);
493493
}
494494
import_points(points, groups);
495+
// TODO: importing regression params doesn't make sense here. take a look again later.
496+
}
497+
498+
Eigen::VectorXd ParticleShapeStatistics::compute_regression_mean(const std::vector<double>& explanatory_variables) const
499+
{
500+
// Map explanatory variables to an Eigen vector
501+
Eigen::VectorXd t = Eigen::Map<const Eigen::VectorXd>(explanatory_variables.data(), explanatory_variables.size());
502+
503+
// Ensure slope and intercept are initialized
504+
if (slope_.size() == 0 || intercept_.size() == 0) {
505+
throw std::runtime_error("Slope and Intercept not initialized yet!");
506+
}
507+
508+
// Handle scalar and vector cases for t
509+
if (t.size() == 1) {
510+
return slope_ + intercept_ * t[0]; // Scalar broadcasting
511+
} else if (t.size() == slope_.size()) {
512+
return slope_ + intercept_.cwiseProduct(t); // Element-wise multiplication
513+
} else {
514+
throw std::invalid_argument("Size mismatch: t must be either scalar or match dimensions of slope and intercept.");
515+
}
495516
}
496517

497518
//---------------------------------------------------------------------------

Libs/Particles/ParticleShapeStatistics.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@ class ParticleShapeStatistics {
8787
const Eigen::VectorXd& get_mean_rel_pos() { return mean_rel_pose_; }
8888

8989
//! Returns the mean shape.
90-
const Eigen::VectorXd& get_mean() const { return mean_; }
90+
const Eigen::VectorXd get_mean() const { return mean_; }
91+
92+
Eigen::VectorXd compute_regression_mean(const std::vector<double>& explanatory_variables) const;
9193
const Eigen::VectorXd& get_group1_mean() const { return mean1_; }
9294
const Eigen::VectorXd& get_group2_mean() const { return mean2_; }
9395

@@ -136,7 +138,7 @@ class ParticleShapeStatistics {
136138
void set_meshes(const std::vector<Mesh>& meshes) { meshes_ = meshes; }
137139

138140
// import estimated parameters for regression
139-
inline bool import_regression_parameters(Eigen::VectorXd slope, Eigen::VectorXd intercept) { slope_ = slope; intercept_ = intercept; };
141+
inline bool import_regression_parameters(Eigen::VectorXd& slope, Eigen::VectorXd& intercept) { slope_ = slope; intercept_ = intercept; return true;};
140142

141143
private:
142144
unsigned int num_samples_group1_;

Studio/Analysis/AnalysisTool.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -664,10 +664,13 @@ bool AnalysisTool::compute_stats() {
664664
can_run_regression_ = check_explanatory_variable_limits();
665665
if (can_run_regression_) {
666666
auto slope = load_regression_parameters(
667-
session_->get_regression_param_file("slope"));
667+
session_->get_regression_param_file("slope")); // dM vector
668668
auto intercept = load_regression_parameters(
669-
session_->get_regression_param_file("intercept"));
670-
stats_.import_regression_parameters(slope, intercept);
669+
session_->get_regression_param_file("intercept")); // dM vector
670+
stats_.import_regression_parameters(slope, intercept); // set slope and intercept in stats object
671+
ui_->regression_groupbox->setVisible(true);
672+
ui_->explanatoryVariableSlider->setVisible(true);
673+
ui_->enableRegressionCheckBox->setVisible(true);
671674
}
672675
else {
673676
ui_->regression_groupbox->setVisible(false);
@@ -790,7 +793,7 @@ Particles AnalysisTool::get_shape_points(int mode, double value) {
790793
ui_->explained_variance->setText("");
791794
ui_->cumulative_explained_variance->setText("");
792795
}
793-
auto mean = !regression_enabled_ ? stats_.get_mean() : stats_.get_regression_mean(ui_->get_explanatory_variable_value());
796+
auto mean = !regression_enabled_ ? stats_.get_mean() : stats_.compute_regression_mean(ui_->get_explanatory_variable_value());
794797
temp_shape_ = mean + (e * (value * lambda));
795798

796799
auto positions = temp_shape_;
@@ -1157,9 +1160,9 @@ double AnalysisTool::get_pca_value() {
11571160
}
11581161

11591162

1160-
double AnalysisTool::get_explanatory_variable_value() {
1163+
std::vector<double> AnalysisTool::get_explanatory_variable_value() {
11611164
int slider_value = ui_->explanatoryVariableSlider->value();
1162-
return t_min + (static_cast<double>(slider_value) / 100.0) * (t_max - t_min);
1165+
return {t_min + (static_cast<double>(slider_value) / 100.0) * (t_max - t_min)};
11631166

11641167
}
11651168

Studio/Analysis/AnalysisTool.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ class AnalysisTool : public QWidget {
7777

7878
double get_pca_value();
7979

80+
std::vector<double> get_explanatory_variable_value();
81+
8082
bool pca_animate();
8183
McaMode get_mca_level() const;
8284

@@ -282,7 +284,7 @@ class AnalysisTool : public QWidget {
282284
std::string feature_map_;
283285

284286
std::vector<double> explanatory_variable_limits_;
285-
bool can_run_regression_;
287+
bool can_run_regression_; // decide if necessary variables are present to run regression in analysis
286288

287289
std::vector<std::string> current_group_names_;
288290
std::vector<std::string> current_group_values_;

0 commit comments

Comments
 (0)