|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql.streaming
|
19 | 19 |
|
20 |
| -import java.io.{File, InterruptedIOException, IOException} |
21 |
| -import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} |
| 20 | +import java.io.{File, InterruptedIOException, IOException, UncheckedIOException} |
| 21 | +import java.nio.channels.ClosedByInterruptException |
| 22 | +import java.util.concurrent.{CountDownLatch, ExecutionException, TimeoutException, TimeUnit} |
22 | 23 |
|
23 | 24 | import scala.reflect.ClassTag
|
24 | 25 | import scala.util.control.ControlThrowable
|
25 | 26 |
|
| 27 | +import com.google.common.util.concurrent.UncheckedExecutionException |
26 | 28 | import org.apache.commons.io.FileUtils
|
27 | 29 | import org.apache.hadoop.conf.Configuration
|
28 | 30 |
|
@@ -691,6 +693,31 @@ class StreamSuite extends StreamTest {
|
691 | 693 | }
|
692 | 694 | }
|
693 | 695 | }
|
| 696 | + |
| 697 | + for (e <- Seq( |
| 698 | + new InterruptedException, |
| 699 | + new InterruptedIOException, |
| 700 | + new ClosedByInterruptException, |
| 701 | + new UncheckedIOException("test", new ClosedByInterruptException), |
| 702 | + new ExecutionException("test", new InterruptedException), |
| 703 | + new UncheckedExecutionException("test", new InterruptedException))) { |
| 704 | + test(s"view ${e.getClass.getSimpleName} as a normal query stop") { |
| 705 | + ThrowingExceptionInCreateSource.createSourceLatch = new CountDownLatch(1) |
| 706 | + ThrowingExceptionInCreateSource.exception = e |
| 707 | + val query = spark |
| 708 | + .readStream |
| 709 | + .format(classOf[ThrowingExceptionInCreateSource].getName) |
| 710 | + .load() |
| 711 | + .writeStream |
| 712 | + .format("console") |
| 713 | + .start() |
| 714 | + assert(ThrowingExceptionInCreateSource.createSourceLatch |
| 715 | + .await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS), |
| 716 | + "ThrowingExceptionInCreateSource.createSource wasn't called before timeout") |
| 717 | + query.stop() |
| 718 | + assert(query.exception.isEmpty) |
| 719 | + } |
| 720 | + } |
694 | 721 | }
|
695 | 722 |
|
696 | 723 | abstract class FakeSource extends StreamSourceProvider {
|
@@ -824,3 +851,32 @@ class TestStateStoreProvider extends StateStoreProvider {
|
824 | 851 |
|
825 | 852 | override def getStore(version: Long): StateStore = null
|
826 | 853 | }
|
| 854 | + |
| 855 | +/** A fake source that throws `ThrowingExceptionInCreateSource.exception` in `createSource` */ |
| 856 | +class ThrowingExceptionInCreateSource extends FakeSource { |
| 857 | + |
| 858 | + override def createSource( |
| 859 | + spark: SQLContext, |
| 860 | + metadataPath: String, |
| 861 | + schema: Option[StructType], |
| 862 | + providerName: String, |
| 863 | + parameters: Map[String, String]): Source = { |
| 864 | + ThrowingExceptionInCreateSource.createSourceLatch.countDown() |
| 865 | + try { |
| 866 | + Thread.sleep(30000) |
| 867 | + throw new TimeoutException("sleep was not interrupted in 30 seconds") |
| 868 | + } catch { |
| 869 | + case _: InterruptedException => |
| 870 | + throw ThrowingExceptionInCreateSource.exception |
| 871 | + } |
| 872 | + } |
| 873 | +} |
| 874 | + |
| 875 | +object ThrowingExceptionInCreateSource { |
| 876 | + /** |
| 877 | + * A latch to allow the user to wait until `ThrowingExceptionInCreateSource.createSource` is |
| 878 | + * called. |
| 879 | + */ |
| 880 | + @volatile var createSourceLatch: CountDownLatch = null |
| 881 | + @volatile var exception: Exception = null |
| 882 | +} |
0 commit comments