|
42 | 42 | MinMaxScalerModel,
|
43 | 43 | RobustScaler,
|
44 | 44 | RobustScalerModel,
|
| 45 | + ChiSqSelector, |
| 46 | + ChiSqSelectorModel, |
| 47 | + UnivariateFeatureSelector, |
| 48 | + UnivariateFeatureSelectorModel, |
| 49 | + VarianceThresholdSelector, |
| 50 | + VarianceThresholdSelectorModel, |
45 | 51 | StopWordsRemover,
|
46 | 52 | StringIndexer,
|
47 | 53 | StringIndexerModel,
|
@@ -391,6 +397,102 @@ def test_robust_scaler(self):
|
391 | 397 | self.assertEqual(str(model), str(model2))
|
392 | 398 | self.assertEqual(model2.getOutputCol(), "scaled")
|
393 | 399 |
|
| 400 | + def test_chi_sq_selector(self): |
| 401 | + df = self.spark.createDataFrame( |
| 402 | + [ |
| 403 | + (Vectors.dense([0.0, 0.0, 18.0, 1.0]), 1.0), |
| 404 | + (Vectors.dense([0.0, 1.0, 12.0, 0.0]), 0.0), |
| 405 | + (Vectors.dense([1.0, 0.0, 15.0, 0.1]), 0.0), |
| 406 | + ], |
| 407 | + ["features", "label"], |
| 408 | + ) |
| 409 | + |
| 410 | + selector = ChiSqSelector(numTopFeatures=1, outputCol="selectedFeatures") |
| 411 | + self.assertEqual(selector.getNumTopFeatures(), 1) |
| 412 | + self.assertEqual(selector.getOutputCol(), "selectedFeatures") |
| 413 | + |
| 414 | + model = selector.fit(df) |
| 415 | + self.assertEqual(model.selectedFeatures, [2]) |
| 416 | + |
| 417 | + output = model.transform(df) |
| 418 | + self.assertEqual(output.columns, ["features", "label", "selectedFeatures"]) |
| 419 | + self.assertEqual(output.count(), 3) |
| 420 | + |
| 421 | + # save & load |
| 422 | + with tempfile.TemporaryDirectory(prefix="chi_sq_selector") as d: |
| 423 | + selector.write().overwrite().save(d) |
| 424 | + selector2 = ChiSqSelector.load(d) |
| 425 | + self.assertEqual(str(selector), str(selector2)) |
| 426 | + |
| 427 | + model.write().overwrite().save(d) |
| 428 | + model2 = ChiSqSelectorModel.load(d) |
| 429 | + self.assertEqual(str(model), str(model2)) |
| 430 | + |
| 431 | + def test_univariate_selector(self): |
| 432 | + df = self.spark.createDataFrame( |
| 433 | + [ |
| 434 | + (Vectors.dense([0.0, 0.0, 18.0, 1.0]), 1.0), |
| 435 | + (Vectors.dense([0.0, 1.0, 12.0, 0.0]), 0.0), |
| 436 | + (Vectors.dense([1.0, 0.0, 15.0, 0.1]), 0.0), |
| 437 | + ], |
| 438 | + ["features", "label"], |
| 439 | + ) |
| 440 | + |
| 441 | + selector = UnivariateFeatureSelector(outputCol="selectedFeatures") |
| 442 | + selector.setFeatureType("continuous").setLabelType("categorical").setSelectionThreshold(1) |
| 443 | + self.assertEqual(selector.getFeatureType(), "continuous") |
| 444 | + self.assertEqual(selector.getLabelType(), "categorical") |
| 445 | + self.assertEqual(selector.getOutputCol(), "selectedFeatures") |
| 446 | + self.assertEqual(selector.getSelectionThreshold(), 1) |
| 447 | + |
| 448 | + model = selector.fit(df) |
| 449 | + self.assertEqual(model.selectedFeatures, [3]) |
| 450 | + |
| 451 | + output = model.transform(df) |
| 452 | + self.assertEqual(output.columns, ["features", "label", "selectedFeatures"]) |
| 453 | + self.assertEqual(output.count(), 3) |
| 454 | + |
| 455 | + # save & load |
| 456 | + with tempfile.TemporaryDirectory(prefix="univariate_selector") as d: |
| 457 | + selector.write().overwrite().save(d) |
| 458 | + selector2 = UnivariateFeatureSelector.load(d) |
| 459 | + self.assertEqual(str(selector), str(selector2)) |
| 460 | + |
| 461 | + model.write().overwrite().save(d) |
| 462 | + model2 = UnivariateFeatureSelectorModel.load(d) |
| 463 | + self.assertEqual(str(model), str(model2)) |
| 464 | + |
| 465 | + def test_variance_threshold_selector(self): |
| 466 | + df = self.spark.createDataFrame( |
| 467 | + [ |
| 468 | + (Vectors.dense([0.0, 0.0, 18.0, 1.0]), 1.0), |
| 469 | + (Vectors.dense([0.0, 1.0, 12.0, 0.0]), 0.0), |
| 470 | + (Vectors.dense([1.0, 0.0, 15.0, 0.1]), 0.0), |
| 471 | + ], |
| 472 | + ["features", "label"], |
| 473 | + ) |
| 474 | + |
| 475 | + selector = VarianceThresholdSelector(varianceThreshold=2, outputCol="selectedFeatures") |
| 476 | + self.assertEqual(selector.getVarianceThreshold(), 2) |
| 477 | + self.assertEqual(selector.getOutputCol(), "selectedFeatures") |
| 478 | + |
| 479 | + model = selector.fit(df) |
| 480 | + self.assertEqual(model.selectedFeatures, [2]) |
| 481 | + |
| 482 | + output = model.transform(df) |
| 483 | + self.assertEqual(output.columns, ["features", "label", "selectedFeatures"]) |
| 484 | + self.assertEqual(output.count(), 3) |
| 485 | + |
| 486 | + # save & load |
| 487 | + with tempfile.TemporaryDirectory(prefix="variance_threshold_selector") as d: |
| 488 | + selector.write().overwrite().save(d) |
| 489 | + selector2 = VarianceThresholdSelector.load(d) |
| 490 | + self.assertEqual(str(selector), str(selector2)) |
| 491 | + |
| 492 | + model.write().overwrite().save(d) |
| 493 | + model2 = VarianceThresholdSelectorModel.load(d) |
| 494 | + self.assertEqual(str(model), str(model2)) |
| 495 | + |
394 | 496 | def test_word2vec(self):
|
395 | 497 | sent = ("a b " * 100 + "a c " * 10).split(" ")
|
396 | 498 | df = self.spark.createDataFrame([(sent,), (sent,)], ["sentence"]).coalesce(1)
|
|
0 commit comments