Skip to content

Commit b460b71

Browse files
committed
traits only get applied and defined on instances
1 parent c5886a7 commit b460b71

File tree

5 files changed

+70
-71
lines changed

5 files changed

+70
-71
lines changed

docs/src/algorithm_traits.md

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,46 +4,41 @@
44
> *This algorithm supports per-observation weights, which must appear as the third
55
> argument of `fit`*, or *This algorithm's `transform` method predicts `Real` vectors*.
66
7-
For any (non-trivial) algorithm, [`LearnAPI.functions`](@ref)`(algorithm)` must be
8-
overloaded to list the LearnAPI methods that have been explicitly implemented/overloaded
9-
(algorithm traits excluded). Overloading other traits is optional, except where required
10-
by the implementation of some LearnAPI method and explicitly documented in that method's
11-
docstring.
12-
13-
Traits are often called on instances but are usually *defined* on algorithm *types*, as in
7+
Algorithm traits are functions whose first (and usually only) argument is an algorithm. In
8+
a new implementation, a single-argument trait is declared following this pattern:
149

1510
```julia
16-
LearnAPI.is_pure_julia(::Type{<:MyAlgorithmType}) = true
11+
LearnAPI.is_pure_julia(algorithm::MyAlgorithmType) = true
1712
```
1813

19-
which has the shorthand
14+
!!! important
2015

21-
```julia
22-
@trait MyAlgorithmType is_pure_julia=true
23-
```
16+
The value of a trait must be the same for all algorithms of the same type,
17+
even if the types differ only in type parameters. There are exceptions for
18+
some traits, if
19+
`is_wrapper(algorithm) = true` for all instances `algorithm` of some type
20+
(composite algorithms). This requirement occasionally requires that
21+
an existing algorithm implementation be split into separate LearnAPI
22+
implementations (e.g., one for regression and another for classification).
2423

25-
So, for convenience, every trait `t` is provided the fallback implementation
24+
The declaration above has the shorthand
2625

2726
```julia
28-
t(algorithm) = t(typeof(algorithm))
27+
@trait MyAlgorithmType is_pure_julia=true
2928
```
3029

31-
This means `LearnAPI.is_pure_julia(algorithm) = true` whenever `algorithm isa MyAlgorithmType` in the
32-
above example.
33-
34-
Every trait has a global fallback implementation for `::Type`. See the table below.
30+
Multiple traits can be declared like this:
3531

36-
## When traits depdend on more than algorithm type
3732

38-
Traits that vary from instance to instance of the same type are disallowed, except in the
39-
case of composite algorithms (`is_wrapper(algorithm) = true`) where this is typically
40-
unavoidable. The reason for this is so one can associate, with each non-composite
41-
algorithm type, unique trait-based "algorithm metadata", for inclusion in searchable
42-
algorithm databases. This requirement occasionally requires that an existing algorithm
43-
implementation be split into separate LearnAPI implementations (e.g., one for regression
44-
and another for classification).
33+
```julia
34+
@trait(
35+
MyAlgorithmType,
36+
is_pure_julia = true,
37+
pkg_name = "MyPackage",
38+
)
39+
```
4540

46-
## Special two-argument traits
41+
### Special two-argument traits
4742

4843
The two-argument version of [`LearnAPI.predict_output_scitype`](@ref) and
4944
[`LearnAPI.predict_output_scitype`](@ref) are the only overloadable traits with more than
@@ -55,7 +50,7 @@ one argument. They cannot be declared using the `@trait` macro.
5550
implementation. **Derived traits** are not, and should not be called by performance
5651
critical code
5752

58-
## Overloadable traits
53+
### Overloadable traits
5954

