Skip to content

Commit 2cb9763

Browse files
tdaszsxwing
authored andcommitted
[SPARK-24565][SS] Add API for in Structured Streaming for exposing output rows of each microbatch as a DataFrame
## What changes were proposed in this pull request? Currently, the micro-batches in the MicroBatchExecution is not exposed to the user through any public API. This was because we did not want to expose the micro-batches, so that all the APIs we expose, we can eventually support them in the Continuous engine. But now that we have better sense of buiding a ContinuousExecution, I am considering adding APIs which will run only the MicroBatchExecution. I have quite a few use cases where exposing the microbatch output as a dataframe is useful. - Pass the output rows of each batch to a library that is designed only the batch jobs (example, uses many ML libraries need to collect() while learning). - Reuse batch data sources for output whose streaming version does not exists (e.g. redshift data source). - Writer the output rows to multiple places by writing twice for each batch. This is not the most elegant thing to do for multiple-output streaming queries but is likely to be better than running two streaming queries processing the same data twice. The proposal is to add a method `foreachBatch(f: Dataset[T] => Unit)` to Scala/Java/Python `DataStreamWriter`. ## How was this patch tested? New unit tests. Author: Tathagata Das <[email protected]> Closes apache#21571 from tdas/foreachBatch.
1 parent 13092d7 commit 2cb9763

File tree

8 files changed

+383
-21
lines changed

8 files changed

+383
-21
lines changed

python/pyspark/java_gateway.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
if sys.version >= '3':
3232
xrange = range
3333

34-
from py4j.java_gateway import java_import, JavaGateway, GatewayParameters
34+
from py4j.java_gateway import java_import, JavaGateway, JavaObject, GatewayParameters
3535
from pyspark.find_spark_home import _find_spark_home
3636
from pyspark.serializers import read_int, write_with_length, UTF8Deserializer
3737

@@ -145,3 +145,26 @@ def do_server_auth(conn, auth_secret):
145145
if reply != "ok":
146146
conn.close()
147147
raise Exception("Unexpected reply from iterator server.")
148+
149+
150+
def ensure_callback_server_started(gw):
151+
"""
152+
Start callback server if not already started. The callback server is needed if the Java
153+
driver process needs to callback into the Python driver process to execute Python code.
154+
"""
155+
156+
# getattr will fallback to JVM, so we cannot test by hasattr()
157+
if "_callback_server" not in gw.__dict__ or gw._callback_server is None:
158+
gw.callback_server_parameters.eager_load = True
159+
gw.callback_server_parameters.daemonize = True
160+
gw.callback_server_parameters.daemonize_connections = True
161+
gw.callback_server_parameters.port = 0
162+
gw.start_callback_server(gw.callback_server_parameters)
163+
cbport = gw._callback_server.server_socket.getsockname()[1]
164+
gw._callback_server.port = cbport
165+
# gateway with real port
166+
gw._python_proxy_port = gw._callback_server.port
167+
# get the GatewayServer object in JVM by ID
168+
jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client)
169+
# update the port of CallbackClient with real port
170+
jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port)

python/pyspark/sql/streaming.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,14 @@
2424
else:
2525
intlike = (int, long)
2626

27+
from py4j.java_gateway import java_import
28+
2729
from pyspark import since, keyword_only
2830
from pyspark.rdd import ignore_unicode_prefix
2931
from pyspark.sql.column import _to_seq
3032
from pyspark.sql.readwriter import OptionUtils, to_str
3133
from pyspark.sql.types import *
32-
from pyspark.sql.utils import StreamingQueryException
34+
from pyspark.sql.utils import ForeachBatchFunction, StreamingQueryException
3335

3436
__all__ = ["StreamingQuery", "StreamingQueryManager", "DataStreamReader", "DataStreamWriter"]
3537

@@ -1016,6 +1018,35 @@ def func_with_open_process_close(partition_id, iterator):
10161018
self._jwrite.foreach(jForeachWriter)
10171019
return self
10181020

