Skip to content

Commit 702c7b9

Browse files
authored
Merge pull request #6 from JuliaAI/stricter-tests-around-resampling
Strengthen tests to catch more `selectrows`/`reformat` issues
2 parents e988c4a + f192102 commit 702c7b9

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJTestInterface"
22
uuid = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.1.1"
4+
version = "0.2.0"
55

66
[deps]
77
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"

src/attemptors.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ function fitted_machine(model, data...; throw=false, verbosity=1)
7878
mach = model isa Static ? machine(model) :
7979
machine(model, data...)
8080
fit!(mach, verbosity=-1)
81+
train, _ = MLJBase.partition(1:MLJBase.nrows(first(data)), 0.5)
82+
fit!(mach, rows=train, verbosity=-1)
83+
fit!(mach, rows=:, verbosity=-1)
8184
MLJBase.report(mach)
8285
MLJBase.fitted_params(mach)
8386
mach
@@ -89,12 +92,17 @@ function operations(fitted_machine, data...; throw=false, verbosity=1)
8992
attempt(finalize(message, verbosity); throw) do
9093
operations = String[]
9194
methods = MLJBase.implemented_methods(fitted_machine.model)
95+
_, test = MLJBase.partition(1:MLJBase.nrows(first(data)), 0.5)
9296
if :predict in methods
9397
predict(fitted_machine, first(data))
98+
predict(fitted_machine, rows=test)
99+
predict(fitted_machine, rows=:)
94100
push!(operations, "predict")
95101
end
96102
if :transform in methods
97103
W = transform(fitted_machine, first(data))
104+
transform(fitted_machine, rows=test)
105+
transform(fitted_machine, rows=:)
98106
push!(operations, "transform")
99107
if :inverse_transform in methods
100108
inverse_transform(fitted_machine, W)
@@ -104,4 +112,3 @@ function operations(fitted_machine, data...; throw=false, verbosity=1)
104112
join(operations, ", ")
105113
end
106114
end
107-

test/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ expected_report2 = (
3333
y;
3434
mod=@__MODULE__,
3535
level=2,
36-
verbosity=0
36+
verbosity=0,
3737
)
3838
@test isempty(fails)
3939
@test report[1] == expected_report1

0 commit comments

Comments
 (0)