6055
In the examples column of the table below, `Table`, `Continuous`, `Sampleable` are names owned by the
6156
package [ScientificTypesBase.jl](https://github.com/JuliaAI/ScientificTypesBase.jl/).
@@ -100,7 +95,7 @@ include the variable.
10095
for the general case.
10196

10297

103-
## Derived Traits
98+
### Derived Traits
10499

105100
The following convenience methods are provided but intended for overloading:
106101

docs/src/anatomy_of_an_implementation.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
> returning the absolute values of the linear coefficients. The ridge regressor has a
99
> target variable and outputs literal predictions of the target (rather than, say,
1010
> probabilistic predictions); accordingly the overloaded `predict` method is dispatched on
11-
> the `LiteralTarget` subtype of `KindOfProxy`. An **algorithm trait** declares this type as the
12-
> preferred kind of target proxy. Other traits articulate the algorithm's training data type
13-
> requirements and the input/output type of `predict`.
11+
> the `LiteralTarget` subtype of `KindOfProxy`. An **algorithm trait** declares this type
12+
> as the preferred kind of target proxy. Other traits articulate the algorithm's training
13+
> data type requirements and the input/output type of `predict`.
1414
1515
We begin by describing an implementation of LearnAPI.jl for basic ridge regression
1616
(without intercept) to introduce the main actors in any implementation.
@@ -159,7 +159,7 @@ list). Accordingly, we are required to declare a preferred target proxy, which w
159159
[`LearnAPI.preferred_kind_of_proxy`](@ref):
160160

161161
```@example anatomy
162-
LearnAPI.preferred_kind_of_proxy(::Type{<:MyRidge}) = LearnAPI.LiteralTarget()
162+
LearnAPI.preferred_kind_of_proxy(::MyRidge) = LearnAPI.LiteralTarget()
163163
nothing # hide
164164
```
165165
Or, you can use the shorthand

src/algorithm_traits.jl

Lines changed: 29 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ this list, do `LearnAPI.functions()`.
6666
See also [`LearnAPI.Algorithm`](@ref).
6767
6868
"""
69-
functions(::Type) = ()
69+
functions(::Any) = ()
7070

7171

7272
"""
@@ -104,13 +104,13 @@ Then we can declare
104104
which is shorthand for
105105
106106
```julia
107-
LearnAPI.preferred_kind_of_proxy(::Type{<:MyNewAlgorithmType}) = LearnAPI.Distribution()
107+
LearnAPI.preferred_kind_of_proxy(::MyNewAlgorithmType) = LearnAPI.Distribution()
108108
```
109109
110110
For more on target variables and target proxies, refer to the LearnAPI documentation.
111111
112112
"""
113-
preferred_kind_of_proxy(::Type) = nothing
113+
preferred_kind_of_proxy(::Any) = nothing
114114

115115
"""
116116
LearnAPI.position_of_target(algorithm)
@@ -122,7 +122,7 @@ If this number is `0`, then no target is expected. If this number exceeds `lengt
122122
then `data` is understood to exclude the target variable.
123123
124124
"""
125-
position_of_target(::Type) = 0
125+
position_of_target(::Any) = 0
126126

127127
"""
128128
LearnAPI.position_of_weights(algorithm)
@@ -135,7 +135,7 @@ If this number is `0`, then no weights are expected. If this number exceeds
135135
uniform.
136136
137137
"""
138-
position_of_weights(::Type) = 0
138+
position_of_weights(::Any) = 0
139139

140140
descriptors() = [
141141
:regression,
@@ -180,7 +180,7 @@ Lists one or more suggestive algorithm descriptors from this list: $DOC_DESCRIPT
180180
This trait should return a tuple of symbols, as in `(:classifier, :probabilistic)`.
181181
182182
"""
183-
descriptors(::Type) = ()
183+
descriptors(::Any) = ()
184184

185185
"""
186186
LearnAPI.is_pure_julia(algorithm)
@@ -192,7 +192,7 @@ Returns `true` if training `algorithm` requires evaluation of pure Julia code on
192192
The fallback is `false`.
193193
194194
"""
195-
is_pure_julia(::Type) = false
195+
is_pure_julia(::Any) = false
196196

197197
"""
198198
LearnAPI.pkg_name(algorithm)
@@ -208,7 +208,7 @@ $DOC_UNKNOWN
208208
Must return a string, as in `"DecisionTree"`.
209209
210210
"""
211-
pkg_name(::Type) = "unknown"
211+
pkg_name(::Any) = "unknown"
212212

213213
"""
214214
LearnAPI.pkg_license(algorithm)
@@ -217,7 +217,7 @@ Return the name of the software license, such as `"MIT"`, applying to the packag
217217
core algorithm for `algorithm` is implemented.
218218
219219
"""
220-
pkg_license(::Type) = "unknown"
220+
pkg_license(::Any) = "unknown"
221221

222222
"""
223223
LearnAPI.doc_url(algorithm)
@@ -231,7 +231,7 @@ $DOC_UNKNOWN
231231
Must return a string, such as `"https://en.wikipedia.org/wiki/Decision_tree_learning"`.
232232
233233
"""
234-
doc_url(::Type) = "unknown"
234+
doc_url(::Any) = "unknown"
235235

236236
"""
237237
LearnAPI.load_path(algorithm)
@@ -250,7 +250,7 @@ $DOC_UNKNOWN
250250
251251
252252
"""
253-
load_path(::Type) = "unknown"
253+
load_path(::Any) = "unknown"
254254

255255

256256
"""
@@ -268,7 +268,7 @@ $DOC_ON_TYPE
268268
269269
270270
"""
271-
is_wrapper(::Type) = false
271+
is_wrapper(::Any) = false
272272

273273
"""
274274
LearnAPI.human_name(algorithm)
@@ -284,7 +284,7 @@ to return `"K-nearest neighbors regressor"`. Ideally, this is a "concrete" noun
284284
`"ridge regressor"` rather than an "abstract" noun like `"ridge regression"`.
285285
286286
"""
287-
human_name(M::Type{}) = snakecase(name(M), delim=' ') # `name` defined below
287+
human_name(M) = snakecase(name(M), delim=' ') # `name` defined below
288288

289289
"""
290290
LearnAPI.iteration_parameter(algorithm)
@@ -297,7 +297,7 @@ iterative.
297297
Implement if algorithm is iterative. Returns a symbol or `nothing`.
298298
299299
"""
300-
iteration_parameter(::Type) = nothing
300+
iteration_parameter(::Any) = nothing
301301

302302
"""
303303
LearnAPI.fit_keywords(algorithm)
@@ -314,7 +314,7 @@ Here's a sample implementation for a classifier that implements a `LearnAPI.fit`
314314
with signature `fit(algorithm::MyClassifier, verbosity, X, y; class_weights=nothing)`:
315315
316316
```
317-
LearnAPI.fit_keywords(::Type{<:MyClassifier}) = (:class_weights,)
317+
LearnAPI.fit_keywords(::Any{<:MyClassifier}) = (:class_weights,)
318318
```
319319
320320
or the shorthand
@@ -325,7 +325,7 @@ or the shorthand
325325
326326
327327
"""
328-
fit_keywords(::Type) = ()
328+
fit_keywords(::Any) = ()
329329

330330
"""
331331
LearnAPI.fit_scitype(algorithm)
@@ -353,7 +353,7 @@ See also [`LearnAPI.fit_type`](@ref), [`LearnAPI.fit_observation_scitype`](@ref)
353353
Optional. The fallback return value is `Union{}`. $DOC_ONLY_ONE
354354
355355
"""
356-
fit_scitype(::Type) = Union{}
356+
fit_scitype(::Any) = Union{}
357357

358358
"""
359359
LearnAPI.fit_observation_scitype(algorithm)
@@ -386,7 +386,7 @@ See also See also [`LearnAPI.fit_type`](@ref), [`LearnAPI.fit_scitype`](@ref),
386386
Optional. The fallback return value is `Union{}`. $DOC_ONLY_ONE
387387
388388
"""
389-
fit_observation_scitype(::Type) = Union{}
389+
fit_observation_scitype(::Any) = Union{}
390390

391391
"""
392392
LearnAPI.fit_type(algorithm)
@@ -413,7 +413,7 @@ See also [`LearnAPI.fit_scitype`](@ref), [`LearnAPI.fit_observation_type`](@ref)
413413
Optional. The fallback return value is `Union{}`. $DOC_ONLY_ONE
414414
415415
"""
416-
fit_type(::Type) = Union{}
416+
fit_type(::Any) = Union{}
417417

418418
"""
419419
LearnAPI.fit_observation_type(algorithm)
@@ -446,7 +446,7 @@ See also See also [`LearnAPI.fit_type`](@ref), [`LearnAPI.fit_scitype`](@ref),
446446
Optional. The fallback return value is `Union{}`. $DOC_ONLY_ONE
447447
448448
"""
449-
fit_observation_type(::Type) = Union{}
449+
fit_observation_type(::Any) = Union{}
450450

451451
DOC_INPUT_SCITYPE(op) =
452452
"""
@@ -543,22 +543,22 @@ DOC_OUTPUT_TYPE(op) =
543543
"""
544544

545545
"$(DOC_INPUT_SCITYPE(:predict))"
546-
predict_input_scitype(::Type) = Union{}
546+
predict_input_scitype(::Any) = Union{}
547547

548548
"$(DOC_INPUT_TYPE(:predict))"
549-
predict_input_type(::Type) = Union{}
549+
predict_input_type(::Any) = Union{}
550550

551551
"$(DOC_INPUT_SCITYPE(:transform))"
552-
transform_input_scitype(::Type) = Union{}
552+
transform_input_scitype(::Any) = Union{}
553553

554554
"$(DOC_OUTPUT_SCITYPE(:transform))"
555-
transform_output_scitype(::Type) = Any
555+
transform_output_scitype(::Any) = Any
556556

557557
"$(DOC_INPUT_TYPE(:transform))"
558-
transform_input_type(::Type) = Union{}
558+
transform_input_type(::Any) = Union{}
559559

560560
"$(DOC_OUTPUT_TYPE(:transform))"
561-
transform_output_type(::Type) = Any
561+
transform_output_type(::Any) = Any
562562

563563

564564
# # TWO-ARGUMENT TRAITS
@@ -591,7 +591,7 @@ const DOC_PREDICT_OUTPUT(s) =
591591
regressor type `MyRgs` that only predicts actual values of the target:
592592
593593
LearnAPI.predict(alogrithm::MyRgs, ::LearnAPI.LiteralTarget, data...) = ...
594-
LearnAPI.predict_output_$(s)(::Type{<:MyRgs}, ::LearnAPI.LiteralTarget) =
594+
LearnAPI.predict_output_$(s)(::MyRgs, ::LearnAPI.LiteralTarget) =
595595
AbstractVector{ScientificTypesBase.Continuous}
596596
597597
The fallback method returns `Any`.
@@ -607,9 +607,9 @@ predict_output_type(algorithm, kind_of_proxy) = Any
607607

608608
# # DERIVED TRAITS
609609

610-
name(A::Type) = string(typename(A))
610+
name(A) = string(typename(A))
611611

612-
is_algorithm(A::Type) = !isempty(functions(A))
612+
is_algorithm(A) = !isempty(functions(A))
613613

614614
const DOC_PREDICT_OUTPUT2(s) =
615615
"""
@@ -651,11 +651,3 @@ predict_output_type(algorithm) =
651651
for T in CONCRETE_TARGET_PROXY_TYPES)
652652

653653

654-
# # FALLBACK FOR INSTANCES
655-
656-
for trait in TRAITS
657-
ex = quote
658-
$trait(x) = $trait(typeof(x))
659-
end
660-
eval(ex)
661-
end

src/tools.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ macro trait(algorithm_ex, exs...)
1414
trait_ex, value_ex = name_value_pair(ex)
1515
push!(
1616
program.args,
17-
:($LearnAPI.$trait_ex(::Type{<:$algorithm_ex}) = $value_ex),
17+
:($LearnAPI.$trait_ex(::$algorithm_ex) = $value_ex),
1818
)
1919
end
2020
return esc(program)

test/tools.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
module Fruit
2+
using LearnAPI
23

34
struct RedApple{T}
45
x::T
56
end
67

8+
@trait(
9+
RedApple,
10+
is_pure_julia = true,
11+
pkg_name = "Fruity",
12+
)
13+
714
end
815

916
import .Fruit
@@ -30,4 +37,9 @@ end
3037
@test LearnAPI.snakecase(:TheLASERBeam) == :the_laser_beam
3138
end
3239

40+
@testset "@trait" begin
41+
@test LearnAPI.is_pure_julia(Fruit.RedApple(1))
42+
@test LearnAPI.pkg_name(Fruit.RedApple(1)) == "Fruity"
43+
end
44+
3345
true

0 commit comments

Comments
 (0)