Skip to content

Commit 98677fc

Browse files
committed
implement data front end: reformat and selectrows
1 parent fc48b45 commit 98677fc

File tree

3 files changed

+153
-149
lines changed

3 files changed

+153
-149
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Anthony D. Blaom <[email protected]>"]
44
version = "0.3.0"
55

66
[deps]
7+
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
78
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
89
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
910
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -16,10 +17,9 @@ Tables = "1.6"
1617
julia = "1.6"
1718

1819
[extras]
19-
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
2020
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
2121
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
2222
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2323

2424
[targets]
25-
test = ["CategoricalArrays", "MLJBase", "StableRNGs", "Test"]
25+
test = ["MLJBase", "StableRNGs", "Test"]

src/MLJDecisionTreeInterface.jl

Lines changed: 111 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import MLJModelInterface
44
using MLJModelInterface.ScientificTypesBase
55
import DecisionTree
66
import Tables
7+
using CategoricalArrays
78

89
using Random
910
import Random.GLOBAL_RNG
@@ -21,13 +22,13 @@ end
2122
Base.show(stream::IO, c::TreePrinter) =
2223
print(stream, "TreePrinter object (call with display depth)")
2324

25+
function classes(y)
26+
p = CategoricalArrays.pool(y)
27+
[p[i] for i in 1:length(p)]
28+
end
2429

2530
# # DECISION TREE CLASSIFIER
2631

27-
# The following meets the MLJ standard for a `Model` docstring and is
28-
# created without the use of interpolation so it can be used a # template for authors of other MLJ model interfaces. The other
29-
# doc-strings, defined later, are generated using the `doc_header`
30-
# utility to automatically generate the header, another option.
3132
MMI.@mlj_model mutable struct DecisionTreeClassifier <: MMI.Probabilistic
3233
max_depth::Int = (-)(1)::(_ ≥ -1)
3334
min_samples_leaf::Int = 1::(_ ≥ 0)
@@ -41,19 +42,17 @@ MMI.@mlj_model mutable struct DecisionTreeClassifier <: MMI.Probabilistic
4142
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
4243
end
4344

44-
function MMI.fit(m::DecisionTreeClassifier, verbosity::Int, X, y)
45-
schema = Tables.schema(X)
46-
Xmatrix = MMI.matrix(X)
47-
yplain = MMI.int(y)
48-
49-
if schema === nothing
50-
features = [Symbol("x$j") for j in 1:size(Xmatrix, 2)]
51-
else
52-
features = schema.names |> collect
53-
end
45+
function MMI.fit(
46+
m::DecisionTreeClassifier,
47+
verbosity::Int,
48+
Xmatrix,
49+
yplain,
50+
features,
51+
classes,
52+
)
5453

55-
classes_seen = filter(in(unique(y)), MMI.classes(y[1]))
56-
integers_seen = MMI.int(classes_seen)
54+
integers_seen = unique(yplain)
55+
classes_seen = MMI.decoder(classes)(integers_seen)
5756

5857
tree = DT.build_tree(yplain, Xmatrix,
5958
m.n_subfeatures,
@@ -70,40 +69,26 @@ function MMI.fit(m::DecisionTreeClassifier, verbosity::Int, X, y)
7069
fitresult = (tree, classes_seen, integers_seen, features)
7170

7271
cache = nothing
73-
report = (classes_seen=classes_seen,
74-
print_tree=TreePrinter(tree),
75-
features=features,
76-
)
72+
report = (
73+
classes_seen=classes_seen,
74+
print_tree=TreePrinter(tree),
75+
features=features,
76+
)
7777
return fitresult, cache, report
7878
end
7979

80-
function get_encoding(classes_seen)
81-
a_cat_element = classes_seen[1]
82-
return Dict(MMI.int(c) => c for c in MMI.classes(a_cat_element))
83-
end
80+
get_encoding(classes_seen) = Dict(MMI.int(c) => c for c in classes(classes_seen))
8481

8582
MMI.fitted_params(::DecisionTreeClassifier, fitresult) =
8683
(tree=fitresult[1],
8784
encoding=get_encoding(fitresult[2]),
8885
features=fitresult[4])
8986

90-
function smooth(scores, smoothing)
91-
iszero(smoothing) && return scores
92-
threshold = smoothing / size(scores, 2)
93-
# clip low values
94-
scores[scores .< threshold] .= threshold
95-
# normalize
96-
return scores ./ sum(scores, dims=2)
97-
end
98-
9987
function MMI.predict(m::DecisionTreeClassifier, fitresult, Xnew)
100-
Xmatrix = MMI.matrix(Xnew)
10188
tree, classes_seen, integers_seen = fitresult
10289
# retrieve the predicted scores
103-
scores = DT.apply_tree_proba(tree, Xmatrix, integers_seen)
104-
105-
# return vector of UF
106-
return MMI.UnivariateFinite(classes_seen, scores)
90+
scores = DT.apply_tree_proba(tree, Xnew, integers_seen)
91+
MMI.UnivariateFinite(classes_seen, scores)
10792
end
10893

10994
MMI.reports_feature_importances(::Type{<:DecisionTreeClassifier}) = true
@@ -123,19 +108,17 @@ MMI.@mlj_model mutable struct RandomForestClassifier <: MMI.Probabilistic
123108
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
124109
end
125110

126-
function MMI.fit(m::RandomForestClassifier, verbosity::Int, X, y)
127-
schema = Tables.schema(X)
128-
Xmatrix = MMI.matrix(X)
129-
yplain = MMI.int(y)
130-
131-
if schema === nothing
132-
features = [Symbol("x$j") for j in 1:size(Xmatrix, 2)]
133-
else
134-
features = schema.names |> collect
135-
end
111+
function MMI.fit(
112+
m::RandomForestClassifier,
113+
verbosity::Int,
114+
Xmatrix,
115+
yplain,
116+
features,
117+
classes,
118+
)
136119

137-
classes_seen = filter(in(unique(y)), MMI.classes(y[1]))
138-
integers_seen = MMI.int(classes_seen)
120+
integers_seen = unique(yplain)
121+
classes_seen = MMI.decoder(classes)(integers_seen)
139122

140123
forest = DT.build_forest(yplain, Xmatrix,
141124
m.n_subfeatures,
@@ -156,10 +139,9 @@ end
156139
MMI.fitted_params(::RandomForestClassifier, (forest,_)) = (forest=forest,)
157140

158141
function MMI.predict(m::RandomForestClassifier, fitresult, Xnew)
159-
Xmatrix = MMI.matrix(Xnew)
160142
forest, classes_seen, integers_seen = fitresult
161-
scores = DT.apply_forest_proba(forest, Xmatrix, integers_seen)
162-
return MMI.UnivariateFinite(classes_seen, scores)
143+
scores = DT.apply_forest_proba(forest, Xnew, integers_seen)
144+
MMI.UnivariateFinite(classes_seen, scores)
163145
end
164146

165147
MMI.reports_feature_importances(::Type{<:RandomForestClassifier}) = true
@@ -173,20 +155,17 @@ MMI.@mlj_model mutable struct AdaBoostStumpClassifier <: MMI.Probabilistic
173155
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
174156
end
175157

176-
function MMI.fit(m::AdaBoostStumpClassifier, verbosity::Int, X, y)
177-
schema = Tables.schema(X)
178-
Xmatrix = MMI.matrix(X)
179-
yplain = MMI.int(y)
180-
181-
if schema === nothing
182-
features = [Symbol("x$j") for j in 1:size(Xmatrix, 2)]
183-
else
184-
features = schema.names |> collect
185-
end
158+
function MMI.fit(
159+
m::AdaBoostStumpClassifier,
160+
verbosity::Int,
161+
Xmatrix,
162+
yplain,
163+
features,
164+
classes,
165+
)
186166

187-
188-
classes_seen = filter(in(unique(y)), MMI.classes(y[1]))
189-
integers_seen = MMI.int(classes_seen)
167+
integers_seen = unique(yplain)
168+
classes_seen = MMI.decoder(classes)(integers_seen)
190169

191170
stumps, coefs =
192171
DT.build_adaboost_stumps(yplain, Xmatrix, m.n_iter, rng=m.rng)
@@ -201,10 +180,13 @@ MMI.fitted_params(::AdaBoostStumpClassifier, (stumps,coefs,_)) =
201180
(stumps=stumps,coefs=coefs)
202181

203182
function MMI.predict(m::AdaBoostStumpClassifier, fitresult, Xnew)
204-
Xmatrix = MMI.matrix(Xnew)
205183
stumps, coefs, classes_seen, integers_seen = fitresult
206-
scores = DT.apply_adaboost_stumps_proba(stumps, coefs,
207-
Xmatrix, integers_seen)
184+
scores = DT.apply_adaboost_stumps_proba(
185+
stumps,
186+
coefs,
187+
Xnew,
188+
integers_seen,
189+
)
208190
return MMI.UnivariateFinite(classes_seen, scores)
209191
end
210192

@@ -225,23 +207,18 @@ MMI.@mlj_model mutable struct DecisionTreeRegressor <: MMI.Deterministic
225207
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
226208
end
227209

228-
function MMI.fit(m::DecisionTreeRegressor, verbosity::Int, X, y)
229-
schema = Tables.schema(X)
230-
Xmatrix = MMI.matrix(X)
210+
function MMI.fit(m::DecisionTreeRegressor, verbosity::Int, Xmatrix, y, features)
231211

232-
if schema === nothing
233-
features = [Symbol("x$j") for j in 1:size(Xmatrix, 2)]
234-
else
235-
features = schema.names |> collect
236-
end
237-
238-
tree = DT.build_tree(float(y), Xmatrix,
239-
m.n_subfeatures,
240-
m.max_depth,
241-
m.min_samples_leaf,
242-
m.min_samples_split,
243-
m.min_purity_increase;
244-
rng=m.rng)
212+
tree = DT.build_tree(
213+
y,
214+
Xmatrix,
215+
m.n_subfeatures,
216+
m.max_depth,
217+
m.min_samples_leaf,
218+
m.min_samples_split,
219+
m.min_purity_increase;
220+
rng=m.rng
221+
)
245222

246223
if m.post_prune
247224
tree = DT.prune_tree(tree, m.merge_purity_threshold)
@@ -255,10 +232,7 @@ end
255232

256233
MMI.fitted_params(::DecisionTreeRegressor, tree) = (tree=tree,)
257234

258-
function MMI.predict(::DecisionTreeRegressor, tree, Xnew)
259-
Xmatrix = MMI.matrix(Xnew)
260-
return DT.apply_tree(tree, Xmatrix)
261-
end
235+
MMI.predict(::DecisionTreeRegressor, tree, Xnew) = DT.apply_tree(tree, Xnew)
262236

263237
MMI.reports_feature_importances(::Type{<:DecisionTreeRegressor}) = true
264238

@@ -277,25 +251,21 @@ MMI.@mlj_model mutable struct RandomForestRegressor <: MMI.Deterministic
277251
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
278252
end
279253

280-
function MMI.fit(m::RandomForestRegressor, verbosity::Int, X, y)
281-
schema = Tables.schema(X)
282-
Xmatrix = MMI.matrix(X)
283-
284-
if schema === nothing
285-
features = [Symbol("x$j") for j in 1:size(Xmatrix, 2)]
286-
else
287-
features = schema.names |> collect
288-
end
254+
function MMI.fit(m::RandomForestRegressor, verbosity::Int, Xmatrix, y, features)
255+
256+
forest = DT.build_forest(
257+
y,
258+
Xmatrix,
259+
m.n_subfeatures,
260+
m.n_trees,
261+
m.sampling_fraction,
262+
m.max_depth,
263+
m.min_samples_leaf,
264+
m.min_samples_split,
265+
m.min_purity_increase,
266+
rng=m.rng
267+
)
289268

290-
forest = DT.build_forest(float(y), Xmatrix,
291-
m.n_subfeatures,
292-
m.n_trees,
293-
m.sampling_fraction,
294-
m.max_depth,
295-
m.min_samples_leaf,
296-
m.min_samples_split,
297-
m.min_purity_increase,
298-
rng=m.rng)
299269
cache = nothing
300270
report = (features=features,)
301271

@@ -304,10 +274,7 @@ end
304274

305275
MMI.fitted_params(::RandomForestRegressor, forest) = (forest=forest,)
306276

307-
function MMI.predict(::RandomForestRegressor, forest, Xnew)
308-
Xmatrix = MMI.matrix(Xnew)
309-
return DT.apply_forest(forest, Xmatrix)
310-
end
277+
MMI.predict(::RandomForestRegressor, forest, Xnew) = DT.apply_forest(forest, Xnew)
311278

312279
MMI.reports_feature_importances(::Type{<:RandomForestRegressor}) = true
313280

@@ -327,7 +294,39 @@ const IterativeModel = Union{
327294
AdaBoostStumpClassifier,
328295
}
329296

330-
const RandomForestModel = Union{DecisionTreeClassifier, RandomForestClassifier}
297+
const Classifier = Union{
298+
DecisionTreeClassifier,
299+
RandomForestClassifier,
300+
AdaBoostStumpClassifier,
301+
}
302+
303+
const Regressor = Union{
304+
DecisionTreeRegressor,
305+
RandomForestRegressor,
306+
}
307+
308+
const RandomForestModel = Union{
309+
DecisionTreeClassifier,
310+
RandomForestClassifier,
311+
}
312+
313+
314+
# # DATA FRONT END
315+
316+
_columnnames(X) = Tables.columnnames(Tables.columns(X)) |> collect
317+
318+
# for fit:
319+
MMI.reformat(::Classifier, X, y) =
320+
(Tables.matrix(X), MMI.int(y), _columnnames(X), classes(y))
321+
MMI.reformat(::Regressor, X, y) =
322+
(Tables.matrix(X), float(y), _columnnames(X))
323+
MMI.selectrows(::TreeModel, I, Xmatrix, y, meta...) =
324+
(view(Xmatrix, I, :), view(y, I), meta...)
325+
326+
# for predict:
327+
MMI.reformat(::TreeModel, X) = (Tables.matrix(X),)
328+
MMI.selectrows(::TreeModel, I, Xmatrix) = (view(Xmatrix, I, :),)
329+
331330

332331
# # FEATURE IMPORTANCES
333332

0 commit comments

Comments
 (0)