|
55 | 55 | StringIndexer,
|
56 | 56 | StringIndexerModel,
|
57 | 57 | TargetEncoder,
|
| 58 | + TargetEncoderModel, |
58 | 59 | VectorSizeHint,
|
59 | 60 | VectorAssembler,
|
60 | 61 | PCA,
|
@@ -1113,148 +1114,22 @@ def test_target_encoder_binary(self):
|
1113 | 1114 | targetType="binary",
|
1114 | 1115 | )
|
1115 | 1116 | model = encoder.fit(df)
|
1116 |
| - te = model.transform(df) |
1117 |
| - actual = te.drop("label").collect() |
1118 |
| - expected = [ |
1119 |
| - Row(input1=0, input2=3, input3=5.0, output1=1.0 / 3, output2=0.0, output3=1.0 / 3), |
1120 |
| - Row(input1=1, input2=4, input3=5.0, output1=2.0 / 3, output2=1.0, output3=1.0 / 3), |
1121 |
| - Row(input1=2, input2=3, input3=5.0, output1=1.0 / 3, output2=0.0, output3=1.0 / 3), |
1122 |
| - Row(input1=0, input2=4, input3=6.0, output1=1.0 / 3, output2=1.0, output3=2.0 / 3), |
1123 |
| - Row(input1=1, input2=3, input3=6.0, output1=2.0 / 3, output2=0.0, output3=2.0 / 3), |
1124 |
| - Row(input1=2, input2=4, input3=6.0, output1=1.0 / 3, output2=1.0, output3=2.0 / 3), |
1125 |
| - Row(input1=0, input2=3, input3=7.0, output1=1.0 / 3, output2=0.0, output3=0.0), |
1126 |
| - Row(input1=1, input2=4, input3=8.0, output1=2.0 / 3, output2=1.0, output3=1.0), |
1127 |
| - Row(input1=2, input2=3, input3=9.0, output1=1.0 / 3, output2=0.0, output3=0.0), |
1128 |
| - ] |
1129 |
| - self.assertEqual(actual, expected) |
1130 |
| - te = model.setSmoothing(1.0).transform(df) |
1131 |
| - actual = te.drop("label").collect() |
1132 |
| - expected = [ |
1133 |
| - Row( |
1134 |
| - input1=0, |
1135 |
| - input2=3, |
1136 |
| - input3=5.0, |
1137 |
| - output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), |
1138 |
| - output2=(1 - 5 / 6) * (4 / 9), |
1139 |
| - output3=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), |
1140 |
| - ), |
1141 |
| - Row( |
1142 |
| - input1=1, |
1143 |
| - input2=4, |
1144 |
| - input3=5.0, |
1145 |
| - output1=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9), |
1146 |
| - output2=(4 / 5) * 1 + (1 - 4 / 5) * (4 / 9), |
1147 |
| - output3=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), |
1148 |
| - ), |
1149 |
| - Row( |
1150 |
| - input1=2, |
1151 |
| - input2=3, |
1152 |
| - input3=5.0, |
1153 |
| - output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), |
1154 |
| - output2=(1 - 5 / 6) * (4 / 9), |
1155 |
| - output3=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), |
1156 |
| - ), |
1157 |
| - Row( |
1158 |
| - input1=0, |
1159 |
| - input2=4, |
1160 |
| - input3=6.0, |
1161 |
| - output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), |
1162 |
| - output2=(4 / 5) * 1 + (1 - 4 / 5) * (4 / 9), |
1163 |
| - output3=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9), |
1164 |
| - ), |
1165 |
| - Row( |
1166 |
| - input1=1, |
1167 |
| - input2=3, |
1168 |
| - input3=6.0, |
1169 |
| - output1=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9), |
1170 |
| - output2=(1 - 5 / 6) * (4 / 9), |
1171 |
| - output3=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9), |
1172 |
| - ), |
1173 |
| - Row( |
1174 |
| - input1=2, |
1175 |
| - input2=4, |
1176 |
| - input3=6.0, |
1177 |
| - output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), |
1178 |
| - output2=(4 / 5) * 1 + (1 - 4 / 5) * (4 / 9), |
1179 |
| - output3=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9), |
1180 |
| - ), |
1181 |
| - Row( |
1182 |
| - input1=0, |
1183 |
| - input2=3, |
1184 |
| - input3=7.0, |
1185 |
| - output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), |
1186 |
| - output2=(1 - 5 / 6) * (4 / 9), |
1187 |
| - output3=(1 - 1 / 2) * (4 / 9), |
1188 |
| - ), |
1189 |
| - Row( |
1190 |
| - input1=1, |
1191 |
| - input2=4, |
1192 |
| - input3=8.0, |
1193 |
| - output1=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9), |
1194 |
| - output2=(4 / 5) * 1 + (1 - 4 / 5) * (4 / 9), |
1195 |
| - output3=(1 / 2) + (1 - 1 / 2) * (4 / 9), |
1196 |
| - ), |
1197 |
| - Row( |
1198 |
| - input1=2, |
1199 |
| - input2=3, |
1200 |
| - input3=9.0, |
1201 |
| - output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), |
1202 |
| - output2=(1 - 5 / 6) * (4 / 9), |
1203 |
| - output3=(1 - 1 / 2) * (4 / 9), |
1204 |
| - ), |
1205 |
| - ] |
1206 |
| - self.assertEqual(actual, expected) |
1207 |
| - |
1208 |
| - def test_target_encoder_continuous(self): |
1209 |
| - df = self.spark.createDataFrame( |
1210 |
| - [ |
1211 |
| - (0, 3, 5.0, 10.0), |
1212 |
| - (1, 4, 5.0, 20.0), |
1213 |
| - (2, 3, 5.0, 30.0), |
1214 |
| - (0, 4, 6.0, 40.0), |
1215 |
| - (1, 3, 6.0, 50.0), |
1216 |
| - (2, 4, 6.0, 60.0), |
1217 |
| - (0, 3, 7.0, 70.0), |
1218 |
| - (1, 4, 8.0, 80.0), |
1219 |
| - (2, 3, 9.0, 90.0), |
1220 |
| - ], |
1221 |
| - schema="input1 short, input2 int, input3 double, label double", |
1222 |
| - ) |
1223 |
| - encoder = TargetEncoder( |
1224 |
| - inputCols=["input1", "input2", "input3"], |
1225 |
| - outputCols=["output", "output2", "output3"], |
1226 |
| - labelCol="label", |
1227 |
| - targetType="continuous", |
| 1117 | + output = model.transform(df) |
| 1118 | + self.assertEqual( |
| 1119 | + output.columns, |
| 1120 | + ["input1", "input2", "input3", "label", "output", "output2", "output3"], |
1228 | 1121 | )
|
1229 |
| - model = encoder.fit(df) |
1230 |
| - te = model.transform(df) |
1231 |
| - actual = te.drop("label").collect() |
1232 |
| - expected = [ |
1233 |
| - Row(input1=0, input2=3, input3=5.0, output1=40.0, output2=50.0, output3=20.0), |
1234 |
| - Row(input1=1, input2=4, input3=5.0, output1=50.0, output2=50.0, output3=20.0), |
1235 |
| - Row(input1=2, input2=3, input3=5.0, output1=60.0, output2=50.0, output3=20.0), |
1236 |
| - Row(input1=0, input2=4, input3=6.0, output1=40.0, output2=50.0, output3=50.0), |
1237 |
| - Row(input1=1, input2=3, input3=6.0, output1=50.0, output2=50.0, output3=50.0), |
1238 |
| - Row(input1=2, input2=4, input3=6.0, output1=60.0, output2=50.0, output3=50.0), |
1239 |
| - Row(input1=0, input2=3, input3=7.0, output1=40.0, output2=50.0, output3=70.0), |
1240 |
| - Row(input1=1, input2=4, input3=8.0, output1=50.0, output2=50.0, output3=80.0), |
1241 |
| - Row(input1=2, input2=3, input3=9.0, output1=60.0, output2=50.0, output3=90.0), |
1242 |
| - ] |
1243 |
| - self.assertEqual(actual, expected) |
1244 |
| - te = model.setSmoothing(1.0).transform(df) |
1245 |
| - actual = te.drop("label").collect() |
1246 |
| - expected = [ |
1247 |
| - Row(input1=0, input2=3, input3=5.0, output1=42.5, output2=50.0, output3=27.5), |
1248 |
| - Row(input1=1, input2=4, input3=5.0, output1=50.0, output2=50.0, output3=27.5), |
1249 |
| - Row(input1=2, input2=3, input3=5.0, output1=57.5, output2=50.0, output3=27.5), |
1250 |
| - Row(input1=0, input2=4, input3=6.0, output1=42.5, output2=50.0, output3=50.0), |
1251 |
| - Row(input1=1, input2=3, input3=6.0, output1=50.0, output2=50.0, output3=50.0), |
1252 |
| - Row(input1=2, input2=4, input3=6.0, output1=57.5, output2=50.0, output3=50.0), |
1253 |
| - Row(input1=0, input2=3, input3=7.0, output1=42.5, output2=50.0, output3=60.0), |
1254 |
| - Row(input1=1, input2=4, input3=8.0, output1=50.0, output2=50.0, output3=65.0), |
1255 |
| - Row(input1=2, input2=3, input3=9.0, output1=57.5, output2=50.0, output3=70.0), |
1256 |
| - ] |
1257 |
| - self.assertEqual(actual, expected) |
| 1122 | + self.assertEqual(output.count(), 9) |
| 1123 | + |
| 1124 | + # save & load |
| 1125 | + with tempfile.TemporaryDirectory(prefix="target_encoder") as d: |
| 1126 | + encoder.write().overwrite().save(d) |
| 1127 | + encoder2 = TargetEncoder.load(d) |
| 1128 | + self.assertEqual(str(encoder), str(encoder2)) |
| 1129 | + |
| 1130 | + model.write().overwrite().save(d) |
| 1131 | + model2 = TargetEncoderModel.load(d) |
| 1132 | + self.assertEqual(str(model), str(model2)) |
1258 | 1133 |
|
1259 | 1134 | def test_vector_size_hint(self):
|
1260 | 1135 | df = self.spark.createDataFrame(
|
|
0 commit comments