1021+
@since(2.4)
1022+
def foreachBatch(self, func):
1023+
"""
1024+
Sets the output of the streaming query to be processed using the provided
1025+
function. This is supported only the in the micro-batch execution modes (that is, when the
1026+
trigger is not continuous). In every micro-batch, the provided function will be called in
1027+
every micro-batch with (i) the output rows as a DataFrame and (ii) the batch identifier.
1028+
The batchId can be used deduplicate and transactionally write the output
1029+
(that is, the provided Dataset) to external systems. The output DataFrame is guaranteed
1030+
to exactly same for the same batchId (assuming all operations are deterministic in the
1031+
query).
1032+
1033+
.. note:: Evolving.
1034+
1035+
>>> def func(batch_df, batch_id):
1036+
... batch_df.collect()
1037+
...
1038+
>>> writer = sdf.writeStream.foreach(func)
1039+
"""
1040+
1041+
from pyspark.java_gateway import ensure_callback_server_started
1042+
gw = self._spark._sc._gateway
1043+
java_import(gw.jvm, "org.apache.spark.sql.execution.streaming.sources.*")
1044+
1045+
wrapped_func = ForeachBatchFunction(self._spark, func)
1046+
gw.jvm.PythonForeachBatchHelper.callForeachBatch(self._jwrite, wrapped_func)
1047+
ensure_callback_server_started(gw)
1048+
return self
1049+
10191050
@ignore_unicode_prefix
10201051
@since(2.0)
10211052
def start(self, path=None, format=None, outputMode=None, partitionBy=None, queryName=None,

python/pyspark/sql/tests.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2126,6 +2126,42 @@ class WriterWithNonCallableClose(WithProcess):
21262126
tester.assert_invalid_writer(WriterWithNonCallableClose(),
21272127
"'close' in provided object is not callable")
21282128

2129+
def test_streaming_foreachBatch(self):
2130+
q = None
2131+
collected = dict()
2132+
2133+
def collectBatch(batch_df, batch_id):
2134+
collected[batch_id] = batch_df.collect()
2135+
2136+
try:
2137+
df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
2138+
q = df.writeStream.foreachBatch(collectBatch).start()
2139+
q.processAllAvailable()
2140+
self.assertTrue(0 in collected)
2141+
self.assertTrue(len(collected[0]), 2)
2142+
finally:
2143+
if q:
2144+
q.stop()
2145+
2146+
def test_streaming_foreachBatch_propagates_python_errors(self):
2147+
from pyspark.sql.utils import StreamingQueryException
2148+
2149+
q = None
2150+
2151+
def collectBatch(df, id):
2152+
raise Exception("this should fail the query")
2153+
2154+
try:
2155+
df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
2156+
q = df.writeStream.foreachBatch(collectBatch).start()
2157+
q.processAllAvailable()
2158+
self.fail("Expected a failure")
2159+
except StreamingQueryException as e:
2160+
self.assertTrue("this should fail" in str(e))
2161+
finally:
2162+
if q:
2163+
q.stop()
2164+
21292165
def test_help_command(self):
21302166
# Regression test for SPARK-5464
21312167
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])

python/pyspark/sql/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,26 @@ def require_minimum_pyarrow_version():
150150
if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version):
151151
raise ImportError("PyArrow >= %s must be installed; however, "
152152
"your version was %s." % (minimum_pyarrow_version, pyarrow.__version__))
153+
154+
155+
class ForeachBatchFunction(object):
156+
"""
157+
This is the Python implementation of Java interface 'ForeachBatchFunction'. This wraps
158+
the user-defined 'foreachBatch' function such that it can be called from the JVM when
159+
the query is active.
160+
"""
161+
162+
def __init__(self, sql_ctx, func):
163+
self.sql_ctx = sql_ctx
164+
self.func = func
165+
166+
def call(self, jdf, batch_id):
167+
from pyspark.sql.dataframe import DataFrame
168+
try:
169+
self.func(DataFrame(jdf, self.sql_ctx), batch_id)
170+
except Exception as e:
171+
self.error = e
172+
raise e
173+
174+
class Java:
175+
implements = ['org.apache.spark.sql.execution.streaming.sources.PythonForeachBatchFunction']

python/pyspark/streaming/context.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -79,22 +79,8 @@ def _ensure_initialized(cls):
7979
java_import(gw.jvm, "org.apache.spark.streaming.api.java.*")
8080
java_import(gw.jvm, "org.apache.spark.streaming.api.python.*")
8181

82-
# start callback server
83-
# getattr will fallback to JVM, so we cannot test by hasattr()
84-
if "_callback_server" not in gw.__dict__ or gw._callback_server is None:
85-
gw.callback_server_parameters.eager_load = True
86-
gw.callback_server_parameters.daemonize = True
87-
gw.callback_server_parameters.daemonize_connections = True
88-
gw.callback_server_parameters.port = 0
89-
gw.start_callback_server(gw.callback_server_parameters)
90-
cbport = gw._callback_server.server_socket.getsockname()[1]
91-
gw._callback_server.port = cbport
92-
# gateway with real port
93-
gw._python_proxy_port = gw._callback_server.port
94-
# get the GatewayServer object in JVM by ID
95-
jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client)
96-
# update the port of CallbackClient with real port
97-
jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port)
82+
from pyspark.java_gateway import ensure_callback_server_started
83+
ensure_callback_server_started(gw)
9884

9985
# register serializer for TransformFunction
10086
# it happens before creating SparkContext when loading from checkpointing
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.streaming.sources
19+
20+
import org.apache.spark.api.python.PythonException
21+
import org.apache.spark.sql._
22+
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
23+
import org.apache.spark.sql.execution.streaming.Sink
24+
import org.apache.spark.sql.streaming.DataStreamWriter
25+
26+
class ForeachBatchSink[T](batchWriter: (Dataset[T], Long) => Unit, encoder: ExpressionEncoder[T])
27+
extends Sink {
28+
29+
override def addBatch(batchId: Long, data: DataFrame): Unit = {
30+
val resolvedEncoder = encoder.resolveAndBind(
31+
data.logicalPlan.output,
32+
data.sparkSession.sessionState.analyzer)
33+
val rdd = data.queryExecution.toRdd.map[T](resolvedEncoder.fromRow)(encoder.clsTag)
34+
val ds = data.sparkSession.createDataset(rdd)(encoder)
35+
batchWriter(ds, batchId)
36+
}
37+
38+
override def toString(): String = "ForeachBatchSink"
39+
}
40+
41+
42+
/**
43+
* Interface that is meant to be extended by Python classes via Py4J.
44+
* Py4J allows Python classes to implement Java interfaces so that the JVM can call back
45+
* Python objects. In this case, this allows the user-defined Python `foreachBatch` function
46+
* to be called from JVM when the query is active.
47+
* */
48+
trait PythonForeachBatchFunction {
49+
/** Call the Python implementation of this function */
50+
def call(batchDF: DataFrame, batchId: Long): Unit
51+
}
52+
53+
object PythonForeachBatchHelper {
54+
def callForeachBatch(dsw: DataStreamWriter[Row], pythonFunc: PythonForeachBatchFunction): Unit = {
55+
dsw.foreachBatch(pythonFunc.call _)
56+
}
57+
}
58+

sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,15 @@ import java.util.Locale
2121

2222
import scala.collection.JavaConverters._
2323

24-
import org.apache.spark.annotation.InterfaceStability
25-
import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter}
24+
import org.apache.spark.annotation.{InterfaceStability, Since}
25+
import org.apache.spark.api.java.function.VoidFunction2
26+
import org.apache.spark.sql._
2627
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
2728
import org.apache.spark.sql.execution.command.DDLUtils
2829
import org.apache.spark.sql.execution.datasources.DataSource
2930
import org.apache.spark.sql.execution.streaming._
3031
import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
31-
import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2}
32+
import org.apache.spark.sql.execution.streaming.sources._
3233
import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport}
3334

3435
/**
@@ -279,6 +280,21 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
279280
outputMode,
280281
useTempCheckpointLocation = true,
281282
trigger = trigger)
283+
} else if (source == "foreachBatch") {
284+
assertNotPartitioned("foreachBatch")
285+
if (trigger.isInstanceOf[ContinuousTrigger]) {
286+
throw new AnalysisException("'foreachBatch' is not supported with continuous trigger")
287+
}
288+
val sink = new ForeachBatchSink[T](foreachBatchWriter, ds.exprEnc)
289+
df.sparkSession.sessionState.streamingQueryManager.startQuery(
290+
extraOptions.get("queryName"),
291+
extraOptions.get("checkpointLocation"),
292+
df,
293+
extraOptions.toMap,
294+
sink,
295+
outputMode,
296+
useTempCheckpointLocation = true,
297+
trigger = trigger)
282298
} else {
283299
val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
284300
val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",")
@@ -322,6 +338,45 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
322338
this
323339
}
324340

341+
/**
342+
* :: Experimental ::
343+
*
344+
* (Scala-specific) Sets the output of the streaming query to be processed using the provided
345+
* function. This is supported only the in the micro-batch execution modes (that is, when the
346+
* trigger is not continuous). In every micro-batch, the provided function will be called in
347+
* every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier.
348+
* The batchId can be used deduplicate and transactionally write the output
349+
* (that is, the provided Dataset) to external systems. The output Dataset is guaranteed
350+
* to exactly same for the same batchId (assuming all operations are deterministic in the query).
351+
*
352+
* @since 2.4.0
353+
*/
354+
@InterfaceStability.Evolving
355+
def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] = {
356+
this.source = "foreachBatch"
357+
if (function == null) throw new IllegalArgumentException("foreachBatch function cannot be null")
358+
this.foreachBatchWriter = function
359+
this
360+
}
361+
362+
/**
363+
* :: Experimental ::
364+
*
365+
* (Java-specific) Sets the output of the streaming query to be processed using the provided
366+
* function. This is supported only the in the micro-batch execution modes (that is, when the
367+
* trigger is not continuous). In every micro-batch, the provided function will be called in
368+
* every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier.
369+
* The batchId can be used deduplicate and transactionally write the output
370+
* (that is, the provided Dataset) to external systems. The output Dataset is guaranteed
371+
* to exactly same for the same batchId (assuming all operations are deterministic in the query).
372+
*
373+
* @since 2.4.0
374+
*/
375+
@InterfaceStability.Evolving
376+
def foreachBatch(function: VoidFunction2[Dataset[T], Long]): DataStreamWriter[T] = {
377+
foreachBatch((batchDs: Dataset[T], batchId: Long) => function.call(batchDs, batchId))
378+
}
379+
325380
private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols =>
326381
cols.map(normalize(_, "Partition"))
327382
}
@@ -358,5 +413,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
358413

359414
private var foreachWriter: ForeachWriter[T] = null
360415

416+
private var foreachBatchWriter: (Dataset[T], Long) => Unit = null
417+
361418
private var partitioningColumns: Option[Seq[String]] = None
362419
}

0 commit comments

Comments
 (0)