Skip to content

Commit b2ee254

Browse files
HyukjinKwonRobert Kruszewski
authored andcommitted
[SPARK-26818][ML] Make MLEvents JSON ser/de safe
## What changes were proposed in this pull request? Currently, it looks it's not going to cause any virtually effective problem apparently (if I didn't misread the codes). I see one place that JSON formatted events are being used. https://github.com/apache/spark/blob/ec506bd30c2ca324c12c9ec811764081c2eb8c42/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala#L148 It's okay because it just logs when the exception is ignorable https://github.com/apache/spark/blob/9690eba16efe6d25261934d8b73a221972b684f3/core/src/main/scala/org/apache/spark/util/ListenerBus.scala#L111 I guess it should be best to stay safe - I don't want this unstable experimental feature breaks anything in any case. It also disables `logEvent` in `SparkListenerEvent` for the same reason. This is also to match SQL execution events side: https://github.com/apache/spark/blob/ca545f79410a464ef24e3986fac225f53bb2ef02/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala#L41-L57 to make ML events JSON ser/de safe. ## How was this patch tested? Manually tested, and unit tests were added. Closes apache#23728 from HyukjinKwon/SPARK-26818. Authored-by: Hyukjin Kwon <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent 9cbd6f1 commit b2ee254

File tree

2 files changed

+155
-38
lines changed

2 files changed

+155
-38
lines changed

mllib/src/main/scala/org/apache/spark/ml/events.scala

Lines changed: 63 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.ml
1919

20+
import com.fasterxml.jackson.annotation.JsonIgnore
21+
2022
import org.apache.spark.SparkContext
2123
import org.apache.spark.annotation.Unstable
2224
import org.apache.spark.internal.Logging
@@ -29,53 +31,84 @@ import org.apache.spark.sql.{DataFrame, Dataset}
2931
* after each operation (the event should document this).
3032
*
3133
* @note This is supported via [[Pipeline]] and [[PipelineModel]].
34+
* @note This is experimental and unstable. Do not use this unless you fully
35+
* understand what `Unstable` means.
3236
*/
3337
@Unstable
34-
sealed trait MLEvent extends SparkListenerEvent
38+
sealed trait MLEvent extends SparkListenerEvent {
39+
// Do not log ML events in event log. It should be revisited to see
40+
// how it works with history server.
41+
protected[spark] override def logEvent: Boolean = false
42+
}
3543

3644
/**
3745
* Event fired before `Transformer.transform`.
3846
*/
3947
@Unstable
40-
case class TransformStart(transformer: Transformer, input: Dataset[_]) extends MLEvent
48+
case class TransformStart() extends MLEvent {
49+
@JsonIgnore var transformer: Transformer = _
50+
@JsonIgnore var input: Dataset[_] = _
51+
}
52+
4153
/**
4254
* Event fired after `Transformer.transform`.
4355
*/
4456
@Unstable
45-
case class TransformEnd(transformer: Transformer, output: Dataset[_]) extends MLEvent
57+
case class TransformEnd() extends MLEvent {
58+
@JsonIgnore var transformer: Transformer = _
59+
@JsonIgnore var output: Dataset[_] = _
60+
}
4661

4762
/**
4863
* Event fired before `Estimator.fit`.
4964
*/
5065
@Unstable
51-
case class FitStart[M <: Model[M]](estimator: Estimator[M], dataset: Dataset[_]) extends MLEvent
66+
case class FitStart[M <: Model[M]]() extends MLEvent {
67+
@JsonIgnore var estimator: Estimator[M] = _
68+
@JsonIgnore var dataset: Dataset[_] = _
69+
}
70+
5271
/**
5372
* Event fired after `Estimator.fit`.
5473
*/
5574
@Unstable
56-
case class FitEnd[M <: Model[M]](estimator: Estimator[M], model: M) extends MLEvent
75+
case class FitEnd[M <: Model[M]]() extends MLEvent {
76+
@JsonIgnore var estimator: Estimator[M] = _
77+
@JsonIgnore var model: M = _
78+
}
5779

5880
/**
5981
* Event fired before `MLReader.load`.
6082
*/
6183
@Unstable
62-
case class LoadInstanceStart[T](reader: MLReader[T], path: String) extends MLEvent
84+
case class LoadInstanceStart[T](path: String) extends MLEvent {
85+
@JsonIgnore var reader: MLReader[T] = _
86+
}
87+
6388
/**
6489
* Event fired after `MLReader.load`.
6590
*/
6691
@Unstable
67-
case class LoadInstanceEnd[T](reader: MLReader[T], instance: T) extends MLEvent
92+
case class LoadInstanceEnd[T]() extends MLEvent {
93+
@JsonIgnore var reader: MLReader[T] = _
94+
@JsonIgnore var instance: T = _
95+
}
6896

6997
/**
7098
* Event fired before `MLWriter.save`.
7199
*/
72100
@Unstable
73-
case class SaveInstanceStart(writer: MLWriter, path: String) extends MLEvent
101+
case class SaveInstanceStart(path: String) extends MLEvent {
102+
@JsonIgnore var writer: MLWriter = _
103+
}
104+
74105
/**
75106
* Event fired after `MLWriter.save`.
76107
*/
77108
@Unstable
78-
case class SaveInstanceEnd(writer: MLWriter, path: String) extends MLEvent
109+
case class SaveInstanceEnd(path: String) extends MLEvent {
110+
@JsonIgnore var writer: MLWriter = _
111+
}
79112

80113
/**
81114
* A small trait that defines some methods to send [[org.apache.spark.ml.MLEvent]].
@@ -91,46 +124,58 @@ private[ml] trait MLEvents extends Logging {
91124

92125
def withFitEvent[M <: Model[M]](
93126
estimator: Estimator[M], dataset: Dataset[_])(func: => M): M = {
94-
val startEvent = FitStart(estimator, dataset)
127+
val startEvent = FitStart[M]()
128+
startEvent.estimator = estimator
129+
startEvent.dataset = dataset
95130
logEvent(startEvent)
96131
listenerBus.post(startEvent)
97132
val model: M = func
98-
val endEvent = FitEnd(estimator, model)
133+
val endEvent = FitEnd[M]()
134+
endEvent.estimator = estimator
135+
endEvent.model = model
99136
logEvent(endEvent)
100137
listenerBus.post(endEvent)
101138
model
102139
}
103140

104141
def withTransformEvent(
105142
transformer: Transformer, input: Dataset[_])(func: => DataFrame): DataFrame = {
106-
val startEvent = TransformStart(transformer, input)
143+
val startEvent = TransformStart()
144+
startEvent.transformer = transformer
145+
startEvent.input = input
107146
logEvent(startEvent)
108147
listenerBus.post(startEvent)
109148
val output: DataFrame = func
110-
val endEvent = TransformEnd(transformer, output)
149+
val endEvent = TransformEnd()
150+
endEvent.transformer = transformer
151+
endEvent.output = output
111152
logEvent(endEvent)
112153
listenerBus.post(endEvent)
113154
output
114155
}
115156

116157
def withLoadInstanceEvent[T](reader: MLReader[T], path: String)(func: => T): T = {
117-
val startEvent = LoadInstanceStart(reader, path)
158+
val startEvent = LoadInstanceStart[T](path)
159+
startEvent.reader = reader
118160
logEvent(startEvent)
119161
listenerBus.post(startEvent)
120162
val instance: T = func
121-
val endEvent = LoadInstanceEnd(reader, instance)
163+
val endEvent = LoadInstanceEnd[T]()
164+
endEvent.reader = reader
165+
endEvent.instance = instance
122166
logEvent(endEvent)
123167
listenerBus.post(endEvent)
124168
instance
125169
}
126170

127171
def withSaveInstanceEvent(writer: MLWriter, path: String)(func: => Unit): Unit = {
128-
listenerBus.post(SaveInstanceEnd(writer, path))
129-
val startEvent = SaveInstanceStart(writer, path)
172+
val startEvent = SaveInstanceStart(path)
173+
startEvent.writer = writer
130174
logEvent(startEvent)
131175
listenerBus.post(startEvent)
132176
func
133-
val endEvent = SaveInstanceEnd(writer, path)
177+
val endEvent = SaveInstanceEnd(path)
178+
endEvent.writer = writer
134179
logEvent(endEvent)
135180
listenerBus.post(endEvent)
136181
}

mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala

Lines changed: 92 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, MLWri
3434
import org.apache.spark.mllib.util.MLlibTestSparkContext
3535
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
3636
import org.apache.spark.sql._
37+
import org.apache.spark.util.JsonProtocol
3738

3839

3940
class MLEventsSuite
@@ -107,20 +108,48 @@ class MLEventsSuite
107108
.setStages(Array(estimator1, transformer1, estimator2))
108109
assert(events.isEmpty)
109110
val pipelineModel = pipeline.fit(dataset1)
110-
val expected =
111-
FitStart(pipeline, dataset1) ::
112-
FitStart(estimator1, dataset1) ::
113-
FitEnd(estimator1, model1) ::
114-
TransformStart(model1, dataset1) ::
115-
TransformEnd(model1, dataset2) ::
116-
TransformStart(transformer1, dataset2) ::
117-
TransformEnd(transformer1, dataset3) ::
118-
FitStart(estimator2, dataset3) ::
119-
FitEnd(estimator2, model2) ::
120-
FitEnd(pipeline, pipelineModel) :: Nil
111+
112+
val event0 = FitStart[PipelineModel]()
113+
event0.estimator = pipeline
114+
event0.dataset = dataset1
115+
val event1 = FitStart[MyModel]()
116+
event1.estimator = estimator1
117+
event1.dataset = dataset1
118+
val event2 = FitEnd[MyModel]()
119+
event2.estimator = estimator1
120+
event2.model = model1
121+
val event3 = TransformStart()
122+
event3.transformer = model1
123+
event3.input = dataset1
124+
val event4 = TransformEnd()
125+
event4.transformer = model1
126+
event4.output = dataset2
127+
val event5 = TransformStart()
128+
event5.transformer = transformer1
129+
event5.input = dataset2
130+
val event6 = TransformEnd()
131+
event6.transformer = transformer1
132+
event6.output = dataset3
133+
val event7 = FitStart[MyModel]()
134+
event7.estimator = estimator2
135+
event7.dataset = dataset3
136+
val event8 = FitEnd[MyModel]()
137+
event8.estimator = estimator2
138+
event8.model = model2
139+
val event9 = FitEnd[PipelineModel]()
140+
event9.estimator = pipeline
141+
event9.model = pipelineModel
142+
143+
val expected = Seq(
144+
event0, event1, event2, event3, event4, event5, event6, event7, event8, event9)
121145
eventually(timeout(10 seconds), interval(1 second)) {
122146
assert(events === expected)
123147
}
148+
// Test if they can be ser/de via JSON protocol.
149+
assert(events.nonEmpty)
150+
events.map(JsonProtocol.sparkEventToJson).foreach { event =>
151+
assert(JsonProtocol.sparkEventFromJson(event).isInstanceOf[MLEvent])
152+
}
124153
}
125154

126155
test("pipeline model transform events") {
@@ -144,18 +173,41 @@ class MLEventsSuite
144173
"pipeline0", Array(transformer1, model, transformer2))
145174
assert(events.isEmpty)
146175
val output = newPipelineModel.transform(dataset1)
147-
val expected =
148-
TransformStart(newPipelineModel, dataset1) ::
149-
TransformStart(transformer1, dataset1) ::
150-
TransformEnd(transformer1, dataset2) ::
151-
TransformStart(model, dataset2) ::
152-
TransformEnd(model, dataset3) ::
153-
TransformStart(transformer2, dataset3) ::
154-
TransformEnd(transformer2, dataset4) ::
155-
TransformEnd(newPipelineModel, output) :: Nil
176+
177+
val event0 = TransformStart()
178+
event0.transformer = newPipelineModel
179+
event0.input = dataset1
180+
val event1 = TransformStart()
181+
event1.transformer = transformer1
182+
event1.input = dataset1
183+
val event2 = TransformEnd()
184+
event2.transformer = transformer1
185+
event2.output = dataset2
186+
val event3 = TransformStart()
187+
event3.transformer = model
188+
event3.input = dataset2
189+
val event4 = TransformEnd()
190+
event4.transformer = model
191+
event4.output = dataset3
192+
val event5 = TransformStart()
193+
event5.transformer = transformer2
194+
event5.input = dataset3
195+
val event6 = TransformEnd()
196+
event6.transformer = transformer2
197+
event6.output = dataset4
198+
val event7 = TransformEnd()
199+
event7.transformer = newPipelineModel
200+
event7.output = output
201+
202+
val expected = Seq(event0, event1, event2, event3, event4, event5, event6, event7)
156203
eventually(timeout(10 seconds), interval(1 second)) {
157204
assert(events === expected)
158205
}
206+
// Test if they can be ser/de via JSON protocol.
207+
assert(events.nonEmpty)
208+
events.map(JsonProtocol.sparkEventToJson).foreach { event =>
209+
assert(JsonProtocol.sparkEventFromJson(event).isInstanceOf[MLEvent])
210+
}
159211
}
160212

161213
test("pipeline read/write events") {
@@ -182,6 +234,11 @@ class MLEventsSuite
182234
case e => fail(s"Unexpected event thrown: $e")
183235
}
184236
}
237+
// Test if they can be ser/de via JSON protocol.
238+
assert(events.nonEmpty)
239+
events.map(JsonProtocol.sparkEventToJson).foreach { event =>
240+
assert(JsonProtocol.sparkEventFromJson(event).isInstanceOf[MLEvent])
241+
}
185242

186243
events.clear()
187244
val pipelineReader = Pipeline.read
@@ -202,6 +259,11 @@ class MLEventsSuite
202259
case e => fail(s"Unexpected event thrown: $e")
203260
}
204261
}
262+
// Test if they can be ser/de via JSON protocol.
263+
assert(events.nonEmpty)
264+
events.map(JsonProtocol.sparkEventToJson).foreach { event =>
265+
assert(JsonProtocol.sparkEventFromJson(event).isInstanceOf[MLEvent])
266+
}
205267
}
206268
}
207269

@@ -230,6 +292,11 @@ class MLEventsSuite
230292
case e => fail(s"Unexpected event thrown: $e")
231293
}
232294
}
295+
// Test if they can be ser/de via JSON protocol.
296+
assert(events.nonEmpty)
297+
events.map(JsonProtocol.sparkEventToJson).foreach { event =>
298+
assert(JsonProtocol.sparkEventFromJson(event).isInstanceOf[MLEvent])
299+
}
233300

234301
events.clear()
235302
val pipelineModelReader = PipelineModel.read
@@ -250,6 +317,11 @@ class MLEventsSuite
250317
case e => fail(s"Unexpected event thrown: $e")
251318
}
252319
}
320+
// Test if they can be ser/de via JSON protocol.
321+
assert(events.nonEmpty)
322+
events.map(JsonProtocol.sparkEventToJson).foreach { event =>
323+
assert(JsonProtocol.sparkEventFromJson(event).isInstanceOf[MLEvent])
324+
}
253325
}
254326
}
255327
}

0 commit comments

Comments
 (0)