Skip to content

Commit ebe8d67

Browse files
committed
feat(deep_causality_algorithms): Generic MRMR
Refactored the mRMR (Maximum Relevance, Minimum Redundancy) feature selection algorithm to be generic over float types and their optional variants. This change eliminates code duplication by unifying the `f64` and `Option<f64>` implementations into a single generic version. The refactoring leverages the newly introduced `FloatOption` trait from the `deep_causality_num` crate, allowing `mrmr_features_selector` and its utility functions (`pearson_correlation`, `f_statistic`) to operate on any type `T` that implements `FloatOption<F>`, where `F` is a float type (`f32` or `f64`). Signed-off-by: Marvin Hansen <[email protected]>
1 parent 54d510e commit ebe8d67

File tree

17 files changed

+156
-630
lines changed

17 files changed

+156
-630
lines changed

Cargo.lock

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

deep_causality_algorithms/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ path = "../deep_causality_tensor"
3636
version = "0.1.4"
3737

3838

39+
[dependencies.deep_causality_num]
40+
path = "../deep_causality_num"
41+
version = "0.1.4"
42+
3943
[dependencies.rayon]
4044
version = "1.11"
4145
optional = true

deep_causality_algorithms/benches/mrmr_benchmark.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
*/
55

66
use criterion::{Criterion, criterion_group, criterion_main};
7-
use deep_causality_algorithms::mrmr::{mrmr_features_selector, mrmr_features_selector_cdl};
7+
use deep_causality_algorithms::mrmr::mrmr_features_selector;
88
use deep_causality_tensor::CausalTensor;
99

1010
fn generate_test_tensor(rows: usize, cols: usize) -> CausalTensor<f64> {
@@ -38,17 +38,17 @@ fn mrmr_benchmark(c: &mut Criterion) {
3838

3939
// Benchmark the standard implementation
4040
group.bench_function("mrmr_features_selector", |b| {
41-
let mut tensor = generate_test_tensor(rows, cols);
41+
let tensor = generate_test_tensor(rows, cols);
4242
b.iter(|| {
43-
mrmr_features_selector(&mut tensor, num_features_to_select, target_col).unwrap();
43+
mrmr_features_selector(&tensor, num_features_to_select, target_col).unwrap();
4444
});
4545
});
4646

4747
// Benchmark the cdl implementation
4848
group.bench_function("mrmr_features_selector_cdl", |b| {
4949
let tensor = generate_test_tensor_cdl(rows, cols);
5050
b.iter(|| {
51-
mrmr_features_selector_cdl(&tensor, num_features_to_select, target_col).unwrap();
51+
mrmr_features_selector(&tensor, num_features_to_select, target_col).unwrap();
5252
});
5353
});
5454

deep_causality_algorithms/examples/example_mrmr.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ fn main() {
1414
11.0, // F0 is close to Target, F1 is also somewhat close, F2 is random
1515
20.0, 21.0, 5.0, 22.0, 30.0, 33.0, 2.0, 31.0, 40.0, 40.0, 8.0, 43.0, 50.0, 55.0, 3.0, 52.0,
1616
];
17-
let mut tensor = CausalTensor::new(data, vec![5, 4]).unwrap();
17+
let tensor = CausalTensor::new(data, vec![5, 4]).unwrap();
1818

1919
// 2. Run the feature selector
2020
// Select 2 features, with the target variable in column 3.
21-
let selected_features_with_scores = mrmr_features_selector(&mut tensor, 3, 3).unwrap();
21+
let selected_features_with_scores = mrmr_features_selector(&tensor, 3, 3).unwrap();
2222

2323
// 3. Interpret the results
2424
println!("Selected features and their scores:");

deep_causality_algorithms/examples/example_mrmr_cdl.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
* Copyright (c) "2025" . The DeepCausality Authors and Contributors. All Rights Reserved.
44
*/
55

6-
use deep_causality_algorithms::mrmr::mrmr_features_selector_cdl;
6+
use deep_causality_algorithms::mrmr::mrmr_features_selector;
77
use deep_causality_tensor::CausalTensor;
88

99
fn main() {
@@ -33,12 +33,12 @@ fn main() {
3333
];
3434
let tensor = CausalTensor::new(data, vec![5, 4]).unwrap();
3535

36-
// 2. Run the feature selector for CDL version
36+
// 2. Run the feature selector
3737
// Select 2 features, with the target variable in column 3.
38-
let selected_features_with_scores = mrmr_features_selector_cdl(&tensor, 3, 3).unwrap();
38+
let selected_features_with_scores = mrmr_features_selector(&tensor, 3, 3).unwrap();
3939

4040
// 3. Interpret the results
41-
println!("Selected features and their normalized scores (CDL):");
41+
println!("Selected features and their normalized scores (Generic MRMR):");
4242
for (index, score) in selected_features_with_scores {
4343
println!("- Feature Index: {}, Importance Score: {:.4}", index, score);
4444
}

deep_causality_algorithms/src/feature_selection/mrmr/mod.rs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,7 @@ pub mod mrmr_algo;
77
pub mod mrmr_error;
88
pub mod mrmr_utils;
99

10-
pub mod mrmr_algo_cdl;
11-
pub mod mrmr_utils_cdl;
12-
1310
pub use mrmr_algo::*;
14-
pub use mrmr_algo_cdl::*;
1511
pub use mrmr_error::MrmrError;
1612

1713
#[cfg(test)]

deep_causality_algorithms/src/feature_selection/mrmr/mrmr_algo.rs

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
1-
/*
2-
* SPDX-License-Identifier: MIT
3-
* Copyright (c) "2025" . The DeepCausality Authors and Contributors. All Rights Reserved.
4-
*/
5-
61
use crate::feature_selection::mrmr::mrmr_error::MrmrError;
72
use crate::mrmr::mrmr_utils;
3+
use deep_causality_num::{Float, FloatOption};
84
use deep_causality_tensor::CausalTensor;
95
use std::collections::HashSet;
106

@@ -21,7 +17,8 @@ use rayon::prelude::*;
2117
/// When compiled with the `parallel` feature flag, the main feature selection loops are parallelized using `rayon`
2218
/// to accelerate computation on multi-core systems.
2319
///
24-
/// Missing values in the input `CausalTensor` are handled by column-mean imputation prior to feature selection.
20+
/// Missing values in the input `CausalTensor` are handled by pairwise deletion. `NaN` values
21+
/// in float types and `None` values in `Option<float>` types are treated as missing.
2522
///
2623
/// The algorithm iteratively selects features based on a score that balances relevance and redundancy.
2724
/// The scoring mechanism is as follows:
@@ -45,12 +42,10 @@ use rayon::prelude::*;
4542
///
4643
/// # Arguments
4744
///
48-
/// * `tensor` - A mutable reference to a 2-dimensional `CausalTensor<f64>` containing the features and the target variable.
45+
/// * `tensor` - A reference to a 2-dimensional `CausalTensor<T>` containing the features and the target variable.
4946
/// * `num_features` - The desired number of features to select.
5047
/// * `target_col` - The column index of the target variable within the `tensor`.
5148
///
52-
/// Missing values (NaN) in the input tensor will be imputed.
53-
///
5449
/// # Returns
5550
///
5651
/// A `Result` containing:
@@ -78,21 +73,25 @@ use rayon::prelude::*;
7873
/// 3.0, 6.2, 9.0, 5.5,
7974
/// 4.0, 8.1, 12.0, 7.5,
8075
/// ];
81-
/// let mut tensor = CausalTensor::new(data, vec![4, 4]).unwrap();
76+
/// let tensor = CausalTensor::new(data, vec![4, 4]).unwrap();
8277
///
8378
/// // Select 2 features, with the target variable in column 3.
84-
/// let selected_features_with_scores = mrmr_features_selector(&mut tensor, 2, 3).unwrap();
79+
/// let selected_features_with_scores = mrmr_features_selector(&tensor, 2, 3).unwrap();
8580
/// // The exact output may vary slightly based on floating-point precision and data, but for this example,
8681
/// // it typically selects features 2 and 0 (indices of the original columns).
8782
/// assert_eq!(selected_features_with_scores.len(), 2);
8883
/// // assert_eq!(selected_features_with_scores[0].0, 2); // Example expected output for index
8984
/// // assert!(selected_features_with_scores[0].1.is_finite()); // Example expected output for score
9085
/// ```
91-
pub fn mrmr_features_selector(
92-
tensor: &mut CausalTensor<f64>,
86+
pub fn mrmr_features_selector<T, F>(
87+
tensor: &CausalTensor<T>,
9388
num_features: usize,
9489
target_col: usize,
95-
) -> Result<Vec<(usize, f64)>, MrmrError> {
90+
) -> Result<Vec<(usize, f64)>, MrmrError>
91+
where
92+
T: FloatOption<F>,
93+
F: Float,
94+
{
9695
let shape = tensor.shape();
9796
if shape.len() != 2 {
9897
return Err(MrmrError::InvalidInput(
@@ -119,8 +118,6 @@ pub fn mrmr_features_selector(
119118
));
120119
}
121120

122-
mrmr_utils::impute_missing_values(tensor);
123-
124121
let mut all_features: HashSet<usize> = (0..n_cols).collect();
125122
all_features.remove(&target_col);
126123

@@ -207,8 +204,14 @@ pub fn mrmr_features_selector(
207204
let redundancy: f64 = selected_indices
208205
.par_iter()
209206
.map(|&selected_idx| {
210-
mrmr_utils::pearson_correlation(tensor, feature_idx, selected_idx)
211-
.map(|v| v.abs())
207+
let (correlation, _) = mrmr_utils::pearson_correlation(tensor, feature_idx, selected_idx)?;
208+
if !correlation.is_finite() {
209+
return Err(MrmrError::FeatureScoreError(format!(
210+
"Correlation for feature {} and selected feature {} is not finite: {}",
211+
feature_idx, selected_idx, correlation
212+
)))
213+
}
214+
Ok(correlation.abs())
212215
})
213216
.sum::<Result<f64, _>>()?;
214217

@@ -281,7 +284,7 @@ pub fn mrmr_features_selector(
281284
.collect();
282285

283286
for &selected_idx in &selected_indices {
284-
let correlation =
287+
let (correlation, _) =
285288
mrmr_utils::pearson_correlation(tensor, feature_idx, selected_idx)?;
286289
if !correlation.is_finite() {
287290
return Err(MrmrError::FeatureScoreError(format!(

0 commit comments

Comments
 (0)