Skip to content

Commit 6312f00

Browse files
authored
Merge pull request #14 from JuliaAI/stricter-static
Make sure `:transform` is tested for transformers
2 parents 3ea48e1 + 87b99ad commit 6312f00

File tree

3 files changed

+19
-0
lines changed

3 files changed

+19
-0
lines changed

.github/codecov.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
coverage:
2+
status:
3+
project:
4+
default:
5+
threshold: 0.5%

src/attemptors.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ function operations(fitted_machine, data...; throw=false, verbosity=1)
9393
model = fitted_machine.model
9494
operations = String[]
9595
methods = MLJBase.implemented_methods(fitted_machine.model)
96+
if model isa Static && !(:transform in methods)
97+
push!(methods, :transform)
98+
end
9699
_, test = MLJBase.partition(1:MLJBase.nrows(first(data)), 0.5)
97100
if :predict in methods
98101
predict(fitted_machine, first(data))

test/attemptors.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ MLJBase.transform(::DummyStatic, _, x, y) = hcat(x, y)
4242
MLJBase.package_name(::Type{<:DummyStatic}) = "DummyPackage"
4343
MLJBase.load_path(::Type{<:DummyStatic}) = "DummyPackage.Some.Thing.Different"
4444

45+
struct DummyStatic2 <: Static end
46+
MLJBase.transform(::DummyStatic2, _, x, y) = hcat(x, y)
47+
MLJBase.package_name(::Type{<:DummyStatic2}) = "DummyPackage"
48+
MLJBase.load_path(::Type{<:DummyStatic2}) = "DummyPackage.Some.Thing.Different"
49+
MLJBase.implemented_methods(::Type{<:DummyStatic2}) = Symbol[]
50+
4551
struct SupervisedTransformer <: Deterministic end
4652
MLJBase.fit(::SupervisedTransformer, verbosity, X, y) = (42, nothing, nothing)
4753
MLJBase.predict(::SupervisedTransformer, _, Xnew) = fill(4.5, length(Xnew))
@@ -63,6 +69,11 @@ MLJBase.load_path(::Type{<:SupervisedTransformer}) =
6369
operations, outcome = MLJTestInterface.operations(smach, X, y)
6470
@test operations == "transform"
6571
@test outcome == ""
72+
73+
smach = machine(DummyStatic2())
74+
operations, outcome = MLJTestInterface.operations(smach, X, y)
75+
@test operations == "transform"
76+
@test outcome == ""
6677
end
6778

6879
true

0 commit comments

Comments
 (0)