@@ -23,8 +23,10 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Attribut
23
23
import org .apache .spark .sql .catalyst .plans .logical ._
24
24
import org .apache .spark .sql .catalyst .plans .physical .{ClusteredDistribution , Distribution }
25
25
import org .apache .spark .sql .execution ._
26
+ import org .apache .spark .sql .execution .streaming .GroupStateImpl .NO_TIMESTAMP
26
27
import org .apache .spark .sql .execution .streaming .state ._
27
28
import org .apache .spark .sql .streaming .{GroupStateTimeout , OutputMode }
29
+ import org .apache .spark .sql .types .IntegerType
28
30
import org .apache .spark .util .CompletionIterator
29
31
30
32
/**
@@ -60,8 +62,27 @@ case class FlatMapGroupsWithStateExec(
60
62
import GroupStateImpl ._
61
63
62
64
private val isTimeoutEnabled = timeoutConf != NoTimeout
63
- val stateManager = new FlatMapGroupsWithState_StateManager (stateEncoder, isTimeoutEnabled)
64
- val watermarkPresent = child.output.exists {
65
+ private val timestampTimeoutAttribute =
66
+ AttributeReference (" timeoutTimestamp" , dataType = IntegerType , nullable = false )()
67
+ private val stateAttributes : Seq [Attribute ] = {
68
+ val encSchemaAttribs = stateEncoder.schema.toAttributes
69
+ if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs
70
+ }
71
+ // Get the serializer for the state, taking into account whether we need to save timestamps
72
+ private val stateSerializer = {
73
+ val encoderSerializer = stateEncoder.namedExpressions
74
+ if (isTimeoutEnabled) {
75
+ encoderSerializer :+ Literal (GroupStateImpl .NO_TIMESTAMP )
76
+ } else {
77
+ encoderSerializer
78
+ }
79
+ }
80
+ // Get the deserializer for the state. Note that this must be done in the driver, as
81
+ // resolving and binding of deserializer expressions to the encoded type can be safely done
82
+ // only in the driver.
83
+ private val stateDeserializer = stateEncoder.resolveAndBind().deserializer
84
+
85
+ private val watermarkPresent = child.output.exists {
65
86
case a : Attribute if a.metadata.contains(EventTimeWatermark .delayKey) => true
66
87
case _ => false
67
88
}
@@ -92,11 +113,11 @@ case class FlatMapGroupsWithStateExec(
92
113
child.execute().mapPartitionsWithStateStore[InternalRow ](
93
114
getStateInfo,
94
115
groupingAttributes.toStructType,
95
- stateManager.stateSchema ,
116
+ stateAttributes.toStructType ,
96
117
indexOrdinal = None ,
97
118
sqlContext.sessionState,
98
119
Some (sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
99
- val processor = new InputProcessor (store)
120
+ val updater = new StateStoreUpdater (store)
100
121
101
122
// If timeout is based on event time, then filter late data based on watermark
102
123
val filteredIter = watermarkPredicateForData match {
@@ -111,7 +132,7 @@ case class FlatMapGroupsWithStateExec(
111
132
// all the data has been processed. This is to ensure that the timeout information of all
112
133
// the keys with data is updated before they are processed for timeouts.
113
134
val outputIterator =
114
- processor.processNewData (filteredIter) ++ processor.processTimedOutState ()
135
+ updater.updateStateForKeysWithData (filteredIter) ++ updater.updateStateForTimedOutKeys ()
115
136
116
137
// Return an iterator of all the rows generated by all the keys, such that when fully
117
138
// consumed, all the state updates will be committed by the state store
@@ -126,7 +147,7 @@ case class FlatMapGroupsWithStateExec(
126
147
}
127
148
128
149
/** Helper class to update the state store */
129
- class InputProcessor (store : StateStore ) {
150
+ class StateStoreUpdater (store : StateStore ) {
130
151
131
152
// Converters for translating input keys, values, output data between rows and Java objects
132
153
private val getKeyObj =
@@ -135,6 +156,14 @@ case class FlatMapGroupsWithStateExec(
135
156
ObjectOperator .deserializeRowToObject(valueDeserializer, dataAttributes)
136
157
private val getOutputRow = ObjectOperator .wrapObjectToRow(outputObjAttr.dataType)
137
158
159
+ // Converters for translating state between rows and Java objects
160
+ private val getStateObjFromRow = ObjectOperator .deserializeRowToObject(
161
+ stateDeserializer, stateAttributes)
162
+ private val getStateRowFromObj = ObjectOperator .serializeObjectToRow(stateSerializer)
163
+
164
+ // Index of the additional metadata fields in the state row
165
+ private val timeoutTimestampIndex = stateAttributes.indexOf(timestampTimeoutAttribute)
166
+
138
167
// Metrics
139
168
private val numUpdatedStateRows = longMetric(" numUpdatedStateRows" )
140
169
private val numOutputRows = longMetric(" numOutputRows" )
@@ -143,19 +172,20 @@ case class FlatMapGroupsWithStateExec(
143
172
* For every group, get the key, values and corresponding state and call the function,
144
173
* and return an iterator of rows
145
174
*/
146
- def processNewData (dataIter : Iterator [InternalRow ]): Iterator [InternalRow ] = {
175
+ def updateStateForKeysWithData (dataIter : Iterator [InternalRow ]): Iterator [InternalRow ] = {
147
176
val groupedIter = GroupedIterator (dataIter, groupingAttributes, child.output)
148
177
groupedIter.flatMap { case (keyRow, valueRowIter) =>
149
178
val keyUnsafeRow = keyRow.asInstanceOf [UnsafeRow ]
150
179
callFunctionAndUpdateState(
151
- stateManager.getState(store, keyUnsafeRow) ,
180
+ keyUnsafeRow,
152
181
valueRowIter,
182
+ store.get(keyUnsafeRow),
153
183
hasTimedOut = false )
154
184
}
155
185
}
156
186
157
187
/** Find the groups that have timeout set and are timing out right now, and call the function */
158
- def processTimedOutState (): Iterator [InternalRow ] = {
188
+ def updateStateForTimedOutKeys (): Iterator [InternalRow ] = {
159
189
if (isTimeoutEnabled) {
160
190
val timeoutThreshold = timeoutConf match {
161
191
case ProcessingTimeTimeout => batchTimestampMs.get
@@ -164,11 +194,12 @@ case class FlatMapGroupsWithStateExec(
164
194
throw new IllegalStateException (
165
195
s " Cannot filter timed out keys for $timeoutConf" )
166
196
}
167
- val timingOutKeys = stateManager.getAllState(store).filter { state =>
168
- state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold
197
+ val timingOutKeys = store.getRange(None , None ).filter { rowPair =>
198
+ val timeoutTimestamp = getTimeoutTimestamp(rowPair.value)
199
+ timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold
169
200
}
170
- timingOutKeys.flatMap { stateData =>
171
- callFunctionAndUpdateState(stateData , Iterator .empty, hasTimedOut = true )
201
+ timingOutKeys.flatMap { rowPair =>
202
+ callFunctionAndUpdateState(rowPair.key , Iterator .empty, rowPair.value , hasTimedOut = true )
172
203
}
173
204
} else Iterator .empty
174
205
}
@@ -178,44 +209,73 @@ case class FlatMapGroupsWithStateExec(
178
209
* iterator. Note that the store updating is lazy, that is, the store will be updated only
179
210
* after the returned iterator is fully consumed.
180
211
*
181
- * @param stateData All the data related to the state to be updated
212
+ * @param keyRow Row representing the key, cannot be null
182
213
* @param valueRowIter Iterator of values as rows, cannot be null, but can be empty
214
+ * @param prevStateRow Row representing the previous state, can be null
183
215
* @param hasTimedOut Whether this function is being called for a key timeout
184
216
*/
185
217
private def callFunctionAndUpdateState (
186
- stateData : FlatMapGroupsWithState_StateData ,
218
+ keyRow : UnsafeRow ,
187
219
valueRowIter : Iterator [InternalRow ],
220
+ prevStateRow : UnsafeRow ,
188
221
hasTimedOut : Boolean ): Iterator [InternalRow ] = {
189
222
190
- val keyObj = getKeyObj(stateData. keyRow) // convert key to objects
223
+ val keyObj = getKeyObj(keyRow) // convert key to objects
191
224
val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects
192
- val groupState = GroupStateImpl .createForStreaming(
193
- Option (stateData.stateObj),
225
+ val stateObj = getStateObj(prevStateRow)
226
+ val keyedState = GroupStateImpl .createForStreaming(
227
+ Option (stateObj),
194
228
batchTimestampMs.getOrElse(NO_TIMESTAMP ),
195
229
eventTimeWatermark.getOrElse(NO_TIMESTAMP ),
196
230
timeoutConf,
197
231
hasTimedOut,
198
232
watermarkPresent)
199
233
200
234
// Call function, get the returned objects and convert them to rows
201
- val mappedIterator = func(keyObj, valueObjIter, groupState ).map { obj =>
235
+ val mappedIterator = func(keyObj, valueObjIter, keyedState ).map { obj =>
202
236
numOutputRows += 1
203
237
getOutputRow(obj)
204
238
}
205
239
206
240
// When the iterator is consumed, then write changes to state
207
241
def onIteratorCompletion : Unit = {
208
- if (groupState.hasRemoved && groupState.getTimeoutTimestamp == NO_TIMESTAMP ) {
209
- stateManager.removeState(store, stateData.keyRow)
242
+
243
+ val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp
244
+ // If the state has not yet been set but timeout has been set, then
245
+ // we have to generate a row to save the timeout. However, attempting serialize
246
+ // null using case class encoder throws -
247
+ // java.lang.NullPointerException: Null value appeared in non-nullable field:
248
+ // If the schema is inferred from a Scala tuple / case class, or a Java bean, please
249
+ // try to use scala.Option[_] or other nullable types.
250
+ if (! keyedState.exists && currentTimeoutTimestamp != NO_TIMESTAMP ) {
251
+ throw new IllegalStateException (
252
+ " Cannot set timeout when state is not defined, that is, state has not been" +
253
+ " initialized or has been removed" )
254
+ }
255
+
256
+ if (keyedState.hasRemoved) {
257
+ store.remove(keyRow)
210
258
numUpdatedStateRows += 1
259
+
211
260
} else {
212
- val currentTimeoutTimestamp = groupState.getTimeoutTimestamp
213
- val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp
214
- val shouldWriteState = groupState.hasUpdated || groupState.hasRemoved || hasTimeoutChanged
261
+ val previousTimeoutTimestamp = getTimeoutTimestamp(prevStateRow)
262
+ val stateRowToWrite = if (keyedState.hasUpdated) {
263
+ getStateRow(keyedState.get)
264
+ } else {
265
+ prevStateRow
266
+ }
267
+
268
+ val hasTimeoutChanged = currentTimeoutTimestamp != previousTimeoutTimestamp
269
+ val shouldWriteState = keyedState.hasUpdated || hasTimeoutChanged
215
270
216
271
if (shouldWriteState) {
217
- val updatedStateObj = if (groupState.exists) groupState.get else null
218
- stateManager.putState(store, stateData.keyRow, updatedStateObj, currentTimeoutTimestamp)
272
+ if (stateRowToWrite == null ) {
273
+ // This should never happen because checks in GroupStateImpl should avoid cases
274
+ // where empty state would need to be written
275
+ throw new IllegalStateException (" Attempting to write empty state" )
276
+ }
277
+ setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp)
278
+ store.put(keyRow, stateRowToWrite)
219
279
numUpdatedStateRows += 1
220
280
}
221
281
}
@@ -224,5 +284,28 @@ case class FlatMapGroupsWithStateExec(
224
284
// Return an iterator of rows such that fully consumed, the updated state value will be saved
225
285
CompletionIterator [InternalRow , Iterator [InternalRow ]](mappedIterator, onIteratorCompletion)
226
286
}
287
+
288
+ /** Returns the state as Java object if defined */
289
+ def getStateObj (stateRow : UnsafeRow ): Any = {
290
+ if (stateRow != null ) getStateObjFromRow(stateRow) else null
291
+ }
292
+
293
+ /** Returns the row for an updated state */
294
+ def getStateRow (obj : Any ): UnsafeRow = {
295
+ assert(obj != null )
296
+ getStateRowFromObj(obj)
297
+ }
298
+
299
+ /** Returns the timeout timestamp of a state row is set */
300
+ def getTimeoutTimestamp (stateRow : UnsafeRow ): Long = {
301
+ if (isTimeoutEnabled && stateRow != null ) {
302
+ stateRow.getLong(timeoutTimestampIndex)
303
+ } else NO_TIMESTAMP
304
+ }
305
+
306
+ /** Set the timestamp in a state row */
307
+ def setTimeoutTimestamp (stateRow : UnsafeRow , timeoutTimestamps : Long ): Unit = {
308
+ if (isTimeoutEnabled) stateRow.setLong(timeoutTimestampIndex, timeoutTimestamps)
309
+ }
227
310
}
228
311
}
0 commit comments