Skip to content

Commit e879925

Browse files
committed
gamma
1 parent 94c0095 commit e879925

File tree

2 files changed

+42
-13
lines changed

2 files changed

+42
-13
lines changed

src/StatisticalTraits.jl

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module StatisticalTraits
22

33
using ScientificTypes
4-
4+
import Base.instances
55

66
## CONSTANTS
77

@@ -26,7 +26,13 @@ const TRAITS = [
2626
:prediction_type,
2727
:hyperparameters,
2828
:hyperparameter_types,
29-
:hyperparameter_ranges]
29+
:hyperparameter_ranges,
30+
:orientation,
31+
:reports_each_observation,
32+
:aggregation,
33+
:is_feature_dependent,
34+
:distribution_type
35+
]
3036

3137

3238
## EXPORT
@@ -35,9 +41,23 @@ for trait in TRAITS
3541
eval(:(export $trait))
3642
end
3743

44+
export Mean, Sum, RootMeanSquare
45+
46+
47+
## TYPES
48+
49+
# For the possible values of the `aggregation` trait:
50+
abstract type AggregationMode end
51+
struct Sum <: AggregationMode end
52+
struct Mean <: AggregationMode end
53+
struct RootMeanSquare <: AggregationMode end
54+
3855

3956
## HELPERS
4057

58+
# Some helper functions are needed to construct sensible fallbacks for
59+
# some traits.
60+
4161
"""
4262
4363
typename(T::Type)
@@ -55,7 +75,7 @@ function typename(M)
5575
if isdefined(M, :name)
5676
return M.name.name
5777
elseif isdefined(M, :body)
58-
return _typename(M.body)
78+
return typename(M.body)
5979
else
6080
return Symbol(string(M))
6181
end
@@ -95,7 +115,6 @@ end
95115
snakecase(s::Symbol) = Symbol(snakecase(string(s)))
96116

97117

98-
99118
## TRAITS
100119

101120
# The following can return any scientific type, that is, any type
@@ -131,12 +150,16 @@ prediction_type(::Type) = :unknown # used for measures too
131150

132151
# Miscellaneous:
133152

134-
is_wrapper(::Type) = false # or `true`
135-
supports_online(::Type) = false # or `true`
136-
docstring(M::Type) = string(M) # some `String`
137-
is_supervised(::Type) = false # or `true`
138-
human_name(M::Type) = snakecase(name(M), delim=' ') # `name` defined below
139-
153+
is_wrapper(::Type) = false # or `true`
154+
supports_online(::Type) = false # or `true`
155+
docstring(M::Type) = string(M) # some `String`
156+
is_supervised(::Type) = false # or `true`
157+
human_name(M::Type) = snakecase(name(M), delim=' ') # `name` defined below
158+
orientation(::Type) = :loss # or `:score`, `:other`
159+
aggregation(::Type) = Mean()
160+
is_feature_dependent(::Type) = false
161+
reports_each_observation(::Type) = false
162+
distribution_type(::Type) = missing
140163

141164
# Returns a tuple, with one entry per field of `T` (the type of some
142165
# statistical model, for example). Each entry is `nothing` or defines
@@ -162,7 +185,8 @@ for trait in TRAITS
162185
eval(ex)
163186
end
164187

165-
## INFO METHOD FOR QUERYING TRAITS
188+
189+
## INFO STUB FOR QUERYING TRAITS
166190

167191
"""
168192
info(X)
@@ -180,4 +204,3 @@ info(X) = info(X, Val(ScientificTypes.trait(X)))
180204
info(X, ::Val{:other}) = NamedTuple()
181205

182206
end # module
183-

test/runtests.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,13 @@ const NONCONSTANT = [:docstring,
8383
@testset "traits with constant fall-back" begin
8484
for trait in setdiff(StatisticalTraits.TRAITS, NONCONSTANT)
8585
ex = quote
86-
@test $trait(Fruit.RedApple()) == $trait(Foo(1, 'x'))
86+
a = $trait(Fruit.RedApple())
87+
b = $trait(Foo(1, 'x'))
88+
if ismissing(a)
89+
@test ismissing(b)
90+
else
91+
@test a == b
92+
end
8793
end
8894
Main.eval(ex)
8995
end

0 commit comments

Comments
 (0)