Skip to content

Commit b11869b

Browse files
tdaszsxwing
authored andcommitted
[SPARK-22187][SS][REVERT] Revert change in state row format for mapGroupsWithState
## What changes were proposed in this pull request? #19416 changed the format in which rows were encoded in the state store. However, this can break existing streaming queries with the old format in unpredictable ways (potentially crashing the JVM). Hence I am reverting this for now. This will be re-applied in the future after we start saving more metadata in checkpoints to signify which version of state row format the existing streaming query is running. Then we can decode old and new formats accordingly. ## How was this patch tested? Existing tests. Author: Tathagata Das <[email protected]> Closes #19924 from tdas/SPARK-22187-1.
1 parent 0ba8f4b commit b11869b

File tree

3 files changed

+171
-247
lines changed

3 files changed

+171
-247
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala

Lines changed: 109 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Attribut
2323
import org.apache.spark.sql.catalyst.plans.logical._
2424
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution}
2525
import org.apache.spark.sql.execution._
26+
import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
2627
import org.apache.spark.sql.execution.streaming.state._
2728
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
29+
import org.apache.spark.sql.types.IntegerType
2830
import org.apache.spark.util.CompletionIterator
2931

3032
/**
@@ -60,8 +62,27 @@ case class FlatMapGroupsWithStateExec(
6062
import GroupStateImpl._
6163

6264
private val isTimeoutEnabled = timeoutConf != NoTimeout
63-
val stateManager = new FlatMapGroupsWithState_StateManager(stateEncoder, isTimeoutEnabled)
64-
val watermarkPresent = child.output.exists {
65+
private val timestampTimeoutAttribute =
66+
AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)()
67+
private val stateAttributes: Seq[Attribute] = {
68+
val encSchemaAttribs = stateEncoder.schema.toAttributes
69+
if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs
70+
}
71+
// Get the serializer for the state, taking into account whether we need to save timestamps
72+
private val stateSerializer = {
73+
val encoderSerializer = stateEncoder.namedExpressions
74+
if (isTimeoutEnabled) {
75+
encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP)
76+
} else {
77+
encoderSerializer
78+
}
79+
}
80+
// Get the deserializer for the state. Note that this must be done in the driver, as
81+
// resolving and binding of deserializer expressions to the encoded type can be safely done
82+
// only in the driver.
83+
private val stateDeserializer = stateEncoder.resolveAndBind().deserializer
84+
85+
private val watermarkPresent = child.output.exists {
6586
case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true
6687
case _ => false
6788
}
@@ -92,11 +113,11 @@ case class FlatMapGroupsWithStateExec(
92113
child.execute().mapPartitionsWithStateStore[InternalRow](
93114
getStateInfo,
94115
groupingAttributes.toStructType,
95-
stateManager.stateSchema,
116+
stateAttributes.toStructType,
96117
indexOrdinal = None,
97118
sqlContext.sessionState,
98119
Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
99-
val processor = new InputProcessor(store)
120+
val updater = new StateStoreUpdater(store)
100121

101122
// If timeout is based on event time, then filter late data based on watermark
102123
val filteredIter = watermarkPredicateForData match {
@@ -111,7 +132,7 @@ case class FlatMapGroupsWithStateExec(
111132
// all the data has been processed. This is to ensure that the timeout information of all
112133
// the keys with data is updated before they are processed for timeouts.
113134
val outputIterator =
114-
processor.processNewData(filteredIter) ++ processor.processTimedOutState()
135+
updater.updateStateForKeysWithData(filteredIter) ++ updater.updateStateForTimedOutKeys()
115136

116137
// Return an iterator of all the rows generated by all the keys, such that when fully
117138
// consumed, all the state updates will be committed by the state store
@@ -126,7 +147,7 @@ case class FlatMapGroupsWithStateExec(
126147
}
127148

128149
/** Helper class to update the state store */
129-
class InputProcessor(store: StateStore) {
150+
class StateStoreUpdater(store: StateStore) {
130151

131152
// Converters for translating input keys, values, output data between rows and Java objects
132153
private val getKeyObj =
@@ -135,6 +156,14 @@ case class FlatMapGroupsWithStateExec(
135156
ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
136157
private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
137158

159+
// Converters for translating state between rows and Java objects
160+
private val getStateObjFromRow = ObjectOperator.deserializeRowToObject(
161+
stateDeserializer, stateAttributes)
162+
private val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer)
163+
164+
// Index of the additional metadata fields in the state row
165+
private val timeoutTimestampIndex = stateAttributes.indexOf(timestampTimeoutAttribute)
166+
138167
// Metrics
139168
private val numUpdatedStateRows = longMetric("numUpdatedStateRows")
140169
private val numOutputRows = longMetric("numOutputRows")
@@ -143,19 +172,20 @@ case class FlatMapGroupsWithStateExec(
143172
* For every group, get the key, values and corresponding state and call the function,
144173
* and return an iterator of rows
145174
*/
146-
def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = {
175+
def updateStateForKeysWithData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = {
147176
val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output)
148177
groupedIter.flatMap { case (keyRow, valueRowIter) =>
149178
val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow]
150179
callFunctionAndUpdateState(
151-
stateManager.getState(store, keyUnsafeRow),
180+
keyUnsafeRow,
152181
valueRowIter,
182+
store.get(keyUnsafeRow),
153183
hasTimedOut = false)
154184
}
155185
}
156186

157187
/** Find the groups that have timeout set and are timing out right now, and call the function */
158-
def processTimedOutState(): Iterator[InternalRow] = {
188+
def updateStateForTimedOutKeys(): Iterator[InternalRow] = {
159189
if (isTimeoutEnabled) {
160190
val timeoutThreshold = timeoutConf match {
161191
case ProcessingTimeTimeout => batchTimestampMs.get
@@ -164,11 +194,12 @@ case class FlatMapGroupsWithStateExec(
164194
throw new IllegalStateException(
165195
s"Cannot filter timed out keys for $timeoutConf")
166196
}
167-
val timingOutKeys = stateManager.getAllState(store).filter { state =>
168-
state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold
197+
val timingOutKeys = store.getRange(None, None).filter { rowPair =>
198+
val timeoutTimestamp = getTimeoutTimestamp(rowPair.value)
199+
timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold
169200
}
170-
timingOutKeys.flatMap { stateData =>
171-
callFunctionAndUpdateState(stateData, Iterator.empty, hasTimedOut = true)
201+
timingOutKeys.flatMap { rowPair =>
202+
callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true)
172203
}
173204
} else Iterator.empty
174205
}
@@ -178,44 +209,73 @@ case class FlatMapGroupsWithStateExec(
178209
* iterator. Note that the store updating is lazy, that is, the store will be updated only
179210
* after the returned iterator is fully consumed.
180211
*
181-
* @param stateData All the data related to the state to be updated
212+
* @param keyRow Row representing the key, cannot be null
182213
* @param valueRowIter Iterator of values as rows, cannot be null, but can be empty
214+
* @param prevStateRow Row representing the previous state, can be null
183215
* @param hasTimedOut Whether this function is being called for a key timeout
184216
*/
185217
private def callFunctionAndUpdateState(
186-
stateData: FlatMapGroupsWithState_StateData,
218+
keyRow: UnsafeRow,
187219
valueRowIter: Iterator[InternalRow],
220+
prevStateRow: UnsafeRow,
188221
hasTimedOut: Boolean): Iterator[InternalRow] = {
189222

190-
val keyObj = getKeyObj(stateData.keyRow) // convert key to objects
223+
val keyObj = getKeyObj(keyRow) // convert key to objects
191224
val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects
192-
val groupState = GroupStateImpl.createForStreaming(
193-
Option(stateData.stateObj),
225+
val stateObj = getStateObj(prevStateRow)
226+
val keyedState = GroupStateImpl.createForStreaming(
227+
Option(stateObj),
194228
batchTimestampMs.getOrElse(NO_TIMESTAMP),
195229
eventTimeWatermark.getOrElse(NO_TIMESTAMP),
196230
timeoutConf,
197231
hasTimedOut,
198232
watermarkPresent)
199233

200234
// Call function, get the returned objects and convert them to rows
201-
val mappedIterator = func(keyObj, valueObjIter, groupState).map { obj =>
235+
val mappedIterator = func(keyObj, valueObjIter, keyedState).map { obj =>
202236
numOutputRows += 1
203237
getOutputRow(obj)
204238
}
205239

206240
// When the iterator is consumed, then write changes to state
207241
def onIteratorCompletion: Unit = {
208-
if (groupState.hasRemoved && groupState.getTimeoutTimestamp == NO_TIMESTAMP) {
209-
stateManager.removeState(store, stateData.keyRow)
242+
243+
val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp
244+
// If the state has not yet been set but timeout has been set, then
245+
// we have to generate a row to save the timeout. However, attempting serialize
246+
// null using case class encoder throws -
247+
// java.lang.NullPointerException: Null value appeared in non-nullable field:
248+
// If the schema is inferred from a Scala tuple / case class, or a Java bean, please
249+
// try to use scala.Option[_] or other nullable types.
250+
if (!keyedState.exists && currentTimeoutTimestamp != NO_TIMESTAMP) {
251+
throw new IllegalStateException(
252+
"Cannot set timeout when state is not defined, that is, state has not been" +
253+
"initialized or has been removed")
254+
}
255+
256+
if (keyedState.hasRemoved) {
257+
store.remove(keyRow)
210258
numUpdatedStateRows += 1
259+
211260
} else {
212-
val currentTimeoutTimestamp = groupState.getTimeoutTimestamp
213-
val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp
214-
val shouldWriteState = groupState.hasUpdated || groupState.hasRemoved || hasTimeoutChanged
261+
val previousTimeoutTimestamp = getTimeoutTimestamp(prevStateRow)
262+
val stateRowToWrite = if (keyedState.hasUpdated) {
263+
getStateRow(keyedState.get)
264+
} else {
265+
prevStateRow
266+
}
267+
268+
val hasTimeoutChanged = currentTimeoutTimestamp != previousTimeoutTimestamp
269+
val shouldWriteState = keyedState.hasUpdated || hasTimeoutChanged
215270

216271
if (shouldWriteState) {
217-
val updatedStateObj = if (groupState.exists) groupState.get else null
218-
stateManager.putState(store, stateData.keyRow, updatedStateObj, currentTimeoutTimestamp)
272+
if (stateRowToWrite == null) {
273+
// This should never happen because checks in GroupStateImpl should avoid cases
274+
// where empty state would need to be written
275+
throw new IllegalStateException("Attempting to write empty state")
276+
}
277+
setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp)
278+
store.put(keyRow, stateRowToWrite)
219279
numUpdatedStateRows += 1
220280
}
221281
}
@@ -224,5 +284,28 @@ case class FlatMapGroupsWithStateExec(
224284
// Return an iterator of rows such that fully consumed, the updated state value will be saved
225285
CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion)
226286
}
287+
288+
/** Returns the state as Java object if defined */
289+
def getStateObj(stateRow: UnsafeRow): Any = {
290+
if (stateRow != null) getStateObjFromRow(stateRow) else null
291+
}
292+
293+
/** Returns the row for an updated state */
294+
def getStateRow(obj: Any): UnsafeRow = {
295+
assert(obj != null)
296+
getStateRowFromObj(obj)
297+
}
298+
299+
/** Returns the timeout timestamp of a state row is set */
300+
def getTimeoutTimestamp(stateRow: UnsafeRow): Long = {
301+
if (isTimeoutEnabled && stateRow != null) {
302+
stateRow.getLong(timeoutTimestampIndex)
303+
} else NO_TIMESTAMP
304+
}
305+
306+
/** Set the timestamp in a state row */
307+
def setTimeoutTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = {
308+
if (isTimeoutEnabled) stateRow.setLong(timeoutTimestampIndex, timeoutTimestamps)
309+
}
227310
}
228311
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala

Lines changed: 0 additions & 153 deletions
This file was deleted.

0 commit comments

Comments
 (0)