Skip to content

Commit 44966c9

Browse files
zhengruifengdongjoon-hyun
authored andcommitted
[SPARK-50976][ML][PYTHON] Fix the save/load of TargetEncoder
### What changes were proposed in this pull request? 1, Fix the save/load of `TargetEncoder` 2, hide `TargetEncoderModel.stats` ### Why are the changes needed? 1, existing implementation of `save/load` actually does not work 2, in the python side, `TargetEncoderModel.stats` return a `JavaObject` which cannot be used. We should find a better way to expose the model coefficients. ``` In [1]: from pyspark.ml.feature import * ...: ...: df = spark.createDataFrame( ...: [ ...: (0, 3, 5.0, 0.0), ...: (1, 4, 5.0, 1.0), ...: (2, 3, 5.0, 0.0), ...: (0, 4, 6.0, 1.0), ...: (1, 3, 6.0, 0.0), ...: (2, 4, 6.0, 1.0), ...: (0, 3, 7.0, 0.0), ...: (1, 4, 8.0, 1.0), ...: (2, 3, 9.0, 0.0), ...: ], ...: schema="input1 short, input2 int, input3 double, label double", ...: ) ...: encoder = TargetEncoder( ...: inputCols=["input1", "input2", "input3"], ...: outputCols=["output", "output2", "output3"], ...: labelCol="label", ...: targetType="binary", ...: ) ...: model = encoder.fit(df) In [2]: model.stats Out[2]: JavaObject id=o92 In [5]: model.write().overwrite().save("/tmp/ta") In [6]: TargetEncoderModel.load("/tmp/ta") {"ts": "2025-01-24 19:06:54,598", "level": "ERROR", "logger": "DataFrameQueryContextLogger", "msg": "[UNRESOLVED_COLUMN.WITH_SUGGESTION] A column, variable, or function parameter with name `encodings` cannot be resolved. Did you mean one of the following? [`stats`]. SQLSTATE: 42703", "context": {"file": ... AnalysisException: [UNRESOLVED_COLUMN.WITH_SUGGESTION] A column, variable, or function parameter with name `encodings` cannot be resolved. Did you mean one of the following? [`stats`]. SQLSTATE: 42703; 'Project ['encodings] +- Relation [stats#37] parquet ``` ### Does this PR introduce _any_ user-facing change? No, since this algorithm was 4.0 only ### How was this patch tested? updated test ### Was this patch authored or co-authored using generative AI tooling? no Closes #49649 from zhengruifeng/ml_target_save_load. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 5db31ae commit 44966c9

File tree

3 files changed

+38
-159
lines changed

3 files changed

+38
-159
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,8 @@ object TargetEncoder extends DefaultParamsReadable[TargetEncoder] {
282282
*/
283283
@Since("4.0.0")
284284
class TargetEncoderModel private[ml] (
285-
@Since("4.0.0") override val uid: String,
286-
@Since("4.0.0") val stats: Array[Map[Double, (Double, Double)]])
285+
@Since("4.0.0") override val uid: String,
286+
@Since("4.0.0") private[ml] val stats: Array[Map[Double, (Double, Double)]])
287287
extends Model[TargetEncoderModel] with TargetEncoderBase with MLWritable {
288288

289289
/** @group setParam */
@@ -403,13 +403,18 @@ object TargetEncoderModel extends MLReadable[TargetEncoderModel] {
403403
private[TargetEncoderModel]
404404
class TargetEncoderModelWriter(instance: TargetEncoderModel) extends MLWriter {
405405

406-
private case class Data(stats: Array[Map[Double, (Double, Double)]])
406+
private case class Data(index: Int, categories: Array[Double],
407+
counts: Array[Double], stats: Array[Double])
407408

408409
override protected def saveImpl(path: String): Unit = {
409410
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
410-
val data = Data(instance.stats)
411+
val datum = instance.stats.iterator.zipWithIndex.map { case (stat, index) =>
412+
val (_categories, _countsAndStats) = stat.toSeq.unzip
413+
val (_counts, _stats) = _countsAndStats.unzip
414+
Data(index, _categories.toArray, _counts.toArray, _stats.toArray)
415+
}.toSeq
411416
val dataPath = new Path(path, "data").toString
412-
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
417+
sparkSession.createDataFrame(datum).write.parquet(dataPath)
413418
}
414419
}
415420

@@ -420,10 +425,18 @@ object TargetEncoderModel extends MLReadable[TargetEncoderModel] {
420425
override def load(path: String): TargetEncoderModel = {
421426
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
422427
val dataPath = new Path(path, "data").toString
423-
val data = sparkSession.read.parquet(dataPath)
424-
.select("encodings")
425-
.head()
426-
val stats = data.getAs[Array[Map[Double, (Double, Double)]]](0)
428+
429+
val stats = sparkSession.read.parquet(dataPath)
430+
.select("index", "categories", "counts", "stats")
431+
.collect()
432+
.map { row =>
433+
val index = row.getInt(0)
434+
val categories = row.getAs[Seq[Double]](1).toArray
435+
val counts = row.getAs[Seq[Double]](2).toArray
436+
val stats = row.getAs[Seq[Double]](3).toArray
437+
(index, categories.zip(counts.zip(stats)).toMap)
438+
}.sortBy(_._1).map(_._2)
439+
427440
val model = new TargetEncoderModel(metadata.uid, stats)
428441
metadata.getAndSetParams(model)
429442
model

python/pyspark/ml/feature.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5500,15 +5500,6 @@ def setSmoothing(self, value: float) -> "TargetEncoderModel":
55005500
"""
55015501
return self._set(smoothing=value)
55025502

5503-
@property
5504-
@since("4.0.0")
5505-
def stats(self) -> List[Dict[float, Tuple[float, float]]]:
5506-
"""
5507-
Fitted statistics for each feature to being encoded.
5508-
The list contains a dictionary for each input column.
5509-
"""
5510-
return self._call_java("stats")
5511-
55125503

55135504
@inherit_doc
55145505
class Tokenizer(

python/pyspark/ml/tests/test_feature.py

Lines changed: 16 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
StringIndexer,
5656
StringIndexerModel,
5757
TargetEncoder,
58+
TargetEncoderModel,
5859
VectorSizeHint,
5960
VectorAssembler,
6061
PCA,
@@ -1113,148 +1114,22 @@ def test_target_encoder_binary(self):
11131114
targetType="binary",
11141115
)
11151116
model = encoder.fit(df)
1116-
te = model.transform(df)
1117-
actual = te.drop("label").collect()
1118-
expected = [
1119-
Row(input1=0, input2=3, input3=5.0, output1=1.0 / 3, output2=0.0, output3=1.0 / 3),
1120-
Row(input1=1, input2=4, input3=5.0, output1=2.0 / 3, output2=1.0, output3=1.0 / 3),
1121-
Row(input1=2, input2=3, input3=5.0, output1=1.0 / 3, output2=0.0, output3=1.0 / 3),
1122-
Row(input1=0, input2=4, input3=6.0, output1=1.0 / 3, output2=1.0, output3=2.0 / 3),
1123-
Row(input1=1, input2=3, input3=6.0, output1=2.0 / 3, output2=0.0, output3=2.0 / 3),
1124-
Row(input1=2, input2=4, input3=6.0, output1=1.0 / 3, output2=1.0, output3=2.0 / 3),
1125-
Row(input1=0, input2=3, input3=7.0, output1=1.0 / 3, output2=0.0, output3=0.0),
1126-
Row(input1=1, input2=4, input3=8.0, output1=2.0 / 3, output2=1.0, output3=1.0),
1127-
Row(input1=2, input2=3, input3=9.0, output1=1.0 / 3, output2=0.0, output3=0.0),
1128-
]
1129-
self.assertEqual(actual, expected)
1130-
te = model.setSmoothing(1.0).transform(df)
1131-
actual = te.drop("label").collect()
1132-
expected = [
1133-
Row(
1134-
input1=0,
1135-
input2=3,
1136-
input3=5.0,
1137-
output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
1138-
output2=(1 - 5 / 6) * (4 / 9),
1139-
output3=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
1140-
),
1141-
Row(
1142-
input1=1,
1143-
input2=4,
1144-
input3=5.0,
1145-
output1=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9),
1146-
output2=(4 / 5) * 1 + (1 - 4 / 5) * (4 / 9),
1147-
output3=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
1148-
),
1149-
Row(
1150-
input1=2,
1151-
input2=3,
1152-
input3=5.0,
1153-
output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
1154-
output2=(1 - 5 / 6) * (4 / 9),
1155-
output3=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
1156-
),
1157-
Row(
1158-
input1=0,
1159-
input2=4,
1160-
input3=6.0,
1161-
output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
1162-
output2=(4 / 5) * 1 + (1 - 4 / 5) * (4 / 9),
1163-
output3=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9),
1164-
),
1165-
Row(
1166-
input1=1,
1167-
input2=3,
1168-
input3=6.0,
1169-
output1=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9),
1170-
output2=(1 - 5 / 6) * (4 / 9),
1171-
output3=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9),
1172-
),
1173-
Row(
1174-
input1=2,
1175-
input2=4,
1176-
input3=6.0,
1177-
output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
1178-
output2=(4 / 5) * 1 + (1 - 4 / 5) * (4 / 9),
1179-
output3=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9),
1180-
),
1181-
Row(
1182-
input1=0,
1183-
input2=3,
1184-
input3=7.0,
1185-
output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
1186-
output2=(1 - 5 / 6) * (4 / 9),
1187-
output3=(1 - 1 / 2) * (4 / 9),
1188-
),
1189-
Row(
1190-
input1=1,
1191-
input2=4,
1192-
input3=8.0,
1193-
output1=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9),
1194-
output2=(4 / 5) * 1 + (1 - 4 / 5) * (4 / 9),
1195-
output3=(1 / 2) + (1 - 1 / 2) * (4 / 9),
1196-
),
1197-
Row(
1198-
input1=2,
1199-
input2=3,
1200-
input3=9.0,
1201-
output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
1202-
output2=(1 - 5 / 6) * (4 / 9),
1203-
output3=(1 - 1 / 2) * (4 / 9),
1204-
),
1205-
]
1206-
self.assertEqual(actual, expected)
1207-
1208-
def test_target_encoder_continuous(self):
1209-
df = self.spark.createDataFrame(
1210-
[
1211-
(0, 3, 5.0, 10.0),
1212-
(1, 4, 5.0, 20.0),
1213-
(2, 3, 5.0, 30.0),
1214-
(0, 4, 6.0, 40.0),
1215-
(1, 3, 6.0, 50.0),
1216-
(2, 4, 6.0, 60.0),
1217-
(0, 3, 7.0, 70.0),
1218-
(1, 4, 8.0, 80.0),
1219-
(2, 3, 9.0, 90.0),
1220-
],
1221-
schema="input1 short, input2 int, input3 double, label double",
1222-
)
1223-
encoder = TargetEncoder(
1224-
inputCols=["input1", "input2", "input3"],
1225-
outputCols=["output", "output2", "output3"],
1226-
labelCol="label",
1227-
targetType="continuous",
1117+
output = model.transform(df)
1118+
self.assertEqual(
1119+
output.columns,
1120+
["input1", "input2", "input3", "label", "output", "output2", "output3"],
12281121
)
1229-
model = encoder.fit(df)
1230-
te = model.transform(df)
1231-
actual = te.drop("label").collect()
1232-
expected = [
1233-
Row(input1=0, input2=3, input3=5.0, output1=40.0, output2=50.0, output3=20.0),
1234-
Row(input1=1, input2=4, input3=5.0, output1=50.0, output2=50.0, output3=20.0),
1235-
Row(input1=2, input2=3, input3=5.0, output1=60.0, output2=50.0, output3=20.0),
1236-
Row(input1=0, input2=4, input3=6.0, output1=40.0, output2=50.0, output3=50.0),
1237-
Row(input1=1, input2=3, input3=6.0, output1=50.0, output2=50.0, output3=50.0),
1238-
Row(input1=2, input2=4, input3=6.0, output1=60.0, output2=50.0, output3=50.0),
1239-
Row(input1=0, input2=3, input3=7.0, output1=40.0, output2=50.0, output3=70.0),
1240-
Row(input1=1, input2=4, input3=8.0, output1=50.0, output2=50.0, output3=80.0),
1241-
Row(input1=2, input2=3, input3=9.0, output1=60.0, output2=50.0, output3=90.0),
1242-
]
1243-
self.assertEqual(actual, expected)
1244-
te = model.setSmoothing(1.0).transform(df)
1245-
actual = te.drop("label").collect()
1246-
expected = [
1247-
Row(input1=0, input2=3, input3=5.0, output1=42.5, output2=50.0, output3=27.5),
1248-
Row(input1=1, input2=4, input3=5.0, output1=50.0, output2=50.0, output3=27.5),
1249-
Row(input1=2, input2=3, input3=5.0, output1=57.5, output2=50.0, output3=27.5),
1250-
Row(input1=0, input2=4, input3=6.0, output1=42.5, output2=50.0, output3=50.0),
1251-
Row(input1=1, input2=3, input3=6.0, output1=50.0, output2=50.0, output3=50.0),
1252-
Row(input1=2, input2=4, input3=6.0, output1=57.5, output2=50.0, output3=50.0),
1253-
Row(input1=0, input2=3, input3=7.0, output1=42.5, output2=50.0, output3=60.0),
1254-
Row(input1=1, input2=4, input3=8.0, output1=50.0, output2=50.0, output3=65.0),
1255-
Row(input1=2, input2=3, input3=9.0, output1=57.5, output2=50.0, output3=70.0),
1256-
]
1257-
self.assertEqual(actual, expected)
1122+
self.assertEqual(output.count(), 9)
1123+
1124+
# save & load
1125+
with tempfile.TemporaryDirectory(prefix="target_encoder") as d:
1126+
encoder.write().overwrite().save(d)
1127+
encoder2 = TargetEncoder.load(d)
1128+
self.assertEqual(str(encoder), str(encoder2))
1129+
1130+
model.write().overwrite().save(d)
1131+
model2 = TargetEncoderModel.load(d)
1132+
self.assertEqual(str(model), str(model2))
12581133

12591134
def test_vector_size_hint(self):
12601135
df = self.spark.createDataFrame(

0 commit comments

Comments
 (0)