@@ -19,19 +19,20 @@ const ALL_MODELS = (REG_MODELS..., CLF_MODELS...)
19
19
20
20
function MMI. fit (m:: Union{REG_MODELS...} , verb:: Int , X, y)
21
21
Xmatrix = MMI. matrix (X)
22
+ features = (sch = MMI. schema (X)) === nothing ? nothing : sch. names
22
23
reg = glr (m)
23
24
solver = m. solver === nothing ? _solver (reg, size (Xmatrix)) : m. solver
24
25
# get the parameters
25
26
θ = fit (reg, Xmatrix, y; solver= solver)
26
27
# return
27
- return θ , nothing , NamedTuple {} ()
28
+ return (θ, features) , nothing , NamedTuple {} ()
28
29
end
29
30
30
- MMI. predict (m:: Union{REG_MODELS...} , θ , Xnew) = apply_X (MMI. matrix (Xnew), θ)
31
+ MMI. predict (m:: Union{REG_MODELS...} , (θ, features) , Xnew) = apply_X (MMI. matrix (Xnew), θ)
31
32
32
- function MMI. fitted_params (m:: Union{REG_MODELS...} , θ )
33
- m. fit_intercept && return (coefs = θ[1 : end - 1 ], intercept = θ[end ])
34
- return (coefs = θ , intercept = nothing )
33
+ function MMI. fitted_params (m:: Union{REG_MODELS...} , (θ, features) )
34
+ m. fit_intercept && return (coefs = coef_vec ( θ[1 : end - 1 ], features) , intercept = θ[end ])
35
+ return (coefs = coef_vec (θ, features) , intercept = nothing )
35
36
end
36
37
37
38
#= ===========
40
41
41
42
function MMI. fit (m:: Union{CLF_MODELS...} , verb:: Int , X, y)
42
43
Xmatrix = MMI. matrix (X)
44
+ features = (sch = MMI. schema (X)) === nothing ? nothing : sch. names
43
45
yplain = convert .(Int, MMI. int (y))
44
46
classes = MMI. classes (y[1 ])
45
47
nclasses = length (classes)
@@ -56,10 +58,10 @@ function MMI.fit(m::Union{CLF_MODELS...}, verb::Int, X, y)
56
58
# get the parameters
57
59
θ = fit (clf, Xmatrix, yplain, solver= solver)
58
60
# return
59
- return (θ, c, classes), nothing , NamedTuple {} ()
61
+ return (θ, features, c, classes), nothing , NamedTuple {} ()
60
62
end
61
63
62
- function MMI. predict (m:: Union{CLF_MODELS...} , (θ, c, classes), Xnew)
64
+ function MMI. predict (m:: Union{CLF_MODELS...} , (θ, features, c, classes), Xnew)
63
65
Xmatrix = MMI. matrix (Xnew)
64
66
preds = apply_X (Xmatrix, θ, c)
65
67
# binary classification
@@ -73,20 +75,29 @@ function MMI.predict(m::Union{CLF_MODELS...}, (θ, c, classes), Xnew)
73
75
return [MMI. UnivariateFinite (classes, preds[i, :]) for i in 1 : size (Xmatrix,1 )]
74
76
end
75
77
76
- function MMI. fitted_params (m:: Union{CLF_MODELS...} , (θ, c, classes))
78
+ function MMI. fitted_params (m:: Union{CLF_MODELS...} , (θ, features, c, classes))
79
+ function _fitted_params (coefs, features, intercept)
80
+ return (classes = classes, coefs = coef_vec (coefs, features), intercept = intercept)
81
+ end
77
82
if c > 1
83
+ W = reshape (θ, :, c)
78
84
if m. fit_intercept
79
- W = reshape (θ, div (length (θ), c), c)
80
- return (coefs = W, intercept = nothing )
85
+ return _fitted_params (W, features, W[end , :])
81
86
end
82
- W = reshape (θ, p+ 1 , c)
83
- return (coefs = W[1 : p, :], intercept = W[end , :])
87
+ return _fitted_params (W[1 : end - 1 , :], features, nothing )
84
88
end
85
89
# single class
86
- m. fit_intercept && return (coefs = θ[1 : end - 1 ], intercept = θ[end ])
87
- return (coefs = θ, intercept = nothing )
90
+ m. fit_intercept && return _fitted_params ( θ[1 : end - 1 ], features, θ[end ])
91
+ return _fitted_params ( θ, features, nothing )
88
92
end
89
93
94
+ @static VERSION < v " 1.1" && (eachrow (A:: AbstractVecOrMat ) = (view (A, i, :) for i in axes (A, 1 )))
95
+
96
+ coef_vec (W:: AbstractMatrix , features) = [feature => coef for (feature, coef) in zip (features, eachrow (W))]
97
+ coef_vec (θ:: AbstractVector , features) = [feature => coef for (feature, coef) in zip (features, θ)]
98
+ coef_vec (W:: AbstractMatrix , :: Nothing ) = W
99
+ coef_vec (θ:: AbstractVector , :: Nothing ) = θ
100
+
90
101
#= =======================
91
102
METADATA FOR ALL MODELS
92
103
======================= =#
0 commit comments