@@ -311,15 +311,38 @@ end
311311
312312MMI. reports_feature_importances (:: Type{<:RandomForestRegressor} ) = true
313313
314+ # # ALIASES FOR TYPE UNIONS
314315
315- # # Feature Importances
316+ const TreeModel = Union{
317+ DecisionTreeClassifier,
318+ RandomForestClassifier,
319+ AdaBoostStumpClassifier,
320+ DecisionTreeRegressor,
321+ RandomForestRegressor,
322+ }
323+
324+ const IterativeModel = Union{
325+ RandomForestClassifier,
326+ RandomForestRegressor,
327+ AdaBoostStumpClassifier,
328+ }
329+
330+ const RandomForestModel = Union{DecisionTreeClassifier, RandomForestClassifier}
331+
332+ # # FEATURE IMPORTANCES
316333
317334# get actual arguments needed for importance calculation from various fitresults.
318- get_fitresult (m:: Union{DecisionTreeClassifier, RandomForestClassifier} , fitresult) = (fitresult[1 ],)
319- get_fitresult (m:: Union{DecisionTreeRegressor, RandomForestRegressor} , fitresult) = (fitresult,)
335+ get_fitresult (
336+ m:: Union{DecisionTreeClassifier, RandomForestClassifier} ,
337+ fitresult,
338+ ) = (fitresult[1 ],)
339+ get_fitresult (
340+ m:: Union{DecisionTreeRegressor, RandomForestRegressor} ,
341+ fitresult,
342+ ) = (fitresult,)
320343get_fitresult (m:: AdaBoostStumpClassifier , fitresult)= (fitresult[1 ], fitresult[2 ])
321344
322- function MMI. feature_importances (m:: Union{DecisionTreeClassifier, RandomForestClassifier, AdaBoostStumpClassifier, DecisionTreeRegressor, RandomForestRegressor} , fitresult, report)
345+ function MMI. feature_importances (m:: TreeModel , fitresult, report)
323346 # generate feature importances for report
324347 if m. feature_importance == :impurity
325348 feature_importance_func = DT. impurity_importance
@@ -337,19 +360,8 @@ function MMI.feature_importances(m::Union{DecisionTreeClassifier, RandomForestCl
337360 return fi_pairs
338361end
339362
340-
341363# # METADATA (MODEL TRAITS)
342364
343- # following five lines of code are redundant if using this branch of
344- # MLJModelInterface:
345- # https://github.com/JuliaAI/MLJModelInterface.jl/pull/139
346-
347- # MMI.human_name(::Type{<:DecisionTreeClassifier}) = "CART decision tree classifier"
348- # MMI.human_name(::Type{<:RandomForestClassifier}) = "CART random forest classifier"
349- # MMI.human_name(::Type{<:AdaBoostStumpClassifier}) = "Ada-boosted stump classifier"
350- # MMI.human_name(::Type{<:DecisionTreeRegressor}) = "CART decision tree regressor"
351- # MMI.human_name(::Type{<:RandomForestRegressor}) = "CART random forest regressor"
352-
353365MMI. metadata_pkg .(
354366 (DecisionTreeClassifier, DecisionTreeRegressor,
355367 RandomForestClassifier, RandomForestRegressor,
0 commit comments