Skip to content

Commit 20de940

Browse files
committed
improve Obs doc-string to close #9
1 parent 851c8f5 commit 20de940

File tree

1 file changed

+72
-3
lines changed

1 file changed

+72
-3
lines changed

src/backends.jl

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,85 @@ If [`Sage`](@ref)`(multitarget=..., code_type=...)` has been implemented, then
3232
additionally have:
3333
3434
- `observations.levels`: A categorical vector of the ordered target levels, as actually
35-
seen in the user-supplied target, with the full pool of levels available by applying
36-
`Categorical.levels` to the result. The corresponding integer codes will be
37-
`sort(unique(observations.target))`.
35+
seen in the user-supplied target. The corresponding integer codes will be
36+
`sort(unique(observations.target))`. To get the full pool of levels, apply
37+
`CategoricalArrays.levels` to `observations.levels_seen`; see the example below.
3838
3939
- `observations.decoder`: A callable function that converts an integer code back to the
4040
original `CategoricalValue` it represents.
4141
4242
Pass the first onto `predict` for making probabilistic predictions, and the second for
4343
point predictions; see [`Sage`](@ref) for details.
4444
45+
# Extended help
46+
47+
In the example below, `observations` implements the full `Obs` interface described above,
48+
for a learner implementing the `Sage` front end:
49+
50+
```julia-repl
51+
using LearnAPI, LearnDataFrontEnds, LearnTestAPI
52+
using CategoricalDistributions, CategoricalArrays, DataFrames
53+
X = DataFrame(rand(10, 3), :auto)
54+
y = categorical(collect("ababababac"))
55+
learner = LearnTestAPI.ConstantClassifier()
56+
observations = obs(learner, (X[1:9,:], y[1:9]))
57+
58+
julia> observations.features
59+
3×9 Matrix{Float64}:
60+
0.234043 0.526468 0.227417 0.956471 … 0.00587146 0.169291 0.353518 0.402631
61+
0.631083 0.151317 0.781049 0.00320728 0.756519 0.15317 0.452169 0.127005
62+
0.285315 0.347433 0.69174 0.516915 0.900343 0.404006 0.448986 0.962649
63+
64+
julia> yint = observations.target
65+
9-element Vector{UInt32}:
66+
0x00000001
67+
0x00000002
68+
0x00000001
69+
0x00000002
70+
0x00000001
71+
0x00000002
72+
0x00000001
73+
0x00000002
74+
0x00000001
75+
76+
julia> observations.levels_seen
77+
2-element CategoricalArray{Char,1,UInt32}:
78+
'a'
79+
'b'
80+
81+
julia> sort(unique(observations.target))
82+
2-element Vector{UInt32}:
83+
0x00000001
84+
0x00000002
85+
86+
julia> observations.levels_seen |> levels
87+
3-element CategoricalArray{Char,1,UInt32}:
88+
'a'
89+
'b'
90+
'c'
91+
92+
julia> observations.decoder.(yint)
93+
9-element CategoricalArray{Char,1,UInt32}:
94+
'a'
95+
'b'
96+
'a'
97+
'b'
98+
'a'
99+
'b'
100+
'a'
101+
'b'
102+
'a'
103+
104+
julia> d = UnivariateFinite(observations.levels_seen, [0.4, 0.6])
105+
UnivariateFinite{Multiclass{3}}(a=>0.4, b=>0.6)
106+
107+
julia> levels(d)
108+
3-element CategoricalArray{Char,1,UInt32}:
109+
'a'
110+
'b'
111+
'c'
112+
```
113+
45114
"""
46115
abstract type Obs end
47116

0 commit comments

Comments
 (0)