|
23 | 23 | Base.show(stream::IO, c::TreePrinter) = |
24 | 24 | print(stream, "TreePrinter object (call with display depth)") |
25 | 25 |
|
26 | | -function classes(y) |
27 | | - p = CategoricalArrays.pool(y) |
28 | | - [p[i] for i in 1:length(p)] |
29 | | -end |
30 | 26 |
|
31 | 27 | # # DECISION TREE CLASSIFIER |
32 | 28 |
|
@@ -79,7 +75,7 @@ function MMI.fit( |
79 | 75 | end |
80 | 76 |
|
81 | 77 | # returns a dictionary of categorical elements keyed on ref integer: |
82 | | -get_encoding(classes_seen) = Dict(MMI.int(c) => c for c in classes(classes_seen)) |
| 78 | +get_encoding(classes_seen) = Dict(MMI.int(c) => c for c in levels(classes_seen)) |
83 | 79 |
|
84 | 80 | # given such a dictionary, return printable class labels, ordered by corresponding ref |
85 | 81 | # integer: |
@@ -459,7 +455,7 @@ _columnnames(X, ::Val{false}) = Tables.columnnames(first(Tables.rows(X))) |
459 | 455 |
|
460 | 456 | # for fit: |
461 | 457 | MMI.reformat(::Classifier, X, y) = |
462 | | - (Tables.matrix(X), MMI.int(y), _columnnames(X), classes(y)) |
| 458 | + (Tables.matrix(X), MMI.int(y), _columnnames(X), levels(y)) |
463 | 459 | MMI.reformat(::Regressor, X, y) = |
464 | 460 | (Tables.matrix(X), float(y), _columnnames(X)) |
465 | 461 | MMI.selectrows(::TreeModel, I, Xmatrix, y, meta...) = |
|
0 commit comments