11# SPDX-License-Identifier: Apache-2.0
22
3- '''
3+ """
44Mapping and utilities for the names of Params(propeties) that various Spark ML models
55have for their input and output columns
6- '''
6+ """
77from .ops_names import get_sparkml_operator_name
88
99
1010def build_io_name_map ():
11- '''
11+ """
1212 map of spark models to input-output tuples
1313 Each lambda gets the corresponding input or output column name from the model
14- '''
14+ """
1515 map = {
1616 "pyspark.ml.feature.BucketedRandomProjectionLSHModel" : (
1717 lambda model : [model .getOrDefault ("inputCol" )],
18- lambda model : [model .getOrDefault ("outputCol" )]
18+ lambda model : [model .getOrDefault ("outputCol" )],
1919 ),
2020 "pyspark.ml.regression.AFTSurvivalRegressionModel" : (
2121 lambda model : [model .getOrDefault ("featuresCol" )],
22- lambda model : [model .getOrDefault ("predictionCol" )]
22+ lambda model : [model .getOrDefault ("predictionCol" )],
2323 ),
2424 "pyspark.ml.feature.ElementwiseProduct" : (
2525 lambda model : [model .getOrDefault ("inputCol" )],
26- lambda model : [model .getOrDefault ("outputCol" )]
26+ lambda model : [model .getOrDefault ("outputCol" )],
2727 ),
2828 "pyspark.ml.feature.MinHashLSHModel" : (
2929 lambda model : [model .getOrDefault ("inputCol" )],
30- lambda model : [model .getOrDefault ("outputCol" )]
30+ lambda model : [model .getOrDefault ("outputCol" )],
3131 ),
3232 "pyspark.ml.feature.Word2VecModel" : (
3333 lambda model : [model .getOrDefault ("inputCol" )],
34- lambda model : [model .getOrDefault ("outputCol" )]
34+ lambda model : [model .getOrDefault ("outputCol" )],
3535 ),
3636 "pyspark.ml.feature.IndexToString" : (
3737 lambda model : [model .getOrDefault ("inputCol" )],
38- lambda model : [model .getOrDefault ("outputCol" )]
38+ lambda model : [model .getOrDefault ("outputCol" )],
3939 ),
4040 "pyspark.ml.feature.ChiSqSelectorModel" : (
4141 lambda model : [model .getOrDefault ("featuresCol" )],
42- lambda model : [model .getOrDefault ("outputCol" )]
42+ lambda model : [model .getOrDefault ("outputCol" )],
4343 ),
4444 "pyspark.ml.classification.OneVsRestModel" : (
4545 lambda model : [model .getOrDefault ("featuresCol" )],
46- lambda model : [model .getOrDefault ("predictionCol" )]
46+ lambda model : [model .getOrDefault ("predictionCol" )],
4747 ),
4848 "pyspark.ml.regression.GBTRegressionModel" : (
4949 lambda model : [model .getOrDefault ("featuresCol" )],
50- lambda model : [model .getOrDefault ("predictionCol" )]
50+ lambda model : [model .getOrDefault ("predictionCol" )],
5151 ),
5252 "pyspark.ml.classification.GBTClassificationModel" : (
5353 lambda model : [model .getOrDefault ("featuresCol" )],
54- lambda model : [model .getOrDefault ("predictionCol" ), ' probability' ]
54+ lambda model : [model .getOrDefault ("predictionCol" ), " probability" ],
5555 ),
5656 "pyspark.ml.feature.DCT" : (
5757 lambda model : [model .getOrDefault ("inputCol" )],
58- lambda model : [model .getOrDefault ("outputCol" )]
58+ lambda model : [model .getOrDefault ("outputCol" )],
5959 ),
6060 "pyspark.ml.feature.PCAModel" : (
6161 lambda model : [model .getOrDefault ("inputCol" )],
62- lambda model : [model .getOrDefault ("outputCol" )]
62+ lambda model : [model .getOrDefault ("outputCol" )],
6363 ),
6464 "pyspark.ml.feature.PolynomialExpansion" : (
6565 lambda model : [model .getOrDefault ("inputCol" )],
66- lambda model : [model .getOrDefault ("outputCol" )]
66+ lambda model : [model .getOrDefault ("outputCol" )],
6767 ),
6868 "pyspark.ml.feature.Tokenizer" : (
6969 lambda model : [model .getOrDefault ("inputCol" )],
70- lambda model : [model .getOrDefault ("outputCol" )]
70+ lambda model : [model .getOrDefault ("outputCol" )],
7171 ),
7272 "pyspark.ml.classification.NaiveBayesModel" : (
7373 lambda model : [model .getOrDefault ("featuresCol" )],
74- lambda model : [model .getOrDefault ("predictionCol" ), model .getOrDefault ("probabilityCol" )]
74+ lambda model : [model .getOrDefault ("predictionCol" ), model .getOrDefault ("probabilityCol" )],
7575 ),
7676 "pyspark.ml.feature.VectorSlicer" : (
7777 lambda model : [model .getOrDefault ("inputCol" )],
78- lambda model : [model .getOrDefault ("outputCol" )]
78+ lambda model : [model .getOrDefault ("outputCol" )],
7979 ),
8080 "pyspark.ml.feature.StopWordsRemover" : (
8181 lambda model : [model .getOrDefault ("inputCol" )],
82- lambda model : [model .getOrDefault ("outputCol" )]
82+ lambda model : [model .getOrDefault ("outputCol" )],
8383 ),
8484 "pyspark.ml.feature.NGram" : (
8585 lambda model : [model .getOrDefault ("inputCol" )],
86- lambda model : [model .getOrDefault ("outputCol" )]
86+ lambda model : [model .getOrDefault ("outputCol" )],
8787 ),
8888 "pyspark.ml.feature.Bucketizer" : (
8989 lambda model : [model .getOrDefault ("inputCol" )],
90- lambda model : [model .getOrDefault ("outputCol" )]
90+ lambda model : [model .getOrDefault ("outputCol" )],
9191 ),
9292 "pyspark.ml.regression.RandomForestRegressionModel" : (
9393 lambda model : [model .getOrDefault ("featuresCol" )],
94- lambda model : [model .getOrDefault ("predictionCol" )]
94+ lambda model : [model .getOrDefault ("predictionCol" )],
9595 ),
9696 "pyspark.ml.classification.RandomForestClassificationModel" : (
9797 lambda model : [model .getOrDefault ("featuresCol" )],
98- lambda model : [model .getOrDefault ("predictionCol" ), model .getOrDefault ("probabilityCol" )]
98+ lambda model : [model .getOrDefault ("predictionCol" ), model .getOrDefault ("probabilityCol" )],
9999 ),
100100 "pyspark.ml.regression.DecisionTreeRegressionModel" : (
101101 lambda model : [model .getOrDefault ("featuresCol" )],
102- lambda model : [model .getOrDefault ("predictionCol" )]
102+ lambda model : [model .getOrDefault ("predictionCol" )],
103103 ),
104104 "pyspark.ml.classification.DecisionTreeClassificationModel" : (
105105 lambda model : [model .getOrDefault ("featuresCol" )],
106- lambda model : [model .getOrDefault ("predictionCol" ), model .getOrDefault ("probabilityCol" )]
106+ lambda model : [model .getOrDefault ("predictionCol" ), model .getOrDefault ("probabilityCol" )],
107107 ),
108108 "pyspark.ml.feature.VectorIndexerModel" : (
109109 lambda model : [model .getOrDefault ("inputCol" )],
110- lambda model : [model .getOrDefault ("outputCol" )]
110+ lambda model : [model .getOrDefault ("outputCol" )],
111111 ),
112112 "pyspark.ml.regression.GeneralizedLinearRegressionModel" : (
113113 lambda model : [model .getOrDefault ("featuresCol" )],
114- lambda model : [model .getOrDefault ("predictionCol" )]
114+ lambda model : [model .getOrDefault ("predictionCol" )],
115115 ),
116116 "pyspark.ml.regression.LinearRegressionModel" : (
117117 lambda model : [model .getOrDefault ("featuresCol" )],
118- lambda model : [model .getOrDefault ("predictionCol" )]
118+ lambda model : [model .getOrDefault ("predictionCol" )],
119119 ),
120120 "pyspark.ml.feature.ImputerModel" : (
121121 lambda model : model .getOrDefault ("inputCols" ),
122- lambda model : model .getOrDefault ("outputCols" )
122+ lambda model : model .getOrDefault ("outputCols" ),
123123 ),
124124 "pyspark.ml.feature.MaxAbsScalerModel" : (
125125 lambda model : [model .getOrDefault ("inputCol" )],
126- lambda model : [model .getOrDefault ("outputCol" )]
126+ lambda model : [model .getOrDefault ("outputCol" )],
127127 ),
128128 "pyspark.ml.feature.MinMaxScalerModel" : (
129129 lambda model : [model .getOrDefault ("inputCol" )],
130- lambda model : [model .getOrDefault ("outputCol" )]
130+ lambda model : [model .getOrDefault ("outputCol" )],
131131 ),
132132 "pyspark.ml.feature.StandardScalerModel" : (
133133 lambda model : [model .getOrDefault ("inputCol" )],
134- lambda model : [model .getOrDefault ("outputCol" )]
134+ lambda model : [model .getOrDefault ("outputCol" )],
135135 ),
136136 "pyspark.ml.feature.Normalizer" : (
137137 lambda model : [model .getOrDefault ("inputCol" )],
138- lambda model : [model .getOrDefault ("outputCol" )]
138+ lambda model : [model .getOrDefault ("outputCol" )],
139139 ),
140140 "pyspark.ml.feature.Binarizer" : (
141141 lambda model : [model .getOrDefault ("inputCol" )],
142- lambda model : [model .getOrDefault ("outputCol" )]
142+ lambda model : [model .getOrDefault ("outputCol" )],
143143 ),
144144 "pyspark.ml.feature.CountVectorizerModel" : (
145145 lambda model : [model .getOrDefault ("inputCol" )],
146- lambda model : [model .getOrDefault ("outputCol" )]
146+ lambda model : [model .getOrDefault ("outputCol" )],
147147 ),
148148 "pyspark.ml.classification.LinearSVCModel" : (
149149 lambda model : [model .getOrDefault ("featuresCol" )],
150- lambda model : [model .getOrDefault ("predictionCol" )]
150+ lambda model : [model .getOrDefault ("predictionCol" )],
151151 ),
152152 "pyspark.ml.classification.LogisticRegressionModel" : (
153153 lambda model : [model .getOrDefault ("featuresCol" )],
154- lambda model : [model .getOrDefault ("predictionCol" ), model .getOrDefault ("probabilityCol" )]
154+ lambda model : [model .getOrDefault ("predictionCol" ), model .getOrDefault ("probabilityCol" )],
155155 ),
156156 "pyspark.ml.feature.OneHotEncoderModel" : (
157157 lambda model : model .getOrDefault ("inputCols" )
@@ -162,17 +162,21 @@ def build_io_name_map():
162162 else [model .getOrDefault ("outputCol" )],
163163 ),
164164 "pyspark.ml.feature.StringIndexerModel" : (
165- lambda model : [model .getOrDefault ("inputCol" )],
166- lambda model : [model .getOrDefault ("outputCol" )]
165+ lambda model : model .getOrDefault ("inputCols" )
166+ if model .isSet ("inputCols" )
167+ else [model .getOrDefault ("inputCol" )],
168+ lambda model : model .getOrDefault ("outputCols" )
169+ if model .isSet ("outputCols" )
170+ else [model .getOrDefault ("outputCol" )],
167171 ),
168172 "pyspark.ml.feature.VectorAssembler" : (
169173 lambda model : model .getOrDefault ("inputCols" ),
170- lambda model : [model .getOrDefault ("outputCol" )]
174+ lambda model : [model .getOrDefault ("outputCol" )],
171175 ),
172176 "pyspark.ml.clustering.KMeansModel" : (
173177 lambda model : [model .getOrDefault ("featuresCol" )],
174- lambda model : [model .getOrDefault ("predictionCol" )]
175- )
178+ lambda model : [model .getOrDefault ("predictionCol" )],
179+ ),
176180 }
177181 return map
178182
@@ -181,18 +185,18 @@ def build_io_name_map():
181185
182186
183187def get_input_names (model ):
184- '''
188+ """
185189 Returns the name(s) of the input(s) for a SparkML operator
186190 :param model: SparkML Model
187191 :return: list of input names
188- '''
192+ """
189193 return io_name_map [get_sparkml_operator_name (type (model ))][0 ](model )
190194
191195
192196def get_output_names (model ):
193- '''
197+ """
194198 Returns the name(s) of the output(s) for a SparkML operator
195199 :param model: SparkML Model
196200 :return: list of output names
197- '''
201+ """
198202 return io_name_map [get_sparkml_operator_name (type (model ))][1 ](model )
0 commit comments