Skip to content
This repository was archived by the owner on Dec 4, 2019. It is now read-only.

Commit 4cc6001

Browse files
authored
Set .classes_, .coef_ on scikit models correctly to enable prediction (#96)
1 parent 9f67a74 commit 4cc6001

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

python/spark_sklearn/converter.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,16 @@ def toSKLearn(self, model):
105105
:return: scikit-learn model with equivalent predictive behavior.
106106
Currently, parameters or arguments for training are not copied.
107107
"""
108-
if isinstance(model, LogisticRegressionModel) or isinstance(model, LinearRegressionModel):
109-
return self._toSKLGLM(model)
108+
if isinstance(model, LogisticRegressionModel) :
109+
return self._toSKLGLM(model, True)
110+
if isinstance(model, LinearRegressionModel):
111+
return self._toSKLGLM(model, False)
110112
else:
111113
supported_types = map(lambda t: type(t), self._supported_spark_types)
112114
raise ValueError("Converter.toSKLearn cannot convert type: %s. Supported types: %s" %
113115
(type(model), ", ".join(supported_types)))
114116

115-
def _toSKLGLM(self, model):
117+
def _toSKLGLM(self, model, is_classifier):
116118
""" Private method for converting a GLM to a scikit-learn model
117119
TODO: Add model parameters as well.
118120
"""
@@ -122,7 +124,9 @@ def _toSKLGLM(self, model):
122124
weights = model.coefficients
123125
skl = skl_cls()
124126
skl.intercept_ = np.float64(intercept)
125-
skl.coef_ = weights.toArray()
127+
skl.coef_ = weights.toArray().reshape(1, -1)
128+
if is_classifier:
129+
skl.classes_ = np.array([0, 1])
126130
return skl
127131

128132
def toPandas(self, df):

python/spark_sklearn/converter_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,17 @@ def test_LogisticRegression_spark2skl(self):
4545
self.assertTrue(isinstance(skl_lr, SKL_LogisticRegression),
4646
"Expected sklearn LogisticRegression but found type %s" % type(skl_lr))
4747
self._compare_GLMs(skl_lr, lr)
48+
# Make sure this doesn't throw an error
49+
skl_lr.predict_proba(self.X)
4850

4951
def test_LinearRegression_spark2skl(self):
5052
lr = LinearRegression().fit(self.df)
5153
skl_lr = self.converter.toSKLearn(lr)
5254
self.assertTrue(isinstance(skl_lr, SKL_LinearRegression),
5355
"Expected sklearn LinearRegression but found type %s" % type(skl_lr))
5456
self._compare_GLMs(skl_lr, lr)
57+
# Make sure this doesn't throw an error
58+
skl_lr.predict(self.X)
5559

5660
def ztest_toPandas(self):
5761
data = [(Vectors.dense([0.1, 0.2]),),

0 commit comments

Comments
 (0)