Skip to content
This repository was archived by the owner on Jan 9, 2020. It is now read-only.

Commit 20cf511

Browse files
zsxwingcloud-fan
authored andcommitted
[SPARK-21253][CORE] Fix a bug that StreamCallback may not be notified if network errors happen
## What changes were proposed in this pull request? If a network error happens before processing StreamResponse/StreamFailure events, StreamCallback.onFailure won't be called. This PR fixes `failOutstandingRequests` to also notify outstanding StreamCallbacks. ## How was this patch tested? The new unit tests. Author: Shixiong Zhu <[email protected]> Closes apache#18472 from zsxwing/fix-stream-2. (cherry picked from commit 4996c53) Signed-off-by: Wenchen Fan <[email protected]>
1 parent 17a04b9 commit 20cf511

File tree

3 files changed

+59
-12
lines changed

3 files changed

+59
-12
lines changed

common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ public void stream(String streamId, StreamCallback callback) {
179179
// written to the socket atomically, so that callbacks are called in the right order
180180
// when responses arrive.
181181
synchronized (this) {
182-
handler.addStreamCallback(callback);
182+
handler.addStreamCallback(streamId, callback);
183183
channel.writeAndFlush(new StreamRequest(streamId)).addListener(future -> {
184184
if (future.isSuccess()) {
185185
long timeTaken = System.currentTimeMillis() - startTime;

common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import java.util.concurrent.ConcurrentLinkedQueue;
2525
import java.util.concurrent.atomic.AtomicLong;
2626

27+
import scala.Tuple2;
28+
2729
import com.google.common.annotations.VisibleForTesting;
2830
import io.netty.channel.Channel;
2931
import org.slf4j.Logger;
@@ -56,7 +58,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
5658

5759
private final Map<Long, RpcResponseCallback> outstandingRpcs;
5860

59-
private final Queue<StreamCallback> streamCallbacks;
61+
private final Queue<Tuple2<String, StreamCallback>> streamCallbacks;
6062
private volatile boolean streamActive;
6163

6264
/** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */
@@ -88,9 +90,9 @@ public void removeRpcRequest(long requestId) {
8890
outstandingRpcs.remove(requestId);
8991
}
9092

91-
public void addStreamCallback(StreamCallback callback) {
93+
public void addStreamCallback(String streamId, StreamCallback callback) {
9294
timeOfLastRequestNs.set(System.nanoTime());
93-
streamCallbacks.offer(callback);
95+
streamCallbacks.offer(Tuple2.apply(streamId, callback));
9496
}
9597

9698
@VisibleForTesting
@@ -104,15 +106,31 @@ public void deactivateStream() {
104106
*/
105107
private void failOutstandingRequests(Throwable cause) {
106108
for (Map.Entry<StreamChunkId, ChunkReceivedCallback> entry : outstandingFetches.entrySet()) {
107-
entry.getValue().onFailure(entry.getKey().chunkIndex, cause);
109+
try {
110+
entry.getValue().onFailure(entry.getKey().chunkIndex, cause);
111+
} catch (Exception e) {
112+
logger.warn("ChunkReceivedCallback.onFailure throws exception", e);
113+
}
108114
}
109115
for (Map.Entry<Long, RpcResponseCallback> entry : outstandingRpcs.entrySet()) {
110-
entry.getValue().onFailure(cause);
116+
try {
117+
entry.getValue().onFailure(cause);
118+
} catch (Exception e) {
119+
logger.warn("RpcResponseCallback.onFailure throws exception", e);
120+
}
121+
}
122+
for (Tuple2<String, StreamCallback> entry : streamCallbacks) {
123+
try {
124+
entry._2().onFailure(entry._1(), cause);
125+
} catch (Exception e) {
126+
logger.warn("StreamCallback.onFailure throws exception", e);
127+
}
111128
}
112129

113130
// It's OK if new fetches appear, as they will fail immediately.
114131
outstandingFetches.clear();
115132
outstandingRpcs.clear();
133+
streamCallbacks.clear();
116134
}
117135

118136
@Override
@@ -190,8 +208,9 @@ public void handle(ResponseMessage message) throws Exception {
190208
}
191209
} else if (message instanceof StreamResponse) {
192210
StreamResponse resp = (StreamResponse) message;
193-
StreamCallback callback = streamCallbacks.poll();
194-
if (callback != null) {
211+
Tuple2<String, StreamCallback> entry = streamCallbacks.poll();
212+
if (entry != null) {
213+
StreamCallback callback = entry._2();
195214
if (resp.byteCount > 0) {
196215
StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount,
197216
callback);
@@ -216,8 +235,9 @@ public void handle(ResponseMessage message) throws Exception {
216235
}
217236
} else if (message instanceof StreamFailure) {
218237
StreamFailure resp = (StreamFailure) message;
219-
StreamCallback callback = streamCallbacks.poll();
220-
if (callback != null) {
238+
Tuple2<String, StreamCallback> entry = streamCallbacks.poll();
239+
if (entry != null) {
240+
StreamCallback callback = entry._2();
221241
try {
222242
callback.onFailure(resp.streamId, new RuntimeException(resp.error));
223243
} catch (IOException ioe) {

common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.network;
1919

20+
import java.io.IOException;
2021
import java.nio.ByteBuffer;
2122

2223
import io.netty.channel.Channel;
@@ -127,17 +128,43 @@ public void testActiveStreams() throws Exception {
127128

128129
StreamResponse response = new StreamResponse("stream", 1234L, null);
129130
StreamCallback cb = mock(StreamCallback.class);
130-
handler.addStreamCallback(cb);
131+
handler.addStreamCallback("stream", cb);
131132
assertEquals(1, handler.numOutstandingRequests());
132133
handler.handle(response);
133134
assertEquals(1, handler.numOutstandingRequests());
134135
handler.deactivateStream();
135136
assertEquals(0, handler.numOutstandingRequests());
136137

137138
StreamFailure failure = new StreamFailure("stream", "uh-oh");
138-
handler.addStreamCallback(cb);
139+
handler.addStreamCallback("stream", cb);
139140
assertEquals(1, handler.numOutstandingRequests());
140141
handler.handle(failure);
141142
assertEquals(0, handler.numOutstandingRequests());
142143
}
144+
145+
@Test
146+
public void failOutstandingStreamCallbackOnClose() throws Exception {
147+
Channel c = new LocalChannel();
148+
c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder());
149+
TransportResponseHandler handler = new TransportResponseHandler(c);
150+
151+
StreamCallback cb = mock(StreamCallback.class);
152+
handler.addStreamCallback("stream-1", cb);
153+
handler.channelInactive();
154+
155+
verify(cb).onFailure(eq("stream-1"), isA(IOException.class));
156+
}
157+
158+
@Test
159+
public void failOutstandingStreamCallbackOnException() throws Exception {
160+
Channel c = new LocalChannel();
161+
c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder());
162+
TransportResponseHandler handler = new TransportResponseHandler(c);
163+
164+
StreamCallback cb = mock(StreamCallback.class);
165+
handler.addStreamCallback("stream-1", cb);
166+
handler.exceptionCaught(new IOException("Oops!"));
167+
168+
verify(cb).onFailure(eq("stream-1"), isA(IOException.class));
169+
}
143170
}

0 commit comments

Comments
 (0)