Skip to content

Commit 1cc66a0

Browse files
jose-torrestdas
authored andcommitted
[SPARK-23687][SS] Add a memory source for continuous processing.
## What changes were proposed in this pull request? Add a memory source for continuous processing. Note that only one of the ContinuousSuite tests is migrated to minimize the diff here. I'll submit a second PR for SPARK-23688 to change the rest and get rid of waitForRateSourceTriggers. ## How was this patch tested? unit test Author: Jose Torres <[email protected]> Closes apache#20828 from jose-torres/continuousMemory.
1 parent 14844a6 commit 1cc66a0

File tree

5 files changed

+266
-44
lines changed

5 files changed

+266
-44
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
3131
import org.apache.spark.sql.execution.SQLExecution
3232
import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2}
3333
import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _}
34+
import org.apache.spark.sql.sources.v2
3435
import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport}
3536
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset}
3637
import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger}
@@ -317,8 +318,10 @@ class ContinuousExecution(
317318
synchronized {
318319
if (queryExecutionThread.isAlive) {
319320
commitLog.add(epoch)
320-
val offset = offsetLog.get(epoch).get.offsets(0).get
321+
val offset =
322+
continuousSources(0).deserializeOffset(offsetLog.get(epoch).get.offsets(0).get.json)
321323
committedOffsets ++= Seq(continuousSources(0) -> offset)
324+
continuousSources(0).commit(offset.asInstanceOf[v2.reader.streaming.Offset])
322325
} else {
323326
return
324327
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,19 @@ import javax.annotation.concurrent.GuardedBy
2424

2525
import scala.collection.JavaConverters._
2626
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
27+
import scala.reflect.ClassTag
2728
import scala.util.control.NonFatal
2829

2930
import org.apache.spark.internal.Logging
3031
import org.apache.spark.sql._
31-
import org.apache.spark.sql.catalyst.encoders.encoderFor
32+
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
3233
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
33-
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
34+
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
3435
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
36+
import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
3537
import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow}
3638
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
37-
import org.apache.spark.sql.streaming.OutputMode
39+
import org.apache.spark.sql.streaming.{OutputMode, Trigger}
3840
import org.apache.spark.sql.types.StructType
3941
import org.apache.spark.util.Utils
4042

@@ -47,16 +49,43 @@ object MemoryStream {
4749
new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext)
4850
}
4951

52+
/**
53+
* A base class for memory stream implementations. Supports adding data and resetting.
54+
*/
55+
abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends BaseStreamingSource {
56+
protected val encoder = encoderFor[A]
57+
protected val attributes = encoder.schema.toAttributes
58+
59+
def toDS(): Dataset[A] = {
60+
Dataset[A](sqlContext.sparkSession, logicalPlan)
61+
}
62+
63+
def toDF(): DataFrame = {
64+
Dataset.ofRows(sqlContext.sparkSession, logicalPlan)
65+
}
66+
67+
def addData(data: A*): Offset = {
68+
addData(data.toTraversable)
69+
}
70+
71+
def readSchema(): StructType = encoder.schema
72+
73+
protected def logicalPlan: LogicalPlan
74+
75+
def addData(data: TraversableOnce[A]): Offset
76+
}
77+
5078
/**
5179
* A [[Source]] that produces value stored in memory as they are added by the user. This [[Source]]
5280
* is intended for use in unit tests as it can only replay data when the object is still
5381
* available.
5482
*/
5583
case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
56-
extends MicroBatchReader with SupportsScanUnsafeRow with Logging {
57-
protected val encoder = encoderFor[A]
58-
private val attributes = encoder.schema.toAttributes
59-
protected val logicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession)
84+
extends MemoryStreamBase[A](sqlContext)
85+
with MicroBatchReader with SupportsScanUnsafeRow with Logging {
86+
87+
protected val logicalPlan: LogicalPlan =
88+
StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession)
6089
protected val output = logicalPlan.output
6190

6291
/**
@@ -70,7 +99,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
7099
protected var currentOffset: LongOffset = new LongOffset(-1)
71100

72101
@GuardedBy("this")
73-
private var startOffset = new LongOffset(-1)
102+
protected var startOffset = new LongOffset(-1)
74103

75104
@GuardedBy("this")
76105
private var endOffset = new LongOffset(-1)
@@ -82,18 +111,6 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
82111
@GuardedBy("this")
83112
protected var lastOffsetCommitted : LongOffset = new LongOffset(-1)
84113

85-
def toDS(): Dataset[A] = {
86-
Dataset(sqlContext.sparkSession, logicalPlan)
87-
}
88-
89-
def toDF(): DataFrame = {
90-
Dataset.ofRows(sqlContext.sparkSession, logicalPlan)
91-
}
92-
93-
def addData(data: A*): Offset = {
94-
addData(data.toTraversable)
95-
}
96-
97114
def addData(data: TraversableOnce[A]): Offset = {
98115
val objects = data.toSeq
99116
val rows = objects.iterator.map(d => encoder.toRow(d).copy().asInstanceOf[UnsafeRow]).toArray
@@ -114,8 +131,6 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
114131
}
115132
}
116133

117-
override def readSchema(): StructType = encoder.schema
118-
119134
override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong)
120135

121136
override def getStartOffset: OffsetV2 = synchronized {
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.streaming.sources
19+
20+
import java.{util => ju}
21+
import java.util.Optional
22+
import java.util.concurrent.atomic.AtomicInteger
23+
import javax.annotation.concurrent.GuardedBy
24+
25+
import scala.collection.JavaConverters._
26+
import scala.collection.mutable.ListBuffer
27+
28+
import org.json4s.NoTypeHints
29+
import org.json4s.jackson.Serialization
30+
31+
import org.apache.spark.SparkEnv
32+
import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
33+
import org.apache.spark.sql.{Encoder, Row, SQLContext}
34+
import org.apache.spark.sql.execution.streaming._
35+
import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream.GetRecord
36+
import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions}
37+
import org.apache.spark.sql.sources.v2.reader.DataReaderFactory
38+
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset}
39+
import org.apache.spark.sql.types.StructType
40+
import org.apache.spark.util.RpcUtils
41+
42+
/**
43+
* The overall strategy here is:
44+
* * ContinuousMemoryStream maintains a list of records for each partition. addData() will
45+
* distribute records evenly-ish across partitions.
46+
* * RecordEndpoint is set up as an endpoint for executor-side
47+
* ContinuousMemoryStreamDataReader instances to poll. It returns the record at the specified
48+
* offset within the list, or null if that offset doesn't yet have a record.
49+
*/
50+
class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
51+
extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport {
52+
private implicit val formats = Serialization.formats(NoTypeHints)
53+
private val NUM_PARTITIONS = 2
54+
55+
protected val logicalPlan =
56+
StreamingRelationV2(this, "memory", Map(), attributes, None)(sqlContext.sparkSession)
57+
58+
// ContinuousReader implementation
59+
60+
@GuardedBy("this")
61+
private val records = Seq.fill(NUM_PARTITIONS)(new ListBuffer[A])
62+
63+
@GuardedBy("this")
64+
private var startOffset: ContinuousMemoryStreamOffset = _
65+
66+
private val recordEndpoint = new RecordEndpoint()
67+
@volatile private var endpointRef: RpcEndpointRef = _
68+
69+
def addData(data: TraversableOnce[A]): Offset = synchronized {
70+
// Distribute data evenly among partition lists.
71+
data.toSeq.zipWithIndex.map {
72+
case (item, index) => records(index % NUM_PARTITIONS) += item
73+
}
74+
75+
// The new target offset is the offset where all records in all partitions have been processed.
76+
ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, records(i).size)).toMap)
77+
}
78+
79+
override def setStartOffset(start: Optional[Offset]): Unit = synchronized {
80+
// Inferred initial offset is position 0 in each partition.
81+
startOffset = start.orElse {
82+
ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, 0)).toMap)
83+
}.asInstanceOf[ContinuousMemoryStreamOffset]
84+
}
85+
86+
override def getStartOffset: Offset = synchronized {
87+
startOffset
88+
}
89+
90+
override def deserializeOffset(json: String): ContinuousMemoryStreamOffset = {
91+
ContinuousMemoryStreamOffset(Serialization.read[Map[Int, Int]](json))
92+
}
93+
94+
override def mergeOffsets(offsets: Array[PartitionOffset]): ContinuousMemoryStreamOffset = {
95+
ContinuousMemoryStreamOffset(
96+
offsets.map {
97+
case ContinuousMemoryStreamPartitionOffset(part, num) => (part, num)
98+
}.toMap
99+
)
100+
}
101+
102+
override def createDataReaderFactories(): ju.List[DataReaderFactory[Row]] = {
103+
synchronized {
104+
val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id"
105+
endpointRef =
106+
recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint)
107+
108+
startOffset.partitionNums.map {
109+
case (part, index) =>
110+
new ContinuousMemoryStreamDataReaderFactory(
111+
endpointName, part, index): DataReaderFactory[Row]
112+
}.toList.asJava
113+
}
114+
}
115+
116+
override def stop(): Unit = {
117+
if (endpointRef != null) recordEndpoint.rpcEnv.stop(endpointRef)
118+
}
119+
120+
override def commit(end: Offset): Unit = {}
121+
122+
// ContinuousReadSupport implementation
123+
// This is necessary because of how StreamTest finds the source for AddDataMemory steps.
124+
def createContinuousReader(
125+
schema: Optional[StructType],
126+
checkpointLocation: String,
127+
options: DataSourceOptions): ContinuousReader = {
128+
this
129+
}
130+
131+
/**
132+
* Endpoint for executors to poll for records.
133+
*/
134+
private class RecordEndpoint extends ThreadSafeRpcEndpoint {
135+
override val rpcEnv: RpcEnv = SparkEnv.get.rpcEnv
136+
137+
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
138+
case GetRecord(ContinuousMemoryStreamPartitionOffset(part, index)) =>
139+
ContinuousMemoryStream.this.synchronized {
140+
val buf = records(part)
141+
val record = if (buf.size <= index) None else Some(buf(index))
142+
143+
context.reply(record.map(Row(_)))
144+
}
145+
}
146+
}
147+
}
148+
149+
object ContinuousMemoryStream {
150+
case class GetRecord(offset: ContinuousMemoryStreamPartitionOffset)
151+
protected val memoryStreamId = new AtomicInteger(0)
152+
153+
def apply[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] =
154+
new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext)
155+
}
156+
157+
/**
158+
* Data reader factory for continuous memory stream.
159+
*/
160+
class ContinuousMemoryStreamDataReaderFactory(
161+
driverEndpointName: String,
162+
partition: Int,
163+
startOffset: Int) extends DataReaderFactory[Row] {
164+
override def createDataReader: ContinuousMemoryStreamDataReader =
165+
new ContinuousMemoryStreamDataReader(driverEndpointName, partition, startOffset)
166+
}
167+
168+
/**
169+
* Data reader for continuous memory stream.
170+
*
171+
* Polls the driver endpoint for new records.
172+
*/
173+
class ContinuousMemoryStreamDataReader(
174+
driverEndpointName: String,
175+
partition: Int,
176+
startOffset: Int) extends ContinuousDataReader[Row] {
177+
private val endpoint = RpcUtils.makeDriverRef(
178+
driverEndpointName,
179+
SparkEnv.get.conf,
180+
SparkEnv.get.rpcEnv)
181+
182+
private var currentOffset = startOffset
183+
private var current: Option[Row] = None
184+
185+
override def next(): Boolean = {
186+
current = None
187+
while (current.isEmpty) {
188+
Thread.sleep(10)
189+
current = endpoint.askSync[Option[Row]](
190+
GetRecord(ContinuousMemoryStreamPartitionOffset(partition, currentOffset)))
191+
}
192+
currentOffset += 1
193+
true
194+
}
195+
196+
override def get(): Row = current.get
197+
198+
override def close(): Unit = {}
199+
200+
override def getOffset: ContinuousMemoryStreamPartitionOffset =
201+
ContinuousMemoryStreamPartitionOffset(partition, currentOffset)
202+
}
203+
204+
case class ContinuousMemoryStreamOffset(partitionNums: Map[Int, Int])
205+
extends Offset {
206+
private implicit val formats = Serialization.formats(NoTypeHints)
207+
override def json(): String = Serialization.write(partitionNums)
208+
}
209+
210+
case class ContinuousMemoryStreamPartitionOffset(partition: Int, numProcessed: Int)
211+
extends PartitionOffset

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
9999
* been processed.
100100
*/
101101
object AddData {
102-
def apply[A](source: MemoryStream[A], data: A*): AddDataMemory[A] =
102+
def apply[A](source: MemoryStreamBase[A], data: A*): AddDataMemory[A] =
103103
AddDataMemory(source, data)
104104
}
105105

@@ -131,7 +131,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
131131
def runAction(): Unit
132132
}
133133

134-
case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData {
134+
case class AddDataMemory[A](source: MemoryStreamBase[A], data: Seq[A]) extends AddData {
135135
override def toString: String = s"AddData to $source: ${data.mkString(",")}"
136136

137137
override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {

0 commit comments

Comments
 (0)