Skip to content

Commit 3d41b21

Browse files
committed
adapt to revised target contract
1 parent cef45a0 commit 3d41b21

File tree

5 files changed

+50
-45
lines changed

5 files changed

+50
-45
lines changed

src/learners/regression.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ function LearnAPI.obs(::Ridge, data)
5050
names = Tables.columnnames(table) |> collect
5151
RidgeFitObs(Tables.matrix(table)', names, y)
5252
end
53+
54+
# for involutivity:
5355
LearnAPI.obs(::Ridge, data::RidgeFitObs) = data
5456

5557
# for observations:
@@ -82,9 +84,10 @@ LearnAPI.fit(learner::Ridge, data; kwargs...) =
8284
fit(learner, obs(learner, data); kwargs...)
8385

8486
# extracting stuff from training data:
85-
LearnAPI.target(::Ridge, data) = last(data)
8687
LearnAPI.target(::Ridge, observations::RidgeFitObs) = observations.y
8788
LearnAPI.features(::Ridge, observations::RidgeFitObs) = observations.A
89+
LearnAPI.target(learner::Ridge, data) =
90+
LearnAPI.target(learner, obs(learner, data))
8891

8992
# observations for consumption by `predict`:
9093
LearnAPI.obs(::RidgeFitted, X) = Tables.matrix(X)'
@@ -130,7 +133,7 @@ LearnAPI.fit(learner::Ridge, X, y; kwargs...) =
130133
fit(learner, (X, y); kwargs...)
131134

132135

133-
# # VARIATION OF RIDGE REGRESSION THAT USES FALLBACK OF LearnAPI.obs
136+
# # VARIATION OF RIDGE REGRESSION WITHOUT DATA FRONT END
134137

135138
# no docstring here - that goes with the constructor
136139
struct BabyRidge
@@ -169,9 +172,6 @@ function LearnAPI.fit(learner::BabyRidge, data; verbosity=1)
169172

170173
end
171174

172-
# extracting stuff from training data:
173-
LearnAPI.target(::BabyRidge, data) = last(data)
174-
175175
LearnAPI.learner(model::BabyRidgeFitted) = model.learner
176176

177177
LearnAPI.predict(model::BabyRidgeFitted, ::Point, Xnew) =

src/logging.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ const OBS = """
9393
9494
Attempting to call `observations = obs(learner, data)`.
9595
96+
"""
97+
const FEATURES0 = """
98+
99+
Attempting to call `LearnAPI.features(learner, data)`.
100+
96101
"""
97102
const FEATURES = """
98103
@@ -249,28 +254,23 @@ const TRANSFORM_ON_SELECTIONS2 = """
249254
`verbosity=1` for further explanation of `model3` and `X3`.
250255
251256
"""
252-
const TARGET = """
257+
const TARGET0 = """
253258
254-
Attempting to call `LearnAPI.target(learner, observations)` (fallback returns
255-
`nothing`).
259+
Attempting to call `LearnAPI.target(learner, data)` (fallback returns
260+
`last(data)`).
256261
257262
"""
258-
const TARGET_IN_FUNCTIONS = """
259-
260-
Checking that `:(LearnAPI.target)` is included in `LearnAPI.functions(learner)`.
261-
262-
"""
263-
const TARGET_NOT_IN_FUNCTIONS = """
263+
const TARGET = """
264264
265-
Checking that `:(LearnAPI.target)` is excluded from `LearnAPI.functions(learner)`, as
266-
`LearnAPI.target` has not been overloaded.
265+
Attempting to call `LearnAPI.target(learner, observations)` (fallback returns
266+
`last(observations)`).
267267
268268
"""
269269
const TARGET_SELECTIONS = """
270270
271-
Checking that all observations can be extracted from `LearnAPI.target(learner,
272-
observations)` using the data interface declared by
273-
`LearnAPI.data_interface(learner)`.
271+
Checking that all observations can be extracted from `LearnAPI.target(learner, data)`
272+
using the data interface declared by `LearnAPI.data_interface(learner)`. Doing the
273+
same with `data` replaced with `observations`.
274274
275275
"""
276276
const WEIGHTS = """

src/testapi.jl

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ If `LearnAPI.is_static(learner) == false`, then:
6868
`transform` are called with no data. Otherwise, they are called with `X`.
6969
7070
If instead `LearnAPI.is_static(learner) == true`, then `fit` and its cousins are called
71-
without any data, and `dataset` is passed directly to `fit` and/or `transform`.
71+
without any data, and `dataset` is passed directly to `predict` and/or `transform`.
7272
7373
"""
7474
macro testapi(learner, data...)
@@ -197,8 +197,15 @@ macro testapi(learner, data...)
197197
obs(learner, data)
198198
end
199199

200-
X = @logged_testset $FEATURES verbosity begin
201-
LearnAPI.features(learner, observations)
200+
X = if _is_static
201+
data
202+
else
203+
@logged_testset $FEATURES0 verbosity begin
204+
LearnAPI.features(learner, data)
205+
end
206+
@logged_testset $FEATURES verbosity begin
207+
LearnAPI.features(learner, observations)
208+
end
202209
end
203210

204211
if !(isnothing(X))
@@ -437,24 +444,24 @@ macro testapi(learner, data...)
437444

438445
# target
439446

440-
_y = @logged_testset $TARGET verbosity begin
441-
LearnAPI.target(learner, observations)
442-
end
443-
444-
if !(isnothing(_y))
445-
@logged_testset $TARGET_IN_FUNCTIONS verbosity begin
446-
Test.@test :(LearnAPI.target) in _functions
447+
if :(LearnAPI.target) in _functions
448+
_y = @logged_testset $TARGET0 verbosity begin
449+
LearnAPI.target(learner, data)
450+
end
451+
@logged_testset $TARGET verbosity begin
452+
LearnAPI.target(learner, observations)
447453
end
448-
y = @logged_testset $TARGET_SELECTIONS verbosity begin
454+
@logged_testset $TARGET_SELECTIONS verbosity begin
449455
LearnTestAPI.learner_get(
450456
learner,
451457
data,
452458
data->LearnAPI.target(learner, data),
453459
)
454-
end
455-
else
456-
@logged_testset $TARGET_NOT_IN_FUNCTIONS verbosity begin
457-
Test.@test !(:(LearnAPI.target) in _functions)
460+
LearnTestAPI.learner_get(
461+
learner,
462+
observations,
463+
data->LearnAPI.target(learner, data),
464+
)
458465
end
459466
end
460467

@@ -645,19 +652,19 @@ macro testapi(learner, data...)
645652
Test.@test _human_name isa String
646653
end
647654

648-
@logged_testset $FIT_SCITYPE verbosity begin
649-
S = LearnAPI.fit_scitype(learner)
650-
if S == Union{}
651-
push!(missing_traits, :(LearnAPI.fit_scitype))
652-
else
655+
S = LearnAPI.fit_scitype(learner)
656+
if S == Union{}
657+
push!(missing_traits, :(LearnAPI.fit_scitype))
658+
else
659+
@logged_testset $FIT_SCITYPE verbosity begin
653660
Test.@test ScientificTypes.scitype(data) <: S
654661
end
655662
end
656663

657664
S = LearnAPI.target_observation_scitype(learner)
658665
testable = :(LearnAPI.target) in _functions &&
659666
_data_interface in (LearnAPI.RandomAccess(), LearnAPI.FiniteIterable())
660-
if S == Any
667+
if S == Any && (LearnAPI.target) in _functions
661668
push!(missing_traits, :(LearnAPI.target_observation_scitype))
662669
elseif testable
663670
@logged_testset $TARGET_OBSERVATION_SCITYPE verbosity begin

src/tools.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,7 @@ end
8383
*Private method.*
8484
8585
Extract from `LearnAPI.obs(learner, data)`, after applying `apply`, all observations,
86-
using the data access API specified by `LearnAPI.data_interface(learner)`. Used to test
87-
that the output of `data` indeed implements the specified interface.
86+
using the data access API specified by `LearnAPI.data_interface(learner)`.
8887
8988
"""
9089
learner_get(learner, data, apply=identity) =
@@ -97,8 +96,7 @@ learner_get(learner, data, apply=identity) =
9796
9897
Extract from `LearnAPI.obs(model, data)`, after applying `apply`, all observations, using
9998
the data access API specified by `LearnAPI.data_interface(learner)`, where `learner =
100-
LearnAPI.learner(model)`. Used to test that the output of `data` indeed implements the
101-
specified interface.
99+
LearnAPI.learner(model)`.
102100
103101
"""
104102
model_get(model, data, apply =identity) =

test/learners/dimension_reduction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ U, Vt = r.U, r.Vt
1111
X = U*diagm([1, 2, 3, 0.01, 0.01])*Vt
1212

1313
learner = LearnTestAPI.TruncatedSVD(codim=2)
14-
@testapi learner X verbosity=1
14+
@testapi learner X verbosity=0
1515

1616
@testset "extra test for truncated SVD" begin
1717
model = @test_logs(

0 commit comments

Comments
 (0)