@@ -34,6 +34,7 @@ import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, MLWri
34
34
import org .apache .spark .mllib .util .MLlibTestSparkContext
35
35
import org .apache .spark .scheduler .{SparkListener , SparkListenerEvent }
36
36
import org .apache .spark .sql ._
37
+ import org .apache .spark .util .JsonProtocol
37
38
38
39
39
40
class MLEventsSuite
@@ -107,20 +108,48 @@ class MLEventsSuite
107
108
.setStages(Array (estimator1, transformer1, estimator2))
108
109
assert(events.isEmpty)
109
110
val pipelineModel = pipeline.fit(dataset1)
110
- val expected =
111
- FitStart (pipeline, dataset1) ::
112
- FitStart (estimator1, dataset1) ::
113
- FitEnd (estimator1, model1) ::
114
- TransformStart (model1, dataset1) ::
115
- TransformEnd (model1, dataset2) ::
116
- TransformStart (transformer1, dataset2) ::
117
- TransformEnd (transformer1, dataset3) ::
118
- FitStart (estimator2, dataset3) ::
119
- FitEnd (estimator2, model2) ::
120
- FitEnd (pipeline, pipelineModel) :: Nil
111
+
112
+ val event0 = FitStart [PipelineModel ]()
113
+ event0.estimator = pipeline
114
+ event0.dataset = dataset1
115
+ val event1 = FitStart [MyModel ]()
116
+ event1.estimator = estimator1
117
+ event1.dataset = dataset1
118
+ val event2 = FitEnd [MyModel ]()
119
+ event2.estimator = estimator1
120
+ event2.model = model1
121
+ val event3 = TransformStart ()
122
+ event3.transformer = model1
123
+ event3.input = dataset1
124
+ val event4 = TransformEnd ()
125
+ event4.transformer = model1
126
+ event4.output = dataset2
127
+ val event5 = TransformStart ()
128
+ event5.transformer = transformer1
129
+ event5.input = dataset2
130
+ val event6 = TransformEnd ()
131
+ event6.transformer = transformer1
132
+ event6.output = dataset3
133
+ val event7 = FitStart [MyModel ]()
134
+ event7.estimator = estimator2
135
+ event7.dataset = dataset3
136
+ val event8 = FitEnd [MyModel ]()
137
+ event8.estimator = estimator2
138
+ event8.model = model2
139
+ val event9 = FitEnd [PipelineModel ]()
140
+ event9.estimator = pipeline
141
+ event9.model = pipelineModel
142
+
143
+ val expected = Seq (
144
+ event0, event1, event2, event3, event4, event5, event6, event7, event8, event9)
121
145
eventually(timeout(10 seconds), interval(1 second)) {
122
146
assert(events === expected)
123
147
}
148
+ // Test if they can be ser/de via JSON protocol.
149
+ assert(events.nonEmpty)
150
+ events.map(JsonProtocol .sparkEventToJson).foreach { event =>
151
+ assert(JsonProtocol .sparkEventFromJson(event).isInstanceOf [MLEvent ])
152
+ }
124
153
}
125
154
126
155
test(" pipeline model transform events" ) {
@@ -144,18 +173,41 @@ class MLEventsSuite
144
173
" pipeline0" , Array (transformer1, model, transformer2))
145
174
assert(events.isEmpty)
146
175
val output = newPipelineModel.transform(dataset1)
147
- val expected =
148
- TransformStart (newPipelineModel, dataset1) ::
149
- TransformStart (transformer1, dataset1) ::
150
- TransformEnd (transformer1, dataset2) ::
151
- TransformStart (model, dataset2) ::
152
- TransformEnd (model, dataset3) ::
153
- TransformStart (transformer2, dataset3) ::
154
- TransformEnd (transformer2, dataset4) ::
155
- TransformEnd (newPipelineModel, output) :: Nil
176
+
177
+ val event0 = TransformStart ()
178
+ event0.transformer = newPipelineModel
179
+ event0.input = dataset1
180
+ val event1 = TransformStart ()
181
+ event1.transformer = transformer1
182
+ event1.input = dataset1
183
+ val event2 = TransformEnd ()
184
+ event2.transformer = transformer1
185
+ event2.output = dataset2
186
+ val event3 = TransformStart ()
187
+ event3.transformer = model
188
+ event3.input = dataset2
189
+ val event4 = TransformEnd ()
190
+ event4.transformer = model
191
+ event4.output = dataset3
192
+ val event5 = TransformStart ()
193
+ event5.transformer = transformer2
194
+ event5.input = dataset3
195
+ val event6 = TransformEnd ()
196
+ event6.transformer = transformer2
197
+ event6.output = dataset4
198
+ val event7 = TransformEnd ()
199
+ event7.transformer = newPipelineModel
200
+ event7.output = output
201
+
202
+ val expected = Seq (event0, event1, event2, event3, event4, event5, event6, event7)
156
203
eventually(timeout(10 seconds), interval(1 second)) {
157
204
assert(events === expected)
158
205
}
206
+ // Test if they can be ser/de via JSON protocol.
207
+ assert(events.nonEmpty)
208
+ events.map(JsonProtocol .sparkEventToJson).foreach { event =>
209
+ assert(JsonProtocol .sparkEventFromJson(event).isInstanceOf [MLEvent ])
210
+ }
159
211
}
160
212
161
213
test(" pipeline read/write events" ) {
@@ -182,6 +234,11 @@ class MLEventsSuite
182
234
case e => fail(s " Unexpected event thrown: $e" )
183
235
}
184
236
}
237
+ // Test if they can be ser/de via JSON protocol.
238
+ assert(events.nonEmpty)
239
+ events.map(JsonProtocol .sparkEventToJson).foreach { event =>
240
+ assert(JsonProtocol .sparkEventFromJson(event).isInstanceOf [MLEvent ])
241
+ }
185
242
186
243
events.clear()
187
244
val pipelineReader = Pipeline .read
@@ -202,6 +259,11 @@ class MLEventsSuite
202
259
case e => fail(s " Unexpected event thrown: $e" )
203
260
}
204
261
}
262
+ // Test if they can be ser/de via JSON protocol.
263
+ assert(events.nonEmpty)
264
+ events.map(JsonProtocol .sparkEventToJson).foreach { event =>
265
+ assert(JsonProtocol .sparkEventFromJson(event).isInstanceOf [MLEvent ])
266
+ }
205
267
}
206
268
}
207
269
@@ -230,6 +292,11 @@ class MLEventsSuite
230
292
case e => fail(s " Unexpected event thrown: $e" )
231
293
}
232
294
}
295
+ // Test if they can be ser/de via JSON protocol.
296
+ assert(events.nonEmpty)
297
+ events.map(JsonProtocol .sparkEventToJson).foreach { event =>
298
+ assert(JsonProtocol .sparkEventFromJson(event).isInstanceOf [MLEvent ])
299
+ }
233
300
234
301
events.clear()
235
302
val pipelineModelReader = PipelineModel .read
@@ -250,6 +317,11 @@ class MLEventsSuite
250
317
case e => fail(s " Unexpected event thrown: $e" )
251
318
}
252
319
}
320
+ // Test if they can be ser/de via JSON protocol.
321
+ assert(events.nonEmpty)
322
+ events.map(JsonProtocol .sparkEventToJson).foreach { event =>
323
+ assert(JsonProtocol .sparkEventFromJson(event).isInstanceOf [MLEvent ])
324
+ }
253
325
}
254
326
}
255
327
}
0 commit comments