1818package org .apache .beam .fn .harness .state ;
1919
2020import java .io .IOException ;
21+ import java .util .HashMap ;
2122import java .util .Map ;
2223import java .util .concurrent .CompletableFuture ;
23- import java .util .concurrent .ConcurrentHashMap ;
24- import java .util .concurrent .ConcurrentMap ;
2524import org .apache .beam .model .fnexecution .v1 .BeamFnApi .StateRequest ;
2625import org .apache .beam .model .fnexecution .v1 .BeamFnApi .StateResponse ;
2726import org .apache .beam .model .fnexecution .v1 .BeamFnStateGrpc ;
4544public class BeamFnStateGrpcClientCache {
4645 private static final Logger LOG = LoggerFactory .getLogger (BeamFnStateGrpcClientCache .class );
4746
48- private final ConcurrentMap <ApiServiceDescriptor , BeamFnStateClient > cache ;
47+ private final Map <ApiServiceDescriptor , BeamFnStateClient > cache ;
4948 private final ManagedChannelFactory channelFactory ;
5049 private final OutboundObserverFactory outboundObserverFactory ;
5150 private final IdGenerator idGenerator ;
@@ -59,74 +58,109 @@ public BeamFnStateGrpcClientCache(
5958 // This showed a 1-2% improvement in the ProcessBundleBenchmark#testState* benchmarks.
6059 this .channelFactory = channelFactory .withDirectExecutor ();
6160 this .outboundObserverFactory = outboundObserverFactory ;
62- this .cache = new ConcurrentHashMap <>();
61+ this .cache = new HashMap <>();
6362 }
6463
6564 /**
6665 * Creates or returns an existing {@link BeamFnStateClient} depending on whether the passed in
6766 * {@link ApiServiceDescriptor} currently has a {@link BeamFnStateClient} bound to the same
6867 * channel.
6968 */
70- public BeamFnStateClient forApiServiceDescriptor (ApiServiceDescriptor apiServiceDescriptor )
71- throws IOException {
72- return cache .computeIfAbsent (apiServiceDescriptor , this ::createBeamFnStateClient );
73- }
74-
75- private BeamFnStateClient createBeamFnStateClient (ApiServiceDescriptor apiServiceDescriptor ) {
76- return new GrpcStateClient (apiServiceDescriptor );
69+ public synchronized BeamFnStateClient forApiServiceDescriptor (
70+ ApiServiceDescriptor apiServiceDescriptor ) throws IOException {
71+ // We specifically are synchronized so that we only create one GrpcStateClient at a time
72+ // preventing a race where multiple GrpcStateClient objects might be constructed at the same
73+ // for the same ApiServiceDescriptor.
74+ BeamFnStateClient rval ;
75+ synchronized (cache ) {
76+ rval = cache .get (apiServiceDescriptor );
77+ }
78+ if (rval == null ) {
79+ // We can't be synchronized on cache while constructing the GrpcStateClient since if the
80+ // connection fails, onError may be invoked from the gRPC thread which will invoke
81+ // closeAndCleanUp that clears the cache.
82+ rval = new GrpcStateClient (apiServiceDescriptor );
83+ synchronized (cache ) {
84+ cache .put (apiServiceDescriptor , rval );
85+ }
86+ }
87+ return rval ;
7788 }
7889
7990 /** A {@link BeamFnStateClient} for a given {@link ApiServiceDescriptor}. */
8091 private class GrpcStateClient implements BeamFnStateClient {
92+ private final Object lock = new Object ();
8193 private final ApiServiceDescriptor apiServiceDescriptor ;
82- private final ConcurrentMap <String , CompletableFuture <StateResponse >> outstandingRequests ;
94+ private final Map <String , CompletableFuture <StateResponse >> outstandingRequests ;
8395 private final StreamObserver <StateRequest > outboundObserver ;
8496 private final ManagedChannel channel ;
85- private volatile RuntimeException closed ;
97+ private RuntimeException closed ;
98+ private boolean errorDuringConstruction ;
8699
87100 private GrpcStateClient (ApiServiceDescriptor apiServiceDescriptor ) {
88101 this .apiServiceDescriptor = apiServiceDescriptor ;
89- this .outstandingRequests = new ConcurrentHashMap <>();
102+ this .outstandingRequests = new HashMap <>();
90103 this .channel = channelFactory .forDescriptor (apiServiceDescriptor );
104+ this .errorDuringConstruction = false ;
91105 this .outboundObserver =
92106 outboundObserverFactory .outboundObserverFor (
93107 BeamFnStateGrpc .newStub (channel )::state , new InboundObserver ());
108+ // Due to safe object publishing, the InboundObserver may invoke closeAndCleanUp before this
109+ // constructor completes. In that case there is a race where outboundObserver may have not
110+ // been initialized and hence we invoke onCompleted here.
111+ synchronized (lock ) {
112+ if (errorDuringConstruction ) {
113+ outboundObserver .onCompleted ();
114+ }
115+ }
94116 }
95117
96118 @ Override
97119 public CompletableFuture <StateResponse > handle (StateRequest .Builder requestBuilder ) {
98120 requestBuilder .setId (idGenerator .getId ());
99121 StateRequest request = requestBuilder .build ();
100122 CompletableFuture <StateResponse > response = new CompletableFuture <>();
101- outstandingRequests .put (request .getId (), response );
123+ synchronized (lock ) {
124+ if (closed != null ) {
125+ response .completeExceptionally (closed );
126+ return response ;
127+ }
128+ outstandingRequests .put (request .getId (), response );
129+ }
102130
103131 // If the server closes, gRPC will throw an error if onNext is called.
104132 LOG .debug ("Sending StateRequest {}" , request );
105133 outboundObserver .onNext (request );
106134 return response ;
107135 }
108136
109- private synchronized void closeAndCleanUp (RuntimeException cause ) {
110- if (closed != null ) {
111- return ;
112- }
113- cache .remove (apiServiceDescriptor );
114- closed = cause ;
115-
116- // Make a copy of the map to make the view of the outstanding requests consistent.
117- Map <String , CompletableFuture <StateResponse >> outstandingRequestsCopy =
118- new ConcurrentHashMap <>(outstandingRequests );
137+ private void closeAndCleanUp (RuntimeException cause ) {
138+ synchronized (lock ) {
139+ if (closed != null ) {
140+ return ;
141+ }
142+ closed = cause ;
119143
120- if (outstandingRequestsCopy .isEmpty ()) {
121- outboundObserver .onCompleted ();
122- return ;
123- }
144+ synchronized (cache ) {
145+ cache .remove (apiServiceDescriptor );
146+ }
124147
125- outstandingRequests .clear ();
126- LOG .error ("BeamFnState failed, clearing outstanding requests {}" , outstandingRequestsCopy );
148+ if (!outstandingRequests .isEmpty ()) {
149+ LOG .error ("BeamFnState failed, clearing outstanding requests {}" , outstandingRequests );
150+ for (CompletableFuture <StateResponse > entry : outstandingRequests .values ()) {
151+ entry .completeExceptionally (cause );
152+ }
153+ outstandingRequests .clear ();
154+ }
127155
128- for (CompletableFuture <StateResponse > entry : outstandingRequestsCopy .values ()) {
129- entry .completeExceptionally (cause );
156+ // Due to safe object publishing, outboundObserver may be null since InboundObserver may
157+ // call closeAndCleanUp before the GrpcStateClient finishes construction. In this case
158+ // we defer invoking onCompleted to the GrpcStateClient constructor.
159+ if (outboundObserver == null ) {
160+ errorDuringConstruction = true ;
161+ } else {
162+ outboundObserver .onCompleted ();
163+ }
130164 }
131165 }
132166
@@ -143,7 +177,10 @@ private class InboundObserver implements StreamObserver<StateResponse> {
143177 @ Override
144178 public void onNext (StateResponse value ) {
145179 LOG .debug ("Received StateResponse {}" , value );
146- CompletableFuture <StateResponse > responseFuture = outstandingRequests .remove (value .getId ());
180+ CompletableFuture <StateResponse > responseFuture ;
181+ synchronized (lock ) {
182+ responseFuture = outstandingRequests .remove (value .getId ());
183+ }
147184 if (responseFuture == null ) {
148185 LOG .warn ("Dropped unknown StateResponse {}" , value );
149186 return ;
0 commit comments