25
25
from pyspark .ml .feature import (
26
26
DCT ,
27
27
Binarizer ,
28
+ Bucketizer ,
28
29
CountVectorizer ,
29
30
CountVectorizerModel ,
30
31
HashingTF ,
@@ -688,17 +689,17 @@ def test_binarizer(self):
688
689
["v1" , "v2" ],
689
690
)
690
691
691
- bucketizer = Binarizer (threshold = 1.0 , inputCol = "v1" , outputCol = "f1" )
692
- output = bucketizer .transform (df )
692
+ binarizer = Binarizer (threshold = 1.0 , inputCol = "v1" , outputCol = "f1" )
693
+ output = binarizer .transform (df )
693
694
self .assertEqual (output .columns , ["v1" , "v2" , "f1" ])
694
695
self .assertEqual (output .count (), 6 )
695
696
self .assertEqual (
696
697
[r .f1 for r in output .select ("f1" ).collect ()],
697
698
[0.0 , 0.0 , 1.0 , 1.0 , 0.0 , 0.0 ],
698
699
)
699
700
700
- bucketizer = Binarizer (threshold = 1.0 , inputCols = ["v1" , "v2" ], outputCols = ["f1" , "f2" ])
701
- output = bucketizer .transform (df )
701
+ binarizer = Binarizer (threshold = 1.0 , inputCols = ["v1" , "v2" ], outputCols = ["f1" , "f2" ])
702
+ output = binarizer .transform (df )
702
703
self .assertEqual (output .columns , ["v1" , "v2" , "f1" , "f2" ])
703
704
self .assertEqual (output .count (), 6 )
704
705
self .assertEqual (
@@ -712,8 +713,74 @@ def test_binarizer(self):
712
713
713
714
# save & load
714
715
with tempfile .TemporaryDirectory (prefix = "binarizer" ) as d :
716
+ binarizer .write ().overwrite ().save (d )
717
+ binarizer2 = Binarizer .load (d )
718
+ self .assertEqual (str (binarizer ), str (binarizer2 ))
719
+
720
+ def test_bucketizer (self ):
721
+ df = self .spark .createDataFrame (
722
+ [
723
+ (0.1 , 0.0 ),
724
+ (0.4 , 1.0 ),
725
+ (1.2 , 1.3 ),
726
+ (1.5 , float ("nan" )),
727
+ (float ("nan" ), 1.0 ),
728
+ (float ("nan" ), 0.0 ),
729
+ ],
730
+ ["v1" , "v2" ],
731
+ )
732
+
733
+ splits = [- float ("inf" ), 0.5 , 1.4 , float ("inf" )]
734
+ bucketizer = Bucketizer ()
735
+ bucketizer .setSplits (splits )
736
+ bucketizer .setHandleInvalid ("keep" )
737
+ bucketizer .setInputCol ("v1" )
738
+ bucketizer .setOutputCol ("b1" )
739
+
740
+ self .assertEqual (bucketizer .getSplits (), splits )
741
+ self .assertEqual (bucketizer .getHandleInvalid (), "keep" )
742
+ self .assertEqual (bucketizer .getInputCol (), "v1" )
743
+ self .assertEqual (bucketizer .getOutputCol (), "b1" )
744
+
745
+ output = bucketizer .transform (df )
746
+ self .assertEqual (output .columns , ["v1" , "v2" , "b1" ])
747
+ self .assertEqual (output .count (), 6 )
748
+ self .assertEqual (
749
+ [r .b1 for r in output .select ("b1" ).collect ()],
750
+ [0.0 , 0.0 , 1.0 , 2.0 , 3.0 , 3.0 ],
751
+ )
752
+
753
+ splitsArray = [
754
+ [- float ("inf" ), 0.5 , 1.4 , float ("inf" )],
755
+ [- float ("inf" ), 0.5 , float ("inf" )],
756
+ ]
757
+ bucketizer = Bucketizer (
758
+ splitsArray = splitsArray ,
759
+ inputCols = ["v1" , "v2" ],
760
+ outputCols = ["b1" , "b2" ],
761
+ )
762
+ bucketizer .setHandleInvalid ("keep" )
763
+ self .assertEqual (bucketizer .getSplitsArray (), splitsArray )
764
+ self .assertEqual (bucketizer .getHandleInvalid (), "keep" )
765
+ self .assertEqual (bucketizer .getInputCols (), ["v1" , "v2" ])
766
+ self .assertEqual (bucketizer .getOutputCols (), ["b1" , "b2" ])
767
+
768
+ output = bucketizer .transform (df )
769
+ self .assertEqual (output .columns , ["v1" , "v2" , "b1" , "b2" ])
770
+ self .assertEqual (output .count (), 6 )
771
+ self .assertEqual (
772
+ [r .b1 for r in output .select ("b1" ).collect ()],
773
+ [0.0 , 0.0 , 1.0 , 2.0 , 3.0 , 3.0 ],
774
+ )
775
+ self .assertEqual (
776
+ [r .b2 for r in output .select ("b2" ).collect ()],
777
+ [0.0 , 1.0 , 1.0 , 2.0 , 1.0 , 0.0 ],
778
+ )
779
+
780
+ # save & load
781
+ with tempfile .TemporaryDirectory (prefix = "bucketizer" ) as d :
715
782
bucketizer .write ().overwrite ().save (d )
716
- bucketizer2 = Binarizer .load (d )
783
+ bucketizer2 = Bucketizer .load (d )
717
784
self .assertEqual (str (bucketizer ), str (bucketizer2 ))
718
785
719
786
def test_idf (self ):
0 commit comments