Skip to content

Commit 93ac2a0

Browse files
authored
Merge pull request #17295: [cherry-pick][release-2.38.0][BEAM-13519] Solve race issues when the server responds with an error before the GrpcStateClient finishes being constructed. (#17240)
[cherry-pick][release-2.38.0][BEAM-13519] Solve race issues when the server responds with an error before the GrpcStateClient finishes being constructed. (#17240)
2 parents 47a9d8f + 4954909 commit 93ac2a0

File tree

2 files changed

+117
-69
lines changed

2 files changed

+117
-69
lines changed

sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCache.java

Lines changed: 71 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,9 @@
1818
package org.apache.beam.fn.harness.state;
1919

2020
import java.io.IOException;
21+
import java.util.HashMap;
2122
import java.util.Map;
2223
import java.util.concurrent.CompletableFuture;
23-
import java.util.concurrent.ConcurrentHashMap;
24-
import java.util.concurrent.ConcurrentMap;
2524
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
2625
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse;
2726
import org.apache.beam.model.fnexecution.v1.BeamFnStateGrpc;
@@ -45,7 +44,7 @@
4544
public class BeamFnStateGrpcClientCache {
4645
private static final Logger LOG = LoggerFactory.getLogger(BeamFnStateGrpcClientCache.class);
4746

48-
private final ConcurrentMap<ApiServiceDescriptor, BeamFnStateClient> cache;
47+
private final Map<ApiServiceDescriptor, BeamFnStateClient> cache;
4948
private final ManagedChannelFactory channelFactory;
5049
private final OutboundObserverFactory outboundObserverFactory;
5150
private final IdGenerator idGenerator;
@@ -59,74 +58,109 @@ public BeamFnStateGrpcClientCache(
5958
// This showed a 1-2% improvement in the ProcessBundleBenchmark#testState* benchmarks.
6059
this.channelFactory = channelFactory.withDirectExecutor();
6160
this.outboundObserverFactory = outboundObserverFactory;
62-
this.cache = new ConcurrentHashMap<>();
61+
this.cache = new HashMap<>();
6362
}
6463

6564
/**
6665
* Creates or returns an existing {@link BeamFnStateClient} depending on whether the passed in
6766
* {@link ApiServiceDescriptor} currently has a {@link BeamFnStateClient} bound to the same
6867
* channel.
6968
*/
70-
public BeamFnStateClient forApiServiceDescriptor(ApiServiceDescriptor apiServiceDescriptor)
71-
throws IOException {
72-
return cache.computeIfAbsent(apiServiceDescriptor, this::createBeamFnStateClient);
73-
}
74-
75-
private BeamFnStateClient createBeamFnStateClient(ApiServiceDescriptor apiServiceDescriptor) {
76-
return new GrpcStateClient(apiServiceDescriptor);
69+
public synchronized BeamFnStateClient forApiServiceDescriptor(
70+
ApiServiceDescriptor apiServiceDescriptor) throws IOException {
71+
// We specifically are synchronized so that we only create one GrpcStateClient at a time
72+
// preventing a race where multiple GrpcStateClient objects might be constructed at the same
73+
// for the same ApiServiceDescriptor.
74+
BeamFnStateClient rval;
75+
synchronized (cache) {
76+
rval = cache.get(apiServiceDescriptor);
77+
}
78+
if (rval == null) {
79+
// We can't be synchronized on cache while constructing the GrpcStateClient since if the
80+
// connection fails, onError may be invoked from the gRPC thread which will invoke
81+
// closeAndCleanUp that clears the cache.
82+
rval = new GrpcStateClient(apiServiceDescriptor);
83+
synchronized (cache) {
84+
cache.put(apiServiceDescriptor, rval);
85+
}
86+
}
87+
return rval;
7788
}
7889

7990
/** A {@link BeamFnStateClient} for a given {@link ApiServiceDescriptor}. */
8091
private class GrpcStateClient implements BeamFnStateClient {
92+
private final Object lock = new Object();
8193
private final ApiServiceDescriptor apiServiceDescriptor;
82-
private final ConcurrentMap<String, CompletableFuture<StateResponse>> outstandingRequests;
94+
private final Map<String, CompletableFuture<StateResponse>> outstandingRequests;
8395
private final StreamObserver<StateRequest> outboundObserver;
8496
private final ManagedChannel channel;
85-
private volatile RuntimeException closed;
97+
private RuntimeException closed;
98+
private boolean errorDuringConstruction;
8699

87100
private GrpcStateClient(ApiServiceDescriptor apiServiceDescriptor) {
88101
this.apiServiceDescriptor = apiServiceDescriptor;
89-
this.outstandingRequests = new ConcurrentHashMap<>();
102+
this.outstandingRequests = new HashMap<>();
90103
this.channel = channelFactory.forDescriptor(apiServiceDescriptor);
104+
this.errorDuringConstruction = false;
91105
this.outboundObserver =
92106
outboundObserverFactory.outboundObserverFor(
93107
BeamFnStateGrpc.newStub(channel)::state, new InboundObserver());
108+
// Due to safe object publishing, the InboundObserver may invoke closeAndCleanUp before this
109+
// constructor completes. In that case there is a race where outboundObserver may have not
110+
// been initialized and hence we invoke onCompleted here.
111+
synchronized (lock) {
112+
if (errorDuringConstruction) {
113+
outboundObserver.onCompleted();
114+
}
115+
}
94116
}
95117

96118
@Override
97119
public CompletableFuture<StateResponse> handle(StateRequest.Builder requestBuilder) {
98120
requestBuilder.setId(idGenerator.getId());
99121
StateRequest request = requestBuilder.build();
100122
CompletableFuture<StateResponse> response = new CompletableFuture<>();
101-
outstandingRequests.put(request.getId(), response);
123+
synchronized (lock) {
124+
if (closed != null) {
125+
response.completeExceptionally(closed);
126+
return response;
127+
}
128+
outstandingRequests.put(request.getId(), response);
129+
}
102130

103131
// If the server closes, gRPC will throw an error if onNext is called.
104132
LOG.debug("Sending StateRequest {}", request);
105133
outboundObserver.onNext(request);
106134
return response;
107135
}
108136

109-
private synchronized void closeAndCleanUp(RuntimeException cause) {
110-
if (closed != null) {
111-
return;
112-
}
113-
cache.remove(apiServiceDescriptor);
114-
closed = cause;
115-
116-
// Make a copy of the map to make the view of the outstanding requests consistent.
117-
Map<String, CompletableFuture<StateResponse>> outstandingRequestsCopy =
118-
new ConcurrentHashMap<>(outstandingRequests);
137+
private void closeAndCleanUp(RuntimeException cause) {
138+
synchronized (lock) {
139+
if (closed != null) {
140+
return;
141+
}
142+
closed = cause;
119143

120-
if (outstandingRequestsCopy.isEmpty()) {
121-
outboundObserver.onCompleted();
122-
return;
123-
}
144+
synchronized (cache) {
145+
cache.remove(apiServiceDescriptor);
146+
}
124147

125-
outstandingRequests.clear();
126-
LOG.error("BeamFnState failed, clearing outstanding requests {}", outstandingRequestsCopy);
148+
if (!outstandingRequests.isEmpty()) {
149+
LOG.error("BeamFnState failed, clearing outstanding requests {}", outstandingRequests);
150+
for (CompletableFuture<StateResponse> entry : outstandingRequests.values()) {
151+
entry.completeExceptionally(cause);
152+
}
153+
outstandingRequests.clear();
154+
}
127155

128-
for (CompletableFuture<StateResponse> entry : outstandingRequestsCopy.values()) {
129-
entry.completeExceptionally(cause);
156+
// Due to safe object publishing, outboundObserver may be null since InboundObserver may
157+
// call closeAndCleanUp before the GrpcStateClient finishes construction. In this case
158+
// we defer invoking onCompleted to the GrpcStateClient constructor.
159+
if (outboundObserver == null) {
160+
errorDuringConstruction = true;
161+
} else {
162+
outboundObserver.onCompleted();
163+
}
130164
}
131165
}
132166

@@ -143,7 +177,10 @@ private class InboundObserver implements StreamObserver<StateResponse> {
143177
@Override
144178
public void onNext(StateResponse value) {
145179
LOG.debug("Received StateResponse {}", value);
146-
CompletableFuture<StateResponse> responseFuture = outstandingRequests.remove(value.getId());
180+
CompletableFuture<StateResponse> responseFuture;
181+
synchronized (lock) {
182+
responseFuture = outstandingRequests.remove(value.getId());
183+
}
147184
if (responseFuture == null) {
148185
LOG.warn("Dropped unknown StateResponse {}", value);
149186
return;

sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCacheTest.java

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,19 @@
2828
import java.util.concurrent.BlockingQueue;
2929
import java.util.concurrent.CompletableFuture;
3030
import java.util.concurrent.ExecutionException;
31+
import java.util.concurrent.Executors;
32+
import java.util.concurrent.Future;
3133
import java.util.concurrent.LinkedBlockingQueue;
3234
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
3335
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse;
3436
import org.apache.beam.model.fnexecution.v1.BeamFnStateGrpc;
37+
import org.apache.beam.model.fnexecution.v1.BeamFnStateGrpc.BeamFnStateImplBase;
3538
import org.apache.beam.model.pipeline.v1.Endpoints;
3639
import org.apache.beam.sdk.fn.IdGenerators;
3740
import org.apache.beam.sdk.fn.channel.ManagedChannelFactory;
3841
import org.apache.beam.sdk.fn.stream.OutboundObserverFactory;
42+
import org.apache.beam.sdk.fn.test.TestExecutors;
43+
import org.apache.beam.sdk.fn.test.TestExecutors.TestExecutorService;
3944
import org.apache.beam.sdk.fn.test.TestStreams;
4045
import org.apache.beam.vendor.grpc.v1p43p2.io.grpc.Server;
4146
import org.apache.beam.vendor.grpc.v1p43p2.io.grpc.Status;
@@ -46,6 +51,7 @@
4651
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Uninterruptibles;
4752
import org.junit.After;
4853
import org.junit.Before;
54+
import org.junit.Rule;
4955
import org.junit.Test;
5056
import org.junit.runner.RunWith;
5157
import org.junit.runners.JUnit4;
@@ -58,6 +64,8 @@ public class BeamFnStateGrpcClientCacheTest {
5864
private static final String TEST_ERROR = "TEST ERROR";
5965
private static final String SERVER_ERROR = "SERVER ERROR";
6066

67+
@Rule public TestExecutorService executor = TestExecutors.from(Executors::newCachedThreadPool);
68+
6169
private Endpoints.ApiServiceDescriptor apiServiceDescriptor;
6270
private Server testServer;
6371
private BeamFnStateGrpcClientCache clientCache;
@@ -110,18 +118,17 @@ public void testCachingOfClient() throws Exception {
110118
Server testServer2 =
111119
InProcessServerBuilder.forName(otherApiServiceDescriptor.getUrl())
112120
.addService(
113-
new BeamFnStateGrpc.BeamFnStateImplBase() {
121+
new BeamFnStateImplBase() {
114122
@Override
115123
public StreamObserver<StateRequest> state(
116124
StreamObserver<StateResponse> outboundObserver) {
117-
throw new IllegalStateException("Unexpected in test.");
125+
throw new RuntimeException();
118126
}
119127
})
120128
.build();
121129
testServer2.start();
122130

123131
try {
124-
125132
assertSame(
126133
clientCache.forApiServiceDescriptor(apiServiceDescriptor),
127134
clientCache.forApiServiceDescriptor(apiServiceDescriptor));
@@ -162,25 +169,27 @@ public void testRequestResponses() throws Exception {
162169
}
163170

164171
@Test
172+
// The checker erroneously flags that the CompletableFuture is not being resolved since it is the
173+
// result to Executor#submit.
174+
@SuppressWarnings("FutureReturnValueIgnored")
165175
public void testServerErrorCausesPendingAndFutureCallsToFail() throws Exception {
166176
BeamFnStateClient client = clientCache.forApiServiceDescriptor(apiServiceDescriptor);
167177

168-
CompletableFuture<StateResponse> inflight =
169-
client.handle(StateRequest.newBuilder().setInstructionId(SUCCESS));
170-
171-
// Wait for the client to connect.
172-
StreamObserver<StateResponse> outboundServerObserver = outboundServerObservers.take();
173-
// Send an error from the server.
174-
outboundServerObserver.onError(
175-
new StatusRuntimeException(Status.INTERNAL.withDescription(SERVER_ERROR)));
176-
177-
try {
178-
inflight.get();
179-
fail("Expected unsuccessful response due to server error");
180-
} catch (ExecutionException e) {
181-
assertThat(e.toString(), containsString(SERVER_ERROR));
182-
}
183-
178+
Future<CompletableFuture<StateResponse>> stateResponse =
179+
executor.submit(() -> client.handle(StateRequest.newBuilder().setInstructionId(SUCCESS)));
180+
Future<Void> serverResponse =
181+
executor.submit(
182+
() -> {
183+
// Wait for the client to connect.
184+
StreamObserver<StateResponse> outboundServerObserver = outboundServerObservers.take();
185+
// Send an error from the server.
186+
outboundServerObserver.onError(
187+
new StatusRuntimeException(Status.INTERNAL.withDescription(SERVER_ERROR)));
188+
return null;
189+
});
190+
191+
CompletableFuture<StateResponse> inflight = stateResponse.get();
192+
serverResponse.get();
184193
try {
185194
inflight.get();
186195
fail("Expected unsuccessful response due to server error");
@@ -190,27 +199,29 @@ public void testServerErrorCausesPendingAndFutureCallsToFail() throws Exception
190199
}
191200

192201
@Test
202+
// The checker erroneously flags that the CompletableFuture is not being resolved since it is the
203+
// result to Executor#submit.
204+
@SuppressWarnings("FutureReturnValueIgnored")
193205
public void testServerCompletionCausesPendingAndFutureCallsToFail() throws Exception {
194206
BeamFnStateClient client = clientCache.forApiServiceDescriptor(apiServiceDescriptor);
195207

196-
CompletableFuture<StateResponse> inflight =
197-
client.handle(StateRequest.newBuilder().setInstructionId(SUCCESS));
198-
199-
// Wait for the client to connect.
200-
StreamObserver<StateResponse> outboundServerObserver = outboundServerObservers.take();
201-
// Send that the server is done.
202-
outboundServerObserver.onCompleted();
203-
208+
Future<CompletableFuture<StateResponse>> stateResponse =
209+
executor.submit(() -> client.handle(StateRequest.newBuilder().setInstructionId(SUCCESS)));
210+
Future<Void> serverResponse =
211+
executor.submit(
212+
() -> {
213+
// Wait for the client to connect.
214+
StreamObserver<StateResponse> outboundServerObserver = outboundServerObservers.take();
215+
// Send that the server is done.
216+
outboundServerObserver.onCompleted();
217+
return null;
218+
});
219+
220+
CompletableFuture<StateResponse> inflight = stateResponse.get();
221+
serverResponse.get();
204222
try {
205223
inflight.get();
206-
fail("Expected unsuccessful response due to server completion");
207-
} catch (ExecutionException e) {
208-
assertThat(e.toString(), containsString("Server hanged up"));
209-
}
210-
211-
try {
212-
inflight.get();
213-
fail("Expected unsuccessful response due to server completion");
224+
fail("Expected unsuccessful response due to server error");
214225
} catch (ExecutionException e) {
215226
assertThat(e.toString(), containsString("Server hanged up"));
216227
}

0 commit comments

Comments
 (0)