@@ -49,7 +49,7 @@ def cross_validate(
4949 cv : int ,
5050 n_repeats : int ,
5151 gpr : DiffusionGPR ,
52- ) -> dict [ int , list [ tuple [ np .ndarray , np . ndarray , np . ndarray , np . ndarray ]]] :
52+ ) -> np .ndarray :
5353 """
5454 Perform the experiment by estimating the dMRI signal using a Gaussian process model.
5555
@@ -211,10 +211,10 @@ def main() -> None:
211211
212212 if args .kfold :
213213 # Use Scikit-learn cross validation
214- scores = defaultdict (list , {})
214+ scores : dict [ str , list ] = defaultdict (list , {})
215215 for n in args .kfold :
216216 for i in range (args .repeats ):
217- cv_scores = - 1.0 * cross_validate (X , y .T , n , gpr )
217+ cv_scores = - 1.0 * cross_validate (X , y .T , n , i , gpr )
218218 scores ["rmse" ] += cv_scores .tolist ()
219219 scores ["repeat" ] += [i ] * len (cv_scores )
220220 scores ["n_folds" ] += [n ] * len (cv_scores )
@@ -224,7 +224,7 @@ def main() -> None:
224224 print (f"Finished { n } -fold cross-validation" )
225225
226226 scores_df = pd .DataFrame (scores )
227- scores_df .to_csv (args .output_scores , sep = "\t " , index = None , na_rep = "n/a" )
227+ scores_df .to_csv (args .output_scores , sep = "\t " , index = False , na_rep = "n/a" )
228228
229229 grouped = scores_df .groupby (["n_folds" ])
230230 print (grouped [["rmse" ]].mean ())
0 commit comments