Skip to content

Commit d82eaa5

Browse files
committed
add a test for predict and transform slurping fallbacks
oops
1 parent 9b9e4d4 commit d82eaa5

File tree

3 files changed

+22
-6
lines changed

3 files changed

+22
-6
lines changed

src/predict_transform.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@ function DOC_IMPLEMENTED_METHODS(name; overloaded=false)
44
"[`LearnAPI.functions`](@ref) trait. "
55
end
66

7-
const OPERATIONS = (:predict, :transform, :inverse_transform)
8-
const DOC_OPERATIONS_LIST_SYMBOL = join(map(op -> "`:$op`", OPERATIONS), ", ")
9-
const DOC_OPERATIONS_LIST_FUNCTION = join(map(op -> "`LearnAPI.$op`", OPERATIONS), ", ")
10-
117
DOC_MUTATION(op) =
128
"""
139
@@ -171,8 +167,8 @@ $(DOC_MUTATION(:transform))
171167
$(DOC_DATA_INTERFACE(:transform))
172168
173169
"""
174-
transform(model, data1, data2...; kwargs...) =
175-
transform(model, (data1, datas...); kwargs...) # automatic slurping
170+
transform(model, data1, data2, datas...; kwargs...) =
171+
transform(model, (data1, data2, datas...); kwargs...) # automatic slurping
176172

177173
"""
178174
inverse_transform(model, data)

test/predict_transform.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using Test
2+
using LearnAPI
3+
4+
struct Goose end
5+
6+
LearnAPI.fit(algorithm::Goose) = Ref(algorithm)
7+
LearnAPI.algorithm(::Base.RefValue{Goose}) = Goose()
8+
LearnAPI.predict(::Base.RefValue{Goose}, ::Point, data) = sum(data)
9+
LearnAPI.transform(::Base.RefValue{Goose}, data) = prod(data)
10+
@trait Goose kinds_of_proxy = (Point(),)
11+
12+
@testset "predict and transform argument slurping" begin
13+
model = fit(Goose())
14+
@test predict(model, Point(), 2, 3, 4) == 9
15+
@test predict(model, 2, 3, 4) == 9
16+
@test transform(model, 2, 3, 4) == 24
17+
end
18+
19+
true

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ test_files = [
44
"tools.jl",
55
"traits.jl",
66
"clone.jl",
7+
"predict_transform.jl",
78
"patterns/regression.jl",
89
"patterns/static_algorithms.jl",
910
"patterns/ensembling.jl",

0 commit comments

Comments
 (0)