File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed
Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change 44from typing import Any
55
66import numpy as np
7- from sklearn .ensemble import RandomForestClassifier
8- from sklearn .tree import DecisionTreeClassifier , DecisionTreeRegressor
97
108from ..core import TreeliteError
119from ..model import Model
@@ -61,6 +59,7 @@ def _export_tree(
6159 # pylint: disable=too-many-locals
6260 try :
6361 from sklearn import __version__ as sklearn_version
62+ from sklearn .tree import DecisionTreeClassifier
6463 from sklearn .tree ._tree import Tree as SKLearnTree
6564 except ImportError as e :
6665 raise TreeliteError ("This function requires scikit-learn package" ) from e
@@ -126,7 +125,7 @@ def _export_tree(
126125 return subestimator
127126
128127
129- def export_model (model : Model ):
128+ def export_model (model : Model ) -> Any :
130129 """
131130 Export a model as a scikit-learn RandomForest.
132131
@@ -153,7 +152,8 @@ def export_model(model: Model):
153152 # pylint: disable=too-many-locals
154153 try :
155154 from sklearn import __version__ as sklearn_version
156- from sklearn .ensemble import RandomForestRegressor
155+ from sklearn .ensemble import RandomForestClassifier , RandomForestRegressor
156+ from sklearn .tree import DecisionTreeClassifier , DecisionTreeRegressor
157157 except ImportError as e :
158158 raise TreeliteError ("This function requires scikit-learn package" ) from e
159159
You can’t perform that action at this time.
0 commit comments