1- /*
2- * SPDX-License-Identifier: MIT
3- * Copyright (c) "2025" . The DeepCausality Authors and Contributors. All Rights Reserved.
4- */
5-
61use crate :: feature_selection:: mrmr:: mrmr_error:: MrmrError ;
72use crate :: mrmr:: mrmr_utils;
3+ use deep_causality_num:: { Float , FloatOption } ;
84use deep_causality_tensor:: CausalTensor ;
95use std:: collections:: HashSet ;
106
@@ -21,7 +17,8 @@ use rayon::prelude::*;
2117/// When compiled with the `parallel` feature flag, the main feature selection loops are parallelized using `rayon`
2218/// to accelerate computation on multi-core systems.
2319///
24- /// Missing values in the input `CausalTensor` are handled by column-mean imputation prior to feature selection.
20+ /// Missing values in the input `CausalTensor` are handled by pairwise deletion. `NaN` values
21+ /// in float types and `None` values in `Option<float>` types are treated as missing.
2522///
2623/// The algorithm iteratively selects features based on a score that balances relevance and redundancy.
2724/// The scoring mechanism is as follows:
@@ -45,12 +42,10 @@ use rayon::prelude::*;
4542///
4643/// # Arguments
4744///
48- /// * `tensor` - A mutable reference to a 2-dimensional `CausalTensor<f64 >` containing the features and the target variable.
45+ /// * `tensor` - A reference to a 2-dimensional `CausalTensor<T >` containing the features and the target variable.
4946/// * `num_features` - The desired number of features to select.
5047/// * `target_col` - The column index of the target variable within the `tensor`.
5148///
52- /// Missing values (NaN) in the input tensor will be imputed.
53- ///
5449/// # Returns
5550///
5651/// A `Result` containing:
@@ -78,21 +73,25 @@ use rayon::prelude::*;
7873/// 3.0, 6.2, 9.0, 5.5,
7974/// 4.0, 8.1, 12.0, 7.5,
8075/// ];
81- /// let mut tensor = CausalTensor::new(data, vec![4, 4]).unwrap();
76+ /// let tensor = CausalTensor::new(data, vec![4, 4]).unwrap();
8277///
8378/// // Select 2 features, with the target variable in column 3.
84- /// let selected_features_with_scores = mrmr_features_selector(&mut tensor, 2, 3).unwrap();
79+ /// let selected_features_with_scores = mrmr_features_selector(&tensor, 2, 3).unwrap();
8580/// // The exact output may vary slightly based on floating-point precision and data, but for this example,
8681/// // it typically selects features 2 and 0 (indices of the original columns).
8782/// assert_eq!(selected_features_with_scores.len(), 2);
8883/// // assert_eq!(selected_features_with_scores[0].0, 2); // Example expected output for index
8984/// // assert!(selected_features_with_scores[0].1.is_finite()); // Example expected output for score
9085/// ```
91- pub fn mrmr_features_selector (
92- tensor : & mut CausalTensor < f64 > ,
86+ pub fn mrmr_features_selector < T , F > (
87+ tensor : & CausalTensor < T > ,
9388 num_features : usize ,
9489 target_col : usize ,
95- ) -> Result < Vec < ( usize , f64 ) > , MrmrError > {
90+ ) -> Result < Vec < ( usize , f64 ) > , MrmrError >
91+ where
92+ T : FloatOption < F > ,
93+ F : Float ,
94+ {
9695 let shape = tensor. shape ( ) ;
9796 if shape. len ( ) != 2 {
9897 return Err ( MrmrError :: InvalidInput (
@@ -119,8 +118,6 @@ pub fn mrmr_features_selector(
119118 ) ) ;
120119 }
121120
122- mrmr_utils:: impute_missing_values ( tensor) ;
123-
124121 let mut all_features: HashSet < usize > = ( 0 ..n_cols) . collect ( ) ;
125122 all_features. remove ( & target_col) ;
126123
@@ -207,8 +204,14 @@ pub fn mrmr_features_selector(
207204 let redundancy: f64 = selected_indices
208205 . par_iter ( )
209206 . map ( |& selected_idx| {
210- mrmr_utils:: pearson_correlation ( tensor, feature_idx, selected_idx)
211- . map ( |v| v. abs ( ) )
207+ let ( correlation, _) = mrmr_utils:: pearson_correlation ( tensor, feature_idx, selected_idx) ?;
208+ if !correlation. is_finite ( ) {
209+ return Err ( MrmrError :: FeatureScoreError ( format ! (
210+ "Correlation for feature {} and selected feature {} is not finite: {}" ,
211+ feature_idx, selected_idx, correlation
212+ ) ) )
213+ }
214+ Ok ( correlation. abs ( ) )
212215 } )
213216 . sum :: < Result < f64 , _ > > ( ) ?;
214217
@@ -281,7 +284,7 @@ pub fn mrmr_features_selector(
281284 . collect ( ) ;
282285
283286 for & selected_idx in & selected_indices {
284- let correlation =
287+ let ( correlation, _ ) =
285288 mrmr_utils:: pearson_correlation ( tensor, feature_idx, selected_idx) ?;
286289 if !correlation. is_finite ( ) {
287290 return Err ( MrmrError :: FeatureScoreError ( format ! (
0 commit comments