@@ -32,16 +32,85 @@ If [`Sage`](@ref)`(multitarget=..., code_type=...)` has been implemented, then
3232additionally have:
3333
3434- `observations.levels`: A categorical vector of the ordered target levels, as actually
35- seen in the user-supplied target, with the full pool of levels available by applying
36- `Categorical.levels` to the result. The corresponding integer codes will be
37- `sort(unique( observations.target))` .
35+ seen in the user-supplied target. The corresponding integer codes will be
36+ `sort(unique(observations.target))`. To get the full pool of levels, apply
37+ `CategoricalArrays.levels` to ` observations.levels_seen`; see the example below .
3838
3939- `observations.decoder`: A callable function that converts an integer code back to the
4040 original `CategoricalValue` it represents.
4141
4242Pass the first onto `predict` for making probabilistic predictions, and the second for
4343point predictions; see [`Sage`](@ref) for details.
4444
45+ # Extended help
46+
47+ In the example below, `observations` implements the full `Obs` interface described above,
48+ for a learner implementing the `Sage` front end:
49+
50+ ```julia-repl
51+ using LearnAPI, LearnDataFrontEnds, LearnTestAPI
52+ using CategoricalDistributions, CategoricalArrays, DataFrames
53+ X = DataFrame(rand(10, 3), :auto)
54+ y = categorical(collect("ababababac"))
55+ learner = LearnTestAPI.ConstantClassifier()
56+ observations = obs(learner, (X[1:9,:], y[1:9]))
57+
58+ julia> observations.features
59+ 3×9 Matrix{Float64}:
60+ 0.234043 0.526468 0.227417 0.956471 … 0.00587146 0.169291 0.353518 0.402631
61+ 0.631083 0.151317 0.781049 0.00320728 0.756519 0.15317 0.452169 0.127005
62+ 0.285315 0.347433 0.69174 0.516915 0.900343 0.404006 0.448986 0.962649
63+
64+ julia> yint = observations.target
65+ 9-element Vector{UInt32}:
66+ 0x00000001
67+ 0x00000002
68+ 0x00000001
69+ 0x00000002
70+ 0x00000001
71+ 0x00000002
72+ 0x00000001
73+ 0x00000002
74+ 0x00000001
75+
76+ julia> observations.levels_seen
77+ 2-element CategoricalArray{Char,1,UInt32}:
78+ 'a'
79+ 'b'
80+
81+ julia> sort(unique(observations.target))
82+ 2-element Vector{UInt32}:
83+ 0x00000001
84+ 0x00000002
85+
86+ julia> observations.levels_seen |> levels
87+ 3-element CategoricalArray{Char,1,UInt32}:
88+ 'a'
89+ 'b'
90+ 'c'
91+
92+ julia> observations.decoder.(yint)
93+ 9-element CategoricalArray{Char,1,UInt32}:
94+ 'a'
95+ 'b'
96+ 'a'
97+ 'b'
98+ 'a'
99+ 'b'
100+ 'a'
101+ 'b'
102+ 'a'
103+
104+ julia> d = UnivariateFinite(observations.levels_seen, [0.4, 0.6])
105+ UnivariateFinite{Multiclass{3}}(a=>0.4, b=>0.6)
106+
107+ julia> levels(d)
108+ 3-element CategoricalArray{Char,1,UInt32}:
109+ 'a'
110+ 'b'
111+ 'c'
112+ ```
113+
45114"""
46115abstract type Obs end
47116
0 commit comments