Skip to content

Commit a2c1977

Browse files
committed
Implement the new state format for stream-stream join
1 parent 62264ab commit a2c1977

File tree

6 files changed

+1372
-382
lines changed

6 files changed

+1372
-382
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala

Lines changed: 62 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ case class StreamingSymmetricHashJoinExec(
433433
}
434434

435435
val initIterFn = { () =>
436-
val removedRowIter = joinerManager.leftSideJoiner.removeOldState()
436+
val removedRowIter = joinerManager.leftSideJoiner.removeAndReturnOldState()
437437
removedRowIter.filterNot { kv =>
438438
stateFormatVersion match {
439439
case 1 => matchesWithRightSideState(new UnsafeRowPair(kv.key, kv.value))
@@ -459,7 +459,7 @@ case class StreamingSymmetricHashJoinExec(
459459
}
460460

461461
val initIterFn = { () =>
462-
val removedRowIter = joinerManager.rightSideJoiner.removeOldState()
462+
val removedRowIter = joinerManager.rightSideJoiner.removeAndReturnOldState()
463463
removedRowIter.filterNot { kv =>
464464
stateFormatVersion match {
465465
case 1 => matchesWithLeftSideState(new UnsafeRowPair(kv.key, kv.value))
@@ -484,13 +484,13 @@ case class StreamingSymmetricHashJoinExec(
484484
}
485485

486486
val leftSideInitIterFn = { () =>
487-
val removedRowIter = joinerManager.leftSideJoiner.removeOldState()
487+
val removedRowIter = joinerManager.leftSideJoiner.removeAndReturnOldState()
488488
removedRowIter.filterNot(isKeyToValuePairMatched)
489489
.map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
490490
}
491491

492492
val rightSideInitIterFn = { () =>
493-
val removedRowIter = joinerManager.rightSideJoiner.removeOldState()
493+
val removedRowIter = joinerManager.rightSideJoiner.removeAndReturnOldState()
494494
removedRowIter.filterNot(isKeyToValuePairMatched)
495495
.map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))
496496
}
@@ -539,22 +539,19 @@ case class StreamingSymmetricHashJoinExec(
539539
// the outer side (e.g., left side for left outer join) while generating the outer "null"
540540
// outputs. Now, we have to remove unnecessary state rows from the other side (e.g., right
541541
// side for the left outer join) if possible. In all cases, nothing needs to be outputted,
542-
// hence the removal needs to be done greedily by immediately consuming the returned
543-
// iterator.
542+
// hence the removal needs to be done greedily.
544543
//
545544
// For full outer joins, we have already removed unnecessary states from both sides, so
546545
// nothing needs to be outputted here.
547-
val cleanupIter = joinType match {
548-
case Inner | LeftSemi => joinerManager.removeOldState()
549-
case LeftOuter => joinerManager.rightSideJoiner.removeOldState()
550-
case RightOuter => joinerManager.leftSideJoiner.removeOldState()
551-
case FullOuter => Iterator.empty
552-
case _ => throwBadJoinTypeException()
553-
}
554-
while (cleanupIter.hasNext) {
555-
cleanupIter.next()
556-
numRemovedStateRows += 1
557-
}
546+
numRemovedStateRows += (
547+
joinType match {
548+
case Inner | LeftSemi => joinerManager.removeOldState()
549+
case LeftOuter => joinerManager.rightSideJoiner.removeOldState()
550+
case RightOuter => joinerManager.leftSideJoiner.removeOldState()
551+
case FullOuter => 0L
552+
case _ => throwBadJoinTypeException()
553+
}
554+
)
558555
}
559556

560557
// Commit all state changes and update state store metrics
@@ -643,7 +640,7 @@ case class StreamingSymmetricHashJoinExec(
643640
private[this] val keyGenerator = UnsafeProjection.create(joinKeys, inputAttributes)
644641

645642
private[this] val stateKeyWatermarkPredicateFunc = stateWatermarkPredicate match {
646-
case Some(JoinStateKeyWatermarkPredicate(expr)) =>
643+
case Some(JoinStateKeyWatermarkPredicate(expr, _)) =>
647644
// inputSchema can be empty as expr should only have BoundReferences and does not require
648645
// the schema to generated predicate. See [[StreamingSymmetricHashJoinHelper]].
649646
Predicate.create(expr, Seq.empty).eval _
@@ -652,7 +649,7 @@ case class StreamingSymmetricHashJoinExec(
652649
}
653650

654651
private[this] val stateValueWatermarkPredicateFunc = stateWatermarkPredicate match {
655-
case Some(JoinStateValueWatermarkPredicate(expr)) =>
652+
case Some(JoinStateValueWatermarkPredicate(expr, _)) =>
656653
Predicate.create(expr, inputAttributes).eval _
657654
case _ =>
658655
Predicate.create(Literal(false), Seq.empty).eval _ // false = do not remove if no predicate
@@ -792,6 +789,29 @@ case class StreamingSymmetricHashJoinExec(
792789
joinStateManager.get(key)
793790
}
794791

792+
// FIXME: doc!
793+
def removeOldState(): Long = {
794+
stateWatermarkPredicate match {
795+
case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark)) =>
796+
joinStateManager match {
797+
case s: SupportsEvictByCondition =>
798+
s.evictByKeyCondition(stateKeyWatermarkPredicateFunc)
799+
800+
case s: SupportsEvictByTimestamp =>
801+
s.evictByTimestamp(stateWatermark)
802+
}
803+
case Some(JoinStateValueWatermarkPredicate(_, stateWatermark)) =>
804+
joinStateManager match {
805+
case s: SupportsEvictByCondition =>
806+
s.evictByValueCondition(stateValueWatermarkPredicateFunc)
807+
808+
case s: SupportsEvictByTimestamp =>
809+
s.evictByTimestamp(stateWatermark)
810+
}
811+
case _ => 0L
812+
}
813+
}
814+
795815
/**
796816
* Builds an iterator over old state key-value pairs, removing them lazily as they're produced.
797817
*
@@ -802,12 +822,24 @@ case class StreamingSymmetricHashJoinExec(
802822
* We do this to avoid requiring either two passes or full materialization when
803823
* processing the rows for outer join.
804824
*/
805-
def removeOldState(): Iterator[KeyToValuePair] = {
825+
def removeAndReturnOldState(): Iterator[KeyToValuePair] = {
806826
stateWatermarkPredicate match {
807-
case Some(JoinStateKeyWatermarkPredicate(expr)) =>
808-
joinStateManager.removeByKeyCondition(stateKeyWatermarkPredicateFunc)
809-
case Some(JoinStateValueWatermarkPredicate(expr)) =>
810-
joinStateManager.removeByValueCondition(stateValueWatermarkPredicateFunc)
827+
case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark)) =>
828+
joinStateManager match {
829+
case s: SupportsEvictByCondition =>
830+
s.evictAndReturnByKeyCondition(stateKeyWatermarkPredicateFunc)
831+
832+
case s: SupportsEvictByTimestamp =>
833+
s.evictAndReturnByTimestamp(stateWatermark)
834+
}
835+
case Some(JoinStateValueWatermarkPredicate(_, stateWatermark)) =>
836+
joinStateManager match {
837+
case s: SupportsEvictByCondition =>
838+
s.evictAndReturnByValueCondition(stateValueWatermarkPredicateFunc)
839+
840+
case s: SupportsEvictByTimestamp =>
841+
s.evictAndReturnByTimestamp(stateWatermark)
842+
}
811843
case _ => Iterator.empty
812844
}
813845
}
@@ -836,8 +868,12 @@ case class StreamingSymmetricHashJoinExec(
836868
private case class OneSideHashJoinerManager(
837869
leftSideJoiner: OneSideHashJoiner, rightSideJoiner: OneSideHashJoiner) {
838870

839-
def removeOldState(): Iterator[KeyToValuePair] = {
840-
leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState()
871+
def removeOldState(): Long = {
872+
leftSideJoiner.removeOldState() + rightSideJoiner.removeOldState()
873+
}
874+
875+
def removeAndReturnOldState(): Iterator[KeyToValuePair] = {
876+
leftSideJoiner.removeAndReturnOldState() ++ rightSideJoiner.removeAndReturnOldState()
841877
}
842878

843879
def metrics: StateStoreMetrics = {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,12 @@ object StreamingSymmetricHashJoinHelper extends Logging {
4646
override def toString: String = s"$desc: $expr"
4747
}
4848
/** Predicate for watermark on state keys */
49-
case class JoinStateKeyWatermarkPredicate(expr: Expression)
49+
case class JoinStateKeyWatermarkPredicate(expr: Expression, stateWatermark: Long)
5050
extends JoinStateWatermarkPredicate {
5151
def desc: String = "key predicate"
5252
}
5353
/** Predicate for watermark on state values */
54-
case class JoinStateValueWatermarkPredicate(expr: Expression)
54+
case class JoinStateValueWatermarkPredicate(expr: Expression, stateWatermark: Long)
5555
extends JoinStateWatermarkPredicate {
5656
def desc: String = "value predicate"
5757
}
@@ -212,8 +212,11 @@ object StreamingSymmetricHashJoinHelper extends Logging {
212212
oneSideJoinKeys(joinKeyOrdinalForWatermark.get).dataType,
213213
oneSideJoinKeys(joinKeyOrdinalForWatermark.get).nullable)
214214
val expr = watermarkExpression(Some(keyExprWithWatermark), eventTimeWatermarkForEviction)
215-
expr.map(JoinStateKeyWatermarkPredicate.apply _)
216-
215+
expr.map { e =>
216+
// watermarkExpression only provides the expression when eventTimeWatermarkForEviction
217+
// is defined
218+
JoinStateKeyWatermarkPredicate(e, eventTimeWatermarkForEviction.get)
219+
}
217220
} else if (isWatermarkDefinedOnInput) { // case 2 in the StreamingSymmetricHashJoinExec docs
218221
val stateValueWatermark = StreamingJoinHelper.getStateValueWatermark(
219222
attributesToFindStateWatermarkFor = AttributeSet(oneSideInputAttributes),
@@ -222,8 +225,11 @@ object StreamingSymmetricHashJoinHelper extends Logging {
222225
eventTimeWatermarkForEviction)
223226
val inputAttributeWithWatermark = oneSideInputAttributes.find(_.metadata.contains(delayKey))
224227
val expr = watermarkExpression(inputAttributeWithWatermark, stateValueWatermark)
225-
expr.map(JoinStateValueWatermarkPredicate.apply _)
226-
228+
expr.map { e =>
229+
// watermarkExpression only provides the expression when eventTimeWatermarkForEviction
230+
// is defined
231+
JoinStateValueWatermarkPredicate(e, stateValueWatermark.get)
232+
}
227233
} else {
228234
None
229235
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.streaming.operators.stateful.join
19+
20+
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal, UnsafeProjection, UnsafeRow}
21+
import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager.ValueAndMatchPair
22+
import org.apache.spark.sql.types.BooleanType
23+
24+
/**
25+
* Converter between the value row stored in state store and the (actual value, match) pair.
26+
*/
27+
trait StreamingSymmetricHashJoinValueRowConverter {
28+
/** Defines the schema of the value row (the value side of K-V in state store). */
29+
def valueAttributes: Seq[Attribute]
30+
31+
/**
32+
* Convert the value row to (actual value, match) pair.
33+
*
34+
* NOTE: implementations should ensure the result row is NOT reused during execution, so
35+
* that caller can safely read the value in any time.
36+
*/
37+
def convertValue(value: UnsafeRow): ValueAndMatchPair
38+
39+
/**
40+
* Build the value row from (actual value, match) pair. This is expected to be called just
41+
* before storing to the state store.
42+
*
43+
* NOTE: depending on the implementation, the result row "may" be reused during execution
44+
* (to avoid initialization of object), so the caller should ensure that the logic doesn't
45+
* affect by such behavior. Call copy() against the result row if needed.
46+
*/
47+
def convertToValueRow(value: UnsafeRow, matched: Boolean): UnsafeRow
48+
}
49+
50+
class StreamingSymmetricHashJoinValueRowConverterFormatV1(
51+
inputValueAttributes: Seq[Attribute]) extends StreamingSymmetricHashJoinValueRowConverter {
52+
override val valueAttributes: Seq[Attribute] = inputValueAttributes
53+
54+
override def convertValue(value: UnsafeRow): ValueAndMatchPair = {
55+
if (value != null) ValueAndMatchPair(value, false) else null
56+
}
57+
58+
override def convertToValueRow(value: UnsafeRow, matched: Boolean): UnsafeRow = value
59+
}
60+
61+
class StreamingSymmetricHashJoinValueRowConverterFormatV2(
62+
inputValueAttributes: Seq[Attribute]) extends StreamingSymmetricHashJoinValueRowConverter {
63+
private val valueWithMatchedExprs = inputValueAttributes :+ Literal(true)
64+
private val indexOrdinalInValueWithMatchedRow = inputValueAttributes.size
65+
66+
private val valueWithMatchedRowGenerator = UnsafeProjection.create(valueWithMatchedExprs,
67+
inputValueAttributes)
68+
69+
override val valueAttributes: Seq[Attribute] = inputValueAttributes :+
70+
AttributeReference("matched", BooleanType)()
71+
72+
// Projection to generate key row from (value + matched) row
73+
private val valueRowGenerator = UnsafeProjection.create(
74+
inputValueAttributes, valueAttributes)
75+
76+
override def convertValue(value: UnsafeRow): ValueAndMatchPair = {
77+
if (value != null) {
78+
ValueAndMatchPair(valueRowGenerator(value).copy(),
79+
value.getBoolean(indexOrdinalInValueWithMatchedRow))
80+
} else {
81+
null
82+
}
83+
}
84+
85+
override def convertToValueRow(value: UnsafeRow, matched: Boolean): UnsafeRow = {
86+
val row = valueWithMatchedRowGenerator(value)
87+
row.setBoolean(indexOrdinalInValueWithMatchedRow, matched)
88+
row
89+
}
90+
}
91+
92+
object StreamingSymmetricHashJoinValueRowConverter {
93+
def create(
94+
inputValueAttributes: Seq[Attribute],
95+
stateFormatVersion: Int): StreamingSymmetricHashJoinValueRowConverter = {
96+
stateFormatVersion match {
97+
case 1 => new StreamingSymmetricHashJoinValueRowConverterFormatV1(inputValueAttributes)
98+
case 2 | 3 | 4 =>
99+
new StreamingSymmetricHashJoinValueRowConverterFormatV2(inputValueAttributes)
100+
case _ => throw new IllegalArgumentException ("Incorrect state format version! " +
101+
s"version $stateFormatVersion")
102+
}
103+
}
104+
}

0 commit comments

Comments
 (0)