Skip to content

Commit 6279b25

Browse files
committed
add @functions and have LearnAPI.functions() return accessors
1 parent 8e8123a commit 6279b25

File tree

6 files changed

+56
-22
lines changed

6 files changed

+56
-22
lines changed

README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ julia> ridge = Ridge(lambda=0.1)
2222
Inspect available functionality:
2323

2424
```
25-
julia> LearnAPI.functions(ridge)
26-
(:(LearnAPI.fit), :(LearnAPI.learner), :(LearnAPI.strip), :(LearnAPI.obs),
27-
:(LearnAPI.features), :(LearnAPI.target), :(LearnAPI.predict), :(LearnAPI.coefficients))
25+
julia> @functions ridge
26+
(fit, LearnAPI.learner, LearnAPI.strip, obs, LearnAPI.features, LearnAPI.target, predict, LearnAPI.coefficients
2827
```
2928

3029
Train:

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ y = <some training target>
5555
Xnew = <some test or production features>
5656

5757
# List LearnaAPI functions implemented for `forest`:
58-
LearnAPI.functions(forest)
58+
@functions forest
5959

6060
# Train:
6161
model = fit(forest, X, y)

docs/src/reference.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,9 @@ minimal (but useless) implementation, see the implementation of `SmallLearner`
199199
## Utilities
200200

201201
```@docs
202+
@functions
202203
LearnAPI.clone
203-
LearnAPI.@trait
204+
@trait
204205
```
205206

206207
---

src/LearnAPI.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ include("accessor_functions.jl")
1111
include("traits.jl")
1212
include("clone.jl")
1313

14-
export @trait
14+
export @trait, @functions
1515
export fit, update, update_observations, update_features
1616
export predict, transform, inverse_transform, obs
1717

src/accessor_functions.jl

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -312,23 +312,23 @@ function training_labels end
312312

313313
# :extras intentionally excluded:
314314
const ACCESSOR_FUNCTIONS_WITHOUT_EXTRAS = (
315-
learner,
316-
coefficients,
317-
intercept,
318-
tree,
319-
trees,
320-
feature_names,
321-
feature_importances,
322-
training_labels,
323-
training_losses,
324-
training_predictions,
325-
training_scores,
326-
components,
315+
:(LearnAPI.learner),
316+
:(LearnAPI.coefficients),
317+
:(LearnAPI.intercept),
318+
:(LearnAPI.tree),
319+
:(LearnAPI.trees),
320+
:(LearnAPI.feature_names),
321+
:(LearnAPI.feature_importances),
322+
:(LearnAPI.training_labels),
323+
:(LearnAPI.training_losses),
324+
:(LearnAPI.training_predictions),
325+
:(LearnAPI.training_scores),
326+
:(LearnAPI.components),
327327
)
328328

329329
const ACCESSOR_FUNCTIONS_WITHOUT_EXTRAS_LIST = join(
330330
map(ACCESSOR_FUNCTIONS_WITHOUT_EXTRAS) do f
331-
"[`LearnAPI.$f`](@ref)"
331+
"[`$f`](@ref)"
332332
end,
333333
", ",
334334
" and ",
@@ -354,11 +354,12 @@ $(DOC_IMPLEMENTED_METHODS(":(LearnAPI.training_labels)")).
354354
"""
355355
function extras end
356356

357-
const ACCESSOR_FUNCTIONS = (extras, ACCESSOR_FUNCTIONS_WITHOUT_EXTRAS...)
357+
const ACCESSOR_FUNCTIONS =
358+
(:(LearnAPI.extras), ACCESSOR_FUNCTIONS_WITHOUT_EXTRAS...)
358359

359360
const ACCESSOR_FUNCTIONS_LIST = join(
360361
map(ACCESSOR_FUNCTIONS) do f
361-
"[`LearnAPI.$f`](@ref)"
362+
"[`$f`](@ref)"
362363
end,
363364
", ",
364365
" and ",

src/traits.jl

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,18 @@ with `learner`, or an associated model (object returned by `fit(learner, ...)`,
6565
first argument. Learner traits (methods for which `learner` is the *only* argument)
6666
are excluded.
6767
68+
To return actual functions, instead of symbols, use [`@functions`](@ref)` learner`
69+
instead.
70+
6871
The returned tuple may include expressions like `:(DecisionTree.print_tree)`, which
6972
reference functions not owned by LearnAPI.jl.
7073
7174
The understanding is that `learner` is a LearnAPI-compliant object whenever the return
7275
value is non-empty.
7376
77+
Do `LearnAPI.functions()` to list all possible elements of the return value owned by
78+
LearnAPI.jl.
79+
7480
# Extended help
7581
7682
# New implementations
@@ -100,6 +106,7 @@ learner-specific ones. The LearnAPI.jl accessor functions are: $ACCESSOR_FUNCTIO
100106
(`LearnAPI.strip` is always included).
101107
102108
"""
109+
functions(::Any) = ()
103110
functions() = (
104111
:(LearnAPI.fit),
105112
:(LearnAPI.learner),
@@ -114,8 +121,34 @@ functions() = (
114121
:(LearnAPI.predict),
115122
:(LearnAPI.transform),
116123
:(LearnAPI.inverse_transform),
124+
ACCESSOR_FUNCTIONS...,
117125
)
118-
functions(::Any) = ()
126+
127+
"""
128+
@functions learner
129+
130+
Return a tuple of functions that can be meaningfully applied with `learner`, or an
131+
associated model, as the first argument. An "associated model" is an object returned by
132+
`fit(learner, ...)`. Learner traits (methods for which `learner` is the *only* argument)
133+
are excluded.
134+
135+
```
136+
julia> @functions my_feature_selector
137+
(fit, LearnAPI.learner, strip, obs, transform)
138+
139+
```
140+
141+
New learner implementations should overload [`LearnAPI.functions`](@ref).
142+
143+
See also [`LearnAPI.functions`](@ref).
144+
145+
"""
146+
macro functions(learner)
147+
quote
148+
exs = LearnAPI.functions(learner)
149+
eval.(exs)
150+
end |> esc
151+
end
119152

120153
"""
121154
LearnAPI.kinds_of_proxy(learner)

0 commit comments

Comments
 (0)