33import copy
44import math
55import statistics
6- from typing import TypedDict
6+ from typing import Any , TypedDict
77
88import numpy as np
99import pandas as pd
@@ -85,7 +85,7 @@ class Training:
8585 training_weight : int = field (default = 1 , validator = [validators .instance_of (int )])
8686 """Training weight for ensembling"""
8787
88- model = field (default = None )
88+ model : Any = field (default = None )
8989 """Model."""
9090
9191 predictions : dict = field (default = Factory (dict ), validator = [validators .instance_of (dict )])
@@ -100,16 +100,16 @@ class Training:
100100 """Features used."""
101101
102102 outlier_samples : list = field (default = Factory (list ), validator = [validators .instance_of (list )])
103- """Outlie samples identified."""
103+ """Outlier samples identified."""
104104
105105 is_fitted : bool = field (default = False , init = False )
106106 """Flag indicating whether the training has been completed."""
107107
108- preprocessing_pipeline = field (init = False )
108+ preprocessing_pipeline : ColumnTransformer | Pipeline = field (init = False )
109109 """Preprocessing pipeline for data scaling, imputation, and categorical encoding."""
110110
111- x_train_processed = field (default = None , init = False )
112- """Training data after pre-processing (outlier, impuation , scaling)."""
111+ x_train_processed : pd . DataFrame | None = field (default = None , init = False )
112+ """Training data after pre-processing (outlier, imputation , scaling)."""
113113
114114 @property
115115 def outl_reduction (self ) -> int :
@@ -169,7 +169,7 @@ def y_dev(self):
169169
170170 @property
171171 def y_test (self ):
172- """y_dev ."""
172+ """y_test ."""
173173 if self .ml_type == MLType .TIMETOEVENT :
174174 duration = self .data_test [self .target_assignments ["duration" ]]
175175 event = self .data_test [self .target_assignments ["event" ]]
@@ -186,9 +186,9 @@ def __attrs_post_init__(self):
186186
187187 def _relabel_processed_output (
188188 self ,
189- processed_data : np . ndarray ,
189+ processed_data : Any ,
190190 index : pd .Index | None = None ,
191- ) -> pd .DataFrame | np . ndarray :
191+ ) -> pd .DataFrame :
192192 """Convert pipeline output to a correctly-labeled DataFrame in self.feature_cols order.
193193
194194 Handles the ColumnTransformer column reordering issue: ColumnTransformer outputs columns
@@ -203,8 +203,12 @@ def _relabel_processed_output(
203203 Returns:
204204 DataFrame with columns in self.feature_cols order, correctly labeled.
205205 """
206+ # Convert sparse matrices to dense arrays
207+ if hasattr (processed_data , "toarray" ):
208+ processed_data = processed_data .toarray ()
209+
206210 if not (hasattr (processed_data , "shape" ) and len (processed_data .shape ) == 2 ):
207- return processed_data
211+ return pd . DataFrame ( processed_data )
208212
209213 try :
210214 output_cols = list (self .preprocessing_pipeline .get_feature_names_out ())
@@ -233,7 +237,7 @@ def _transform_to_dataframe(
233237 self ,
234238 data : pd .DataFrame | np .ndarray ,
235239 index : pd .Index | None = None ,
236- ) -> pd .DataFrame | np . ndarray :
240+ ) -> pd .DataFrame :
237241 """Transform data through preprocessing pipeline and return correctly-labeled DataFrame.
238242
239243 Args:
@@ -242,7 +246,6 @@ def _transform_to_dataframe(
242246
243247 Returns:
244248 DataFrame with columns in self.feature_cols order, correctly labeled.
245- Falls back to returning the raw array if it is not 2D.
246249 """
247250 processed_data = self .preprocessing_pipeline .transform (data )
248251 return self ._relabel_processed_output (processed_data , index = index )
@@ -536,7 +539,7 @@ def calculate_fi_group_permutation(self, partition="dev", n_repeats=10):
536539 logger .set_log_group (LogGroup .TRAINING , f"{ self .training_id } " )
537540
538541 logger .info (f"Calculating permutation feature importances ({ partition } ). This may take a while..." )
539- np .random .seed (42 ) # reproducibility
542+ rng = np .random .RandomState (42 ) # local random state for reproducibility
540543 # fixed confidence level
541544 confidence_level = 0.95
542545 feature_cols = self .feature_cols
@@ -551,6 +554,8 @@ def calculate_fi_group_permutation(self, partition="dev", n_repeats=10):
551554 data = pd .concat ([self .x_dev_processed , self .data_dev [target_cols ]], axis = 1 )
552555 elif partition == "test" :
553556 data = pd .concat ([self .x_test_processed , self .data_test [target_cols ]], axis = 1 )
557+ else :
558+ raise ValueError (f"Invalid partition: '{ partition } '. Must be 'dev' or 'test'." )
554559
555560 if not set (feature_cols ).issubset (data .columns ):
556561 raise ValueError ("Features missing in provided dataset." )
@@ -581,7 +586,7 @@ def calculate_fi_group_permutation(self, partition="dev", n_repeats=10):
581586 # replace column with random selection from that column of data_all
582587 # we use data_all as the validation dataset may be small
583588 for feat in feature :
584- data_pfi [feat ] = np . random .choice (data [feat ], len (data_pfi ), replace = False )
589+ data_pfi [feat ] = rng .choice (data [feat ], len (data_pfi ), replace = False )
585590 pfi_score = get_score_from_model (
586591 model ,
587592 data_pfi ,
@@ -625,7 +630,6 @@ def calculate_fi_group_permutation(self, partition="dev", n_repeats=10):
625630 def calculate_fi_permutation (self , partition = "dev" , n_repeats = 10 ):
626631 """Permutation feature importance."""
627632 logger .info (f"Calculating permutation feature importances ({ partition } ). This may take a while..." )
628- np .random .seed (42 ) # reproducibility
629633 if self .ml_type == MLType .TIMETOEVENT :
630634 # sksurv models only provide inbuilt scorer (CI)
631635 # more work needed to support other metrics
@@ -641,6 +645,8 @@ def calculate_fi_permutation(self, partition="dev", n_repeats=10):
641645 elif partition == "test" :
642646 x = self .x_test_processed
643647 y = self .y_test
648+ else :
649+ raise ValueError (f"Invalid partition: '{ partition } '. Must be 'dev' or 'test'." )
644650
645651 perm_importance = permutation_importance (
646652 self .model ,
@@ -659,7 +665,6 @@ def calculate_fi_permutation(self, partition="dev", n_repeats=10):
659665
660666 def calculate_fi_lofo (self ):
661667 """LOFO feature importance."""
662- np .random .seed (42 ) # reproducibility
663668 logger .info ("Calculating LOFO feature importance. This may take a while..." )
664669 # first, dev only
665670 feature_cols = self .feature_cols
@@ -690,6 +695,9 @@ def calculate_fi_lofo(self):
690695 feature_cols_dict = {x : [x ] for x in feature_cols }
691696 lofo_features = {** feature_cols_dict , ** self .feature_groups }
692697
698+ if self .x_train_processed is None :
699+ raise RuntimeError ("x_train_processed is None — model must be fitted before calculating LOFO FI." )
700+
693701 # lofo
694702 fi_dev : list [tuple [str , float ]] = []
695703 fi_test : list [tuple [str , float ]] = []
0 commit comments