Skip to content

Commit d0eef62

Browse files
authored
Make scikit-learn optional again (#596)
1 parent 0b8ea3e commit d0eef62

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

python/treelite/sklearn/exporter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
from typing import Any
55

66
import numpy as np
7-
from sklearn.ensemble import RandomForestClassifier
8-
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
97

108
from ..core import TreeliteError
119
from ..model import Model
@@ -61,6 +59,7 @@ def _export_tree(
6159
# pylint: disable=too-many-locals
6260
try:
6361
from sklearn import __version__ as sklearn_version
62+
from sklearn.tree import DecisionTreeClassifier
6463
from sklearn.tree._tree import Tree as SKLearnTree
6564
except ImportError as e:
6665
raise TreeliteError("This function requires scikit-learn package") from e
@@ -126,7 +125,7 @@ def _export_tree(
126125
return subestimator
127126

128127

129-
def export_model(model: Model):
128+
def export_model(model: Model) -> Any:
130129
"""
131130
Export a model as a scikit-learn RandomForest.
132131
@@ -153,7 +152,8 @@ def export_model(model: Model):
153152
# pylint: disable=too-many-locals
154153
try:
155154
from sklearn import __version__ as sklearn_version
156-
from sklearn.ensemble import RandomForestRegressor
155+
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
156+
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
157157
except ImportError as e:
158158
raise TreeliteError("This function requires scikit-learn package") from e
159159

0 commit comments

Comments
 (0)