Skip to content

Commit 6c633a1

Browse files
committed
dump fit_observation_scitype in favour of fit_scitype
1 parent 6ef477b commit 6c633a1

File tree

3 files changed

+38
-17
lines changed

3 files changed

+38
-17
lines changed

docs/src/traits.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ In the examples column of the table below, `Continuous` is a name owned the pack
2828
| [`LearnAPI.human_name`](@ref)`(learner)` | human name for the learner; should be a noun | type name with spaces | "elastic net regressor" |
2929
| [`LearnAPI.iteration_parameter`](@ref)`(learner)` | symbolic name of an iteration parameter | `nothing` | :epochs |
3030
| [`LearnAPI.data_interface`](@ref)`(learner)` | Interface implemented by objects returned by [`obs`](@ref) | `Base.HasLength()` (supports `MLUtils.getobs/numobs`) | `Base.SizeUnknown()` (supports `iterate`) |
31-
| [`LearnAPI.fit_observation_scitype`](@ref)`(learner)` | upper bound on `scitype(observation)` for `observation` in `data` ensuring `fit(learner, data)` works | `Union{}` | `Tuple{AbstractVector{Continuous}, Continuous}` |
31+
| [`LearnAPI.fit_scitype`](@ref)`(learner)` | upper bound on `scitype(data)` ensuring `fit(learner, data)` works | `Union{}` | `Tuple{AbstractVector{Continuous}, Continuous}` |
3232
| [`LearnAPI.target_observation_scitype`](@ref)`(learner)` | upper bound on the scitype of each observation of the targget | `Any` | `Continuous` |
3333
| [`LearnAPI.is_static`](@ref)`(learner)` | `true` if `fit` consumes no data | `false` | `true` |
3434

@@ -105,7 +105,7 @@ LearnAPI.nonlearners
105105
LearnAPI.human_name
106106
LearnAPI.data_interface
107107
LearnAPI.iteration_parameter
108-
LearnAPI.fit_observation_scitype
108+
LearnAPI.fit_scitype
109109
LearnAPI.target_observation_scitype
110110
LearnAPI.is_static
111111
```

src/traits.jl

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -416,16 +416,33 @@ Implement if algorithm is iterative. Returns a symbol or `nothing`.
416416
"""
417417
iteration_parameter(::Any) = nothing
418418

419+
# """
420+
# LearnAPI.fit_observation_scitype(learner)
421+
422+
# Return an upper bound `S` on the scitype of individual observations guaranteed to work
423+
# when calling `fit`: if `observations = obs(learner, data)` and
424+
# `ScientificTypes.scitype(collect(o)) <:S` for each `o` in `observations`, then the call
425+
# `fit(learner, data)` is supported.
426+
427+
# $DOC_EXPLAIN_EACHOBS
428+
429+
# See also [`LearnAPI.target_observation_scitype`](@ref).
430+
431+
# # New implementations
432+
433+
# Optional. The fallback return value is `Union{}`.
434+
435+
# """
436+
# fit_observation_scitype(::Any) = Union{}
419437

420438
"""
421-
LearnAPI.fit_observation_scitype(learner)
439+
LearnAPI.fit_scitype(learner)
422440
423-
Return an upper bound `S` on the scitype of individual observations guaranteed to work
424-
when calling `fit`: if `observations = obs(learner, data)` and
425-
`ScientificTypes.scitype(collect(o)) <:S` for each `o` in `observations`, then the call
426-
`fit(learner, data)` is supported.
441+
Return an upper bound `S` on the `scitype` (scientific type) of `data` for which the call
442+
[`fit(learner, data)`](@ref) is supported. Specifically, if `ScientificTypes.scitype(data)
443+
<: S` then the call is guaranteed to succeed. If not, the call may or may not succeed.
427444
428-
$DOC_EXPLAIN_EACHOBS
445+
See ScientificTypes.jl documentation for more on the `scitype` function.
429446
430447
See also [`LearnAPI.target_observation_scitype`](@ref).
431448
@@ -434,21 +451,25 @@ See also [`LearnAPI.target_observation_scitype`](@ref).
434451
Optional. The fallback return value is `Union{}`.
435452
436453
"""
437-
fit_observation_scitype(::Any) = Union{}
454+
fit_scitype(::Any) = Union{}
438455

439456
"""
440457
LearnAPI.target_observation_scitype(learner)
441458
442-
Return an upper bound `S` on the scitype of each observation of an applicable target
443-
variable. Specifically, both of the following is always true:
459+
Return an upper bound `S` on the `scitype` (scientific type) of each observation of any
460+
target variable associated with the learner. See LearnAPI.jl documentation for the meaning
461+
of "target variable". See ScientificTypes.jl documentation for an explanation of the
462+
`scitype` function, which it provides.
463+
464+
Specifically, both of the following is always true:
444465
445466
- If `:(LearnAPI.target) in LearnAPI.functions(learner)` (i.e., `fit` consumes target
446467
variables) then "target" means anything returned by [`LearnAPI.target(learner,
447468
observations)`](@ref), where `observations = `[`LearnAPI.obs(learner, data)`](@ref) and
448-
`data` is an admissible argument in the call [`fit(learner, data)`](@ref).
469+
`data` is a supported argument in the call [`fit(learner, data)`](@ref).
449470
450-
- `S` will always be an upper bound on the scitype of (point) observations that could be
451-
conceivably extracted from the output of [`predict`](@ref).
471+
- `S` is an upper bound on the `scitype` of (point) observations that might normally be
472+
extracted from the output of [`predict`](@ref).
452473
453474
To illustate the second property, suppose we have
454475
@@ -458,9 +479,9 @@ ŷ = predict(model, Sampleable(), data_new)
458479
```
459480
460481
Then each individual sample generated by each "observation" of `ŷ` (a vector of sampleable
461-
objects, say) will be bound in scitype by `S`.
482+
objects, say) will be bound in `scitype` by `S`.
462483
463-
See also See also [`LearnAPI.fit_observation_scitype`](@ref).
484+
See also See also [`LearnAPI.fit_scitype`](@ref).
464485
465486
# New implementations
466487

test/traits.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ small = SmallLearner()
4343
@test LearnAPI.human_name(small) == "small learner"
4444
@test isnothing(LearnAPI.iteration_parameter(small))
4545
@test LearnAPI.data_interface(small) == LearnAPI.RandomAccess()
46-
@test !(6 isa LearnAPI.fit_observation_scitype(small))
46+
@test !(6 isa LearnAPI.fit_scitype(small))
4747
@test 6 isa LearnAPI.target_observation_scitype(small)
4848
@test !LearnAPI.is_static(small)
4949

0 commit comments

Comments
 (0)