Skip to content

Commit 6ca0df3

Browse files
committed
Merge pull request #42 from lensacom/dev
Fixed dict vectorizer heck_rdd; DictRDD transform dtype issue
2 parents 9274ba1 + 3d519bd commit 6ca0df3

File tree

17 files changed

+111
-41
lines changed

17 files changed

+111
-41
lines changed

README.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Sparkit-learn
22
=============
33

4-
|Build Status| |PyPi|
4+
|Build Status| |PyPi| |Gitter|
55

66
**PySpark + Scikit-learn = Sparkit-learn**
77

@@ -448,3 +448,7 @@ Special thanks
448448
:target: https://travis-ci.org/lensacom/sparkit-learn
449449
.. |PyPi| image:: https://img.shields.io/pypi/v/sparkit-learn.svg
450450
:target: https://pypi.python.org/pypi/sparkit-learn
451+
.. |Gitter| image:: https://badges.gitter.im/Join%20Chat.svg
452+
:alt: Join the chat at https://gitter.im/lensacom/sparkit-learn
453+
:target: https://gitter.im/lensacom/sparkit-learn?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge
454+

setup.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
import sys
44

5-
from setuptools import setup
6-
from setuptools import find_packages
5+
from setuptools import find_packages, setup
76

87

98
def is_numpy_installed():
@@ -17,7 +16,7 @@ def is_numpy_installed():
1716
def setup_package():
1817
metadata = dict(
1918
name='sparkit-learn',
20-
version="0.2.4",
19+
version='0.2.5',
2120
description='Scikit-learn on PySpark',
2221
author='Krisztian Szucs, Andras Fulop',
2322
author_email='krisztian.szucs@lensa.com, andras.fulop@lensa.com',
@@ -33,7 +32,8 @@ def setup_package():
3332
if is_numpy_installed() is False:
3433
raise ImportError("Numerical Python (NumPy) is not installed.\n"
3534
"sparkit-learn requires NumPy.\n"
36-
"Installation instructions are available on scikit-learn website: "
35+
"Installation instructions are available on "
36+
"scikit-learn website: "
3737
"http://scikit-learn.org/stable/install.html\n")
3838

3939
setup(**metadata)

splearn/feature_extraction/dict_vectorizer.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
import numpy as np
44
import scipy.sparse as sp
5-
from sklearn.feature_extraction import DictVectorizer
5+
from pyspark import AccumulatorParam
66
from sklearn.externals import six
7+
from sklearn.feature_extraction import DictVectorizer
78

9+
from ..base import SparkBroadcasterMixin
810
from ..rdd import DictRDD
9-
from pyspark import AccumulatorParam
1011
from ..utils.validation import check_rdd
11-
from ..base import SparkBroadcasterMixin
1212

1313

1414
class SparkDictVectorizer(DictVectorizer, SparkBroadcasterMixin):
@@ -87,7 +87,6 @@ def fit(self, Z):
8787
self
8888
"""
8989
X = Z[:, 'X'] if isinstance(Z, DictRDD) else Z
90-
check_rdd(X, (np.ndarray,))
9190

9291
"""Create vocabulary
9392
"""
@@ -142,9 +141,6 @@ def transform(self, Z):
142141
Z : transformed, containing {array, sparse matrix}
143142
Feature vectors; always 2-d.
144143
"""
145-
X = Z[:, 'X'] if isinstance(Z, DictRDD) else Z
146-
check_rdd(X, (np.ndarray, sp.spmatrix))
147-
148144
mapper = self.broadcast(super(SparkDictVectorizer, self).transform,
149145
Z.context)
150146
dtype = sp.spmatrix if self.sparse else np.ndarray

splearn/feature_extraction/tests/test_dict_vectorizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import scipy.sparse as sp
21
import numpy as np
2+
import scipy.sparse as sp
33
from sklearn.feature_extraction import DictVectorizer
44
from splearn.feature_extraction import SparkDictVectorizer
55
from splearn.rdd import ArrayRDD

splearn/feature_extraction/tests/test_text.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
import scipy.sparse as sp
21
import numpy as np
2+
import scipy.sparse as sp
33
from sklearn.feature_extraction.text import (CountVectorizer,
44
HashingVectorizer,
55
TfidfTransformer)
66
from splearn.feature_extraction.text import (SparkCountVectorizer,
77
SparkHashingVectorizer,
88
SparkTfidfTransformer)
99
from splearn.utils.testing import (SplearnTestCase, assert_array_almost_equal,
10-
assert_array_equal, assert_equal, assert_true)
10+
assert_array_equal, assert_equal,
11+
assert_true)
1112
from splearn.utils.validation import check_rdd_dtype
1213

1314

splearn/feature_selection/tests/test_variance_threshold.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,6 @@ def test_same_transform_with_treshold(self):
8080
result_dist.toarray())
8181

8282
result_dist = dist.fit_transform(Z_rdd)[:, 'X']
83-
assert_true(check_rdd_dtype(result_dist, (sp.spmatrix,)))
83+
assert_true(check_rdd_dtype(result_dist, (sp.spmatrix,)))
8484
assert_array_almost_equal(result_local.toarray(),
8585
result_dist.toarray())

splearn/feature_selection/variance_threshold.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from sklearn.utils.sparsefuncs import mean_variance_axis
66

77
from ..rdd import DictRDD
8-
from .base import SparkSelectorMixin
98
from ..utils.validation import check_rdd
9+
from .base import SparkSelectorMixin
1010

1111

1212
class SparkVarianceThreshold(VarianceThreshold, SparkSelectorMixin):

splearn/linear_model/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# encoding: utf-8
22

3-
import scipy.sparse as sp
43
import numpy as np
5-
4+
import scipy.sparse as sp
65
from sklearn.base import copy
76
from sklearn.linear_model.base import LinearRegression
87

splearn/linear_model/logistic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import scipy.sparse as sp
55
from sklearn.linear_model import LogisticRegression
66

7-
from .base import SparkLinearModelMixin
87
from ..utils.validation import check_rdd
8+
from .base import SparkLinearModelMixin
99

1010

1111
class SparkLogisticRegression(LogisticRegression, SparkLinearModelMixin):

splearn/linear_model/stochastic_gradient.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import scipy.sparse as sp
55
from sklearn.linear_model import SGDClassifier
66

7-
from .base import SparkLinearModelMixin
87
from ..utils.validation import check_rdd
8+
from .base import SparkLinearModelMixin
99

1010

1111
class SparkSGDClassifier(SGDClassifier, SparkLinearModelMixin):

0 commit comments

Comments
 (0)