Skip to content

Commit f2fdb18

Browse files
committed
keys <--> values in encoding to close #19
1 parent 303eb2c commit f2fdb18

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

src/MLJDecisionTreeInterface.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ end
7979

8080
function get_encoding(classes_seen)
8181
a_cat_element = classes_seen[1]
82-
return Dict(c => MMI.int(c) for c in MMI.classes(a_cat_element))
82+
return Dict(MMI.int(c) => c for c in MMI.classes(a_cat_element))
8383
end
8484

8585
MMI.fitted_params(::DecisionTreeClassifier, fitresult) =
@@ -538,9 +538,9 @@ To interpret the internal class labelling:
538538
```
539539
julia> fitted_params(mach).encoding
540540
Dict{CategoricalArrays.CategoricalValue{String, UInt32}, UInt32} with 3 entries:
541-
"virginica" => 0x00000003
542-
"setosa" => 0x00000001
543-
"versicolor" => 0x00000002
541+
0x00000003 => "virginica"
542+
0x00000001 => "setosa"
543+
0x00000002 => "versicolor"
544544
```
545545
546546
See also

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ yyhat = predict_mode(baretree, fitresult, MLJBase.selectrows(X, 1:3))
4949
fp = fitted_params(baretree, fitresult)
5050
@test Set([:tree, :encoding, :features]) == Set(keys(fp))
5151
@test fp.features == report.features
52+
enc = fp.encoding
53+
@test Set(values(enc)) == Set(["virginica", "setosa", "versicolor"])
54+
@test enc[MLJBase.int(y[end])] == "virginica"
5255

5356
using Random: seed!
5457
seed!(0)

0 commit comments

Comments
 (0)