This repository was archived by the owner on Dec 4, 2019. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 227
Is there any way to broadcast sk-learn model? #71
Copy link
Copy link
Open
Labels
Description
I hava try some code like this:
from spark_sklearn import GridSearchCV
import cPickle as pickle
session = SparkSession.builder.master("local[2]").appName("test").getOrCreate()
# iris = datasets.load_iris()
# print(iris.target)
documentDF = session.createDataFrame([
("Hi I heard about Spark", "spark"),
("I wish Java could use case classes", "java"),
("Logistic regression models are neat", "mlib"),
("Logistic regression models are neat", "spark"),
("Logistic regression models are neat", "mlib"),
("Logistic regression models are neat", "java"),
("Logistic regression models are neat", "spark"),
("Logistic regression models are neat", "java"),
("Logistic regression models are neat", "mlib")
], ["text", "preds"]).select(f.split("text", "\\s+").alias("new_text"), "preds")
word2vec = Word2Vec(vectorSize=100, minCount=1, inputCol="new_text",
outputCol="features")
indexer = StringIndexer(inputCol="preds", outputCol="labels")
pipline = Pipeline(stages=[word2vec, indexer])
ds = pipline.fit(documentDF).transform(documentDF)
data = ds.toPandas()
parameters = {'kernel': ('linear', 'rbf')}
svr = svm.SVC()
clf = GridSearchCV(session.sparkContext, svr, parameters)
X = [x.values for x in data.features.values]
y = [int(x) for x in data.labels.values]
model = clf.fit(X, y)
modelB = session.sparkContext.broadcast(pickle.dumps(model))
wow = documentDF.rdd.map(lambda row: pickle.loads(modelB.value).transform(row["features"].values)).collect()
print(wow)but the code will fail because sk-learn model can not been pickle. Is there any way to broadcast model?