Skip to content

Commit fc48b45

Browse files
committed
clean up; add some type aliases
1 parent bb7c54b commit fc48b45

File tree

1 file changed

+27
-15
lines changed

1 file changed

+27
-15
lines changed

src/MLJDecisionTreeInterface.jl

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -311,15 +311,38 @@ end
311311

312312
MMI.reports_feature_importances(::Type{<:RandomForestRegressor}) = true
313313

314+
# # ALIASES FOR TYPE UNIONS
314315

315-
# # Feature Importances
316+
const TreeModel = Union{
317+
DecisionTreeClassifier,
318+
RandomForestClassifier,
319+
AdaBoostStumpClassifier,
320+
DecisionTreeRegressor,
321+
RandomForestRegressor,
322+
}
323+
324+
const IterativeModel = Union{
325+
RandomForestClassifier,
326+
RandomForestRegressor,
327+
AdaBoostStumpClassifier,
328+
}
329+
330+
const RandomForestModel = Union{DecisionTreeClassifier, RandomForestClassifier}
331+
332+
# # FEATURE IMPORTANCES
316333

317334
# get actual arguments needed for importance calculation from various fitresults.
318-
get_fitresult(m::Union{DecisionTreeClassifier, RandomForestClassifier}, fitresult) = (fitresult[1],)
319-
get_fitresult(m::Union{DecisionTreeRegressor, RandomForestRegressor}, fitresult) = (fitresult,)
335+
get_fitresult(
336+
m::Union{DecisionTreeClassifier, RandomForestClassifier},
337+
fitresult,
338+
) = (fitresult[1],)
339+
get_fitresult(
340+
m::Union{DecisionTreeRegressor, RandomForestRegressor},
341+
fitresult,
342+
) = (fitresult,)
320343
get_fitresult(m::AdaBoostStumpClassifier, fitresult)= (fitresult[1], fitresult[2])
321344

322-
function MMI.feature_importances(m::Union{DecisionTreeClassifier, RandomForestClassifier, AdaBoostStumpClassifier, DecisionTreeRegressor, RandomForestRegressor}, fitresult, report)
345+
function MMI.feature_importances(m::TreeModel, fitresult, report)
323346
# generate feature importances for report
324347
if m.feature_importance == :impurity
325348
feature_importance_func = DT.impurity_importance
@@ -337,19 +360,8 @@ function MMI.feature_importances(m::Union{DecisionTreeClassifier, RandomForestCl
337360
return fi_pairs
338361
end
339362

340-
341363
# # METADATA (MODEL TRAITS)
342364

343-
# following five lines of code are redundant if using this branch of
344-
# MLJModelInterface:
345-
# https://github.com/JuliaAI/MLJModelInterface.jl/pull/139
346-
347-
# MMI.human_name(::Type{<:DecisionTreeClassifier}) = "CART decision tree classifier"
348-
# MMI.human_name(::Type{<:RandomForestClassifier}) = "CART random forest classifier"
349-
# MMI.human_name(::Type{<:AdaBoostStumpClassifier}) = "Ada-boosted stump classifier"
350-
# MMI.human_name(::Type{<:DecisionTreeRegressor}) = "CART decision tree regressor"
351-
# MMI.human_name(::Type{<:RandomForestRegressor}) = "CART random forest regressor"
352-
353365
MMI.metadata_pkg.(
354366
(DecisionTreeClassifier, DecisionTreeRegressor,
355367
RandomForestClassifier, RandomForestRegressor,

0 commit comments

Comments
 (0)