Skip to content

Commit 678eb95

Browse files
authored
Merge pull request #10 from JuliaAI/static-transform
Fix handling of Static transformers
2 parents ac841b8 + 03d4efd commit 678eb95

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

src/attemptors.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,11 @@ function operations(fitted_machine, data...; throw=false, verbosity=1)
101101
push!(operations, "predict")
102102
end
103103
if :transform in methods
104-
W = transform(fitted_machine, first(data))
104+
W = if model isa Static
105+
transform(fitted_machine, data...)
106+
else
107+
transform(fitted_machine, first(data))
108+
end
105109
model isa Static || transform(fitted_machine, rows=test)
106110
model isa Static || transform(fitted_machine, rows=:)
107111
push!(operations, "transform")

test/attemptors.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,32 @@ MLJBase.load_path(::Type{<:DummyModel}) = "DummyPackage.Some.Where.Over.The.Rain
3737
@test outcome == ""
3838
end
3939

40+
struct DummyStatic <: Static end
41+
MLJBase.transform(::DummyStatic, _, x, y) = hcat(x, y)
42+
MLJBase.package_name(::Type{<:DummyStatic}) = "DummyPackage"
43+
MLJBase.load_path(::Type{<:DummyStatic}) = "DummyPackage.Some.Thing.Different"
44+
45+
struct SupervisedTransformer <: Deterministic end
46+
MLJBase.fit(::SupervisedTransformer, verbosity, X, y) = (42, nothing, nothing)
47+
MLJBase.predict(::SupervisedTransformer, _, Xnew) = fill(4.5, length(Xnew))
48+
MLJBase.transform(model::SupervisedTransformer, Θ, Xnew) =
49+
predict(model, Θ, Xnew)
50+
MLJBase.package_name(::Type{<:SupervisedTransformer}) = "DummyPackage"
51+
MLJBase.load_path(::Type{<:SupervisedTransformer}) =
52+
"DummyPackage.Some.Thing.Else"
53+
54+
@testset "operations" begin
55+
X = fill(1.2, 10)
56+
y = X
57+
mach = machine(SupervisedTransformer(), X, y) |> fit!
58+
operations, outcome = MLJTestInterface.operations(mach, X, y, throw=true)
59+
@test operations == "predict, transform"
60+
@test outcome == ""
61+
62+
smach = machine(DummyStatic())
63+
operations, outcome = MLJTestInterface.operations(smach, X, y)
64+
@test operations == "transform"
65+
@test outcome == ""
66+
end
67+
4068
true

0 commit comments

Comments
 (0)