11module SKCrossValidators
22
3- using PyCall
3+ import PythonCall
4+ const PYC= PythonCall
45
56# standard included modules
67using DataFrames
@@ -11,11 +12,11 @@ using ..Utils
1112import .. CrossValidators: crossvalidate
1213export crossvalidate
1314
14- const metric_dict = Dict {String,PyObject } ()
15- const SKM = PyNULL ()
15+ const metric_dict = Dict {String,PYC.Py } ()
16+ const SKM = PYC . pynew ()
1617
1718function __init__ ()
18- copy ! (SKM, pyimport_conda (" sklearn.metrics" , " scikit-learn " ))
19+ PYC . pycopy ! (SKM, PYC . pyimport (" sklearn.metrics" ))
1920
2021 metric_dict[" roc_auc_score" ] = SKM. roc_auc_score
2122 metric_dict[" accuracy_score" ] = SKM. accuracy_score
6768
6869Runs K-fold cross-validation using balanced accuracy as the default. It support the
6970following metrics for classification:
70- - accuracy_score
71- - balanced_accuracy_score
72- - cohen_kappa_score
73- - jaccard_score
74- - matthews_corrcoef
75- - hamming_loss
76- - zero_one_loss
77- - f1_score
78- - precision_score
79- - recall_score
71+ - " accuracy_score"
72+ - " balanced_accuracy_score"
73+ - " cohen_kappa_score"
74+ - " jaccard_score"
75+ - " matthews_corrcoef"
76+ - " hamming_loss"
77+ - " zero_one_loss"
78+ - " f1_score"
79+ - " precision_score"
80+ - " recall_score"
8081
8182and the following metrics for regression:
82- - mean_squared_error
83- - mean_squared_log_error
84- - median_absolute_error
85- - r2_score
86- - max_error
87- - explained_variance_score
83+ - " mean_squared_error"
84+ - " mean_squared_log_error"
85+ - " median_absolute_error"
86+ - " r2_score"
87+ - " max_error"
88+ - " explained_variance_score"
8889"""
8990function crossvalidate (pl:: Machine ,X:: DataFrame ,Y:: Vector ,
9091 sfunc:: String ; nfolds= 10 ,verbose:: Bool = true )
92+
93+ YC= Y
94+ if ! (eltype (YC) <: Real )
95+ YC = Y |> Vector{String}
96+ end
97+
9198 checkfun (sfunc)
9299 pfunc = metric_dict[sfunc]
93- metric (a,b) = pfunc (a,b)
94- crossvalidate (pl,X,Y ,metric,nfolds,verbose)
100+ metric (a,b) = pfunc (a,b) |> (x -> PYC . pyconvert (Float64,x))
101+ crossvalidate (pl,X,YC ,metric,nfolds,verbose)
95102end
96103
97- function crossvalidate (pl:: Machine ,X:: DataFrame ,Y:: Vector ,sfunc:: String ,folds :: Int )
98- crossvalidate (pl,X,Y,sfunc, nfolds= folds )
104+ function crossvalidate (pl:: Machine ,X:: DataFrame ,Y:: Vector ,sfunc:: String ,nfolds :: Int )
105+ crossvalidate (pl,X,Y,sfunc; nfolds)
99106end
100107
101- function crossvalidate (pl:: Machine ,X:: DataFrame ,Y:: Vector ,sfunc:: String ,verby :: Bool )
102- crossvalidate (pl,X,Y,sfunc, verbose= verby )
108+ function crossvalidate (pl:: Machine ,X:: DataFrame ,Y:: Vector ,sfunc:: String ,verbose :: Bool )
109+ crossvalidate (pl,X,Y,sfunc; verbose)
103110end
104111
105112function crossvalidate (pl:: Machine ,X:: DataFrame ,Y:: Vector ,
106- sfunc:: String , folds :: Int ,verby :: Bool )
107- crossvalidate (pl,X,Y,sfunc, nfolds= folds ,verbose= verby )
113+ sfunc:: String , nfolds :: Int ,verbose :: Bool )
114+ crossvalidate (pl,X,Y,sfunc; nfolds,verbose)
108115end
109116
110-
111-
112117function crossvalidate (pl:: Machine ,X:: DataFrame ,Y:: Vector ,
113- sfunc:: String ,averagetype:: String , nfolds= 10 ,verbose:: Bool = true )
118+ sfunc:: String ,averagetype:: String ; nfolds= 10 ,verbose:: Bool = true )
114119 checkfun (sfunc)
115120 pfunc = metric_dict[sfunc]
116- metric (a,b) = pfunc (a,b,average= averagetype)
121+ metric (a,b) = pfunc (a,b,average= averagetype) |> (x -> PYC . pyconvert (Float64,x))
117122 crossvalidate (pl,X,Y,metric,nfolds,verbose)
118123end
119124
0 commit comments