|
37 | 37 | Imputer,
|
38 | 38 | ImputerModel,
|
39 | 39 | NGram,
|
| 40 | + Normalizer, |
| 41 | + Interaction, |
40 | 42 | RFormula,
|
41 | 43 | Tokenizer,
|
42 | 44 | SQLTransformer,
|
@@ -923,13 +925,64 @@ def test_idf(self):
|
923 | 925 | self.assertEqual(str(model), str(model2))
|
924 | 926 |
|
925 | 927 | def test_ngram(self):
|
926 |
| - dataset = self.spark.createDataFrame([Row(input=["a", "b", "c", "d", "e"])]) |
927 |
| - ngram0 = NGram(n=4, inputCol="input", outputCol="output") |
928 |
| - self.assertEqual(ngram0.getN(), 4) |
929 |
| - self.assertEqual(ngram0.getInputCol(), "input") |
930 |
| - self.assertEqual(ngram0.getOutputCol(), "output") |
931 |
| - transformedDF = ngram0.transform(dataset) |
932 |
| - self.assertEqual(transformedDF.head().output, ["a b c d", "b c d e"]) |
| 928 | + spark = self.spark |
| 929 | + df = spark.createDataFrame([Row(input=["a", "b", "c", "d", "e"])]) |
| 930 | + |
| 931 | + ngram = NGram(n=4, inputCol="input", outputCol="output") |
| 932 | + self.assertEqual(ngram.getN(), 4) |
| 933 | + self.assertEqual(ngram.getInputCol(), "input") |
| 934 | + self.assertEqual(ngram.getOutputCol(), "output") |
| 935 | + |
| 936 | + output = ngram.transform(df) |
| 937 | + self.assertEqual(output.head().output, ["a b c d", "b c d e"]) |
| 938 | + |
| 939 | + # save & load |
| 940 | + with tempfile.TemporaryDirectory(prefix="ngram") as d: |
| 941 | + ngram.write().overwrite().save(d) |
| 942 | + ngram2 = NGram.load(d) |
| 943 | + self.assertEqual(str(ngram), str(ngram2)) |
| 944 | + |
| 945 | + def test_normalizer(self): |
| 946 | + spark = self.spark |
| 947 | + df = spark.createDataFrame( |
| 948 | + [(Vectors.dense([3.0, -4.0]),), (Vectors.sparse(4, {1: 4.0, 3: 3.0}),)], |
| 949 | + ["input"], |
| 950 | + ) |
| 951 | + |
| 952 | + normalizer = Normalizer(p=2.0, inputCol="input", outputCol="output") |
| 953 | + self.assertEqual(normalizer.getP(), 2.0) |
| 954 | + self.assertEqual(normalizer.getInputCol(), "input") |
| 955 | + self.assertEqual(normalizer.getOutputCol(), "output") |
| 956 | + |
| 957 | + output = normalizer.transform(df) |
| 958 | + self.assertEqual(output.columns, ["input", "output"]) |
| 959 | + self.assertEqual(output.count(), 2) |
| 960 | + |
| 961 | + # save & load |
| 962 | + with tempfile.TemporaryDirectory(prefix="normalizer") as d: |
| 963 | + normalizer.write().overwrite().save(d) |
| 964 | + normalizer2 = Normalizer.load(d) |
| 965 | + self.assertEqual(str(normalizer), str(normalizer2)) |
| 966 | + |
| 967 | + def test_interaction(self): |
| 968 | + spark = self.spark |
| 969 | + df = spark.createDataFrame([(0.0, 1.0), (2.0, 3.0)], ["a", "b"]) |
| 970 | + |
| 971 | + interaction = Interaction() |
| 972 | + interaction.setInputCols(["a", "b"]) |
| 973 | + interaction.setOutputCol("ab") |
| 974 | + self.assertEqual(interaction.getInputCols(), ["a", "b"]) |
| 975 | + self.assertEqual(interaction.getOutputCol(), "ab") |
| 976 | + |
| 977 | + output = interaction.transform(df) |
| 978 | + self.assertEqual(output.columns, ["a", "b", "ab"]) |
| 979 | + self.assertEqual(output.count(), 2) |
| 980 | + |
| 981 | + # save & load |
| 982 | + with tempfile.TemporaryDirectory(prefix="interaction") as d: |
| 983 | + interaction.write().overwrite().save(d) |
| 984 | + interaction2 = Interaction.load(d) |
| 985 | + self.assertEqual(str(interaction), str(interaction2)) |
933 | 986 |
|
934 | 987 | def test_count_vectorizer_with_binary(self):
|
935 | 988 | dataset = self.spark.createDataFrame(
|
|
0 commit comments