@@ -78,6 +78,9 @@ function fitted_machine(model, data...; throw=false, verbosity=1)
78
78
mach = model isa Static ? machine (model) :
79
79
machine (model, data... )
80
80
fit! (mach, verbosity= - 1 )
81
+ train, _ = MLJBase. partition (1 : MLJBase. nrows (first (data)), 0.5 )
82
+ fit! (mach, rows= train, verbosity= - 1 )
83
+ fit! (mach, rows= :, verbosity= - 1 )
81
84
MLJBase. report (mach)
82
85
MLJBase. fitted_params (mach)
83
86
mach
@@ -89,12 +92,17 @@ function operations(fitted_machine, data...; throw=false, verbosity=1)
89
92
attempt (finalize (message, verbosity); throw) do
90
93
operations = String[]
91
94
methods = MLJBase. implemented_methods (fitted_machine. model)
95
+ _, test = MLJBase. partition (1 : MLJBase. nrows (first (data)), 0.5 )
92
96
if :predict in methods
93
97
predict (fitted_machine, first (data))
98
+ predict (fitted_machine, rows= test)
99
+ predict (fitted_machine, rows= :)
94
100
push! (operations, " predict" )
95
101
end
96
102
if :transform in methods
97
103
W = transform (fitted_machine, first (data))
104
+ transform (fitted_machine, rows= test)
105
+ transform (fitted_machine, rows= :)
98
106
push! (operations, " transform" )
99
107
if :inverse_transform in methods
100
108
inverse_transform (fitted_machine, W)
@@ -104,4 +112,3 @@ function operations(fitted_machine, data...; throw=false, verbosity=1)
104
112
join (operations, " , " )
105
113
end
106
114
end
107
-
0 commit comments