Skip to content

Commit 0f86996

Browse files
authored
Merge pull request #37 from JuliaAI/MLJModelInterface-compat-fix
Reverse roles keys <--> values in encoding for classifier
2 parents 39135de + f2fdb18 commit 0f86996

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

src/MLJDecisionTreeInterface.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ end
7979

8080
function get_encoding(classes_seen)
8181
a_cat_element = classes_seen[1]
82-
return Dict(c => MMI.int(c) for c in MMI.classes(a_cat_element))
82+
return Dict(MMI.int(c) => c for c in MMI.classes(a_cat_element))
8383
end
8484

8585
MMI.fitted_params(::DecisionTreeClassifier, fitresult) =
@@ -537,9 +537,9 @@ To interpret the internal class labelling:
537537
```
538538
julia> fitted_params(mach).encoding
539539
Dict{CategoricalArrays.CategoricalValue{String, UInt32}, UInt32} with 3 entries:
540-
"virginica" => 0x00000003
541-
"setosa" => 0x00000001
542-
"versicolor" => 0x00000002
540+
0x00000003 => "virginica"
541+
0x00000001 => "setosa"
542+
0x00000002 => "versicolor"
543543
```
544544
545545
See also

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ yyhat = predict_mode(baretree, fitresult, MLJBase.selectrows(X, 1:3))
4949
fp = fitted_params(baretree, fitresult)
5050
@test Set([:tree, :encoding, :features]) == Set(keys(fp))
5151
@test fp.features == report.features
52+
enc = fp.encoding
53+
@test Set(values(enc)) == Set(["virginica", "setosa", "versicolor"])
54+
@test enc[MLJBase.int(y[end])] == "virginica"
5255

5356
using Random: seed!
5457
seed!(0)

0 commit comments

Comments
 (0)