Skip to content

Commit 4d071cc

Browse files
committed
strengthen two of the attemptors
1 parent e988c4a commit 4d071cc

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

src/attemptors.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ function fitted_machine(model, data...; throw=false, verbosity=1)
7878
mach = model isa Static ? machine(model) :
7979
machine(model, data...)
8080
fit!(mach, verbosity=-1)
81+
train, _ = MLJBase.partition(1:MLJBase.nrows(first(data)), 0.5)
82+
fit!(mach, rows=train, verbosity=-1)
83+
fit!(mach, rows=:, verbosity=-1)
8184
MLJBase.report(mach)
8285
MLJBase.fitted_params(mach)
8386
mach
@@ -89,12 +92,17 @@ function operations(fitted_machine, data...; throw=false, verbosity=1)
8992
attempt(finalize(message, verbosity); throw) do
9093
operations = String[]
9194
methods = MLJBase.implemented_methods(fitted_machine.model)
95+
_, test = MLJBase.partition(1:MLJBase.nrows(first(data)), 0.5)
9296
if :predict in methods
9397
predict(fitted_machine, first(data))
98+
predict(fitted_machine, rows=test)
99+
predict(fitted_machine, rows=:)
94100
push!(operations, "predict")
95101
end
96102
if :transform in methods
97103
W = transform(fitted_machine, first(data))
104+
transform(fitted_machine, rows=test)
105+
transform(fitted_machine, rows=:)
98106
push!(operations, "transform")
99107
if :inverse_transform in methods
100108
inverse_transform(fitted_machine, W)
@@ -104,4 +112,3 @@ function operations(fitted_machine, data...; throw=false, verbosity=1)
104112
join(operations, ", ")
105113
end
106114
end
107-

test/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ expected_report2 = (
3333
y;
3434
mod=@__MODULE__,
3535
level=2,
36-
verbosity=0
36+
verbosity=0,
3737
)
3838
@test isempty(fails)
3939
@test report[1] == expected_report1

0 commit comments

Comments
 (0)