1919
2020import static org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .collect .ImmutableList .toImmutableList ;
2121
22- import com .google .auto .value .AutoOneOf ;
2322import java .util .ArrayList ;
2423import java .util .Collections ;
2524import java .util .Comparator ;
25+ import java .util .HashSet ;
2626import java .util .List ;
2727import java .util .concurrent .CountDownLatch ;
28- import java . util . stream . Stream ;
28+ import javax . annotation . Nullable ;
2929import org .apache .beam .runners .dataflow .worker .windmill .Windmill ;
30- import org .apache .beam .runners .dataflow .worker .windmill .Windmill .ComputationGetDataRequest ;
3130import org .apache .beam .runners .dataflow .worker .windmill .Windmill .GlobalDataRequest ;
3231import org .apache .beam .runners .dataflow .worker .windmill .Windmill .KeyedGetDataRequest ;
3332import org .apache .beam .runners .dataflow .worker .windmill .client .WindmillStreamShutdownException ;
33+ import org .apache .beam .sdk .util .Preconditions ;
3434import org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .collect .ImmutableList ;
3535import org .slf4j .Logger ;
3636import org .slf4j .LoggerFactory ;
@@ -46,15 +46,42 @@ private static String debugFormat(long value) {
4646 return String .format ("%016x" , value );
4747 }
4848
49+ static class ComputationAndKeyRequest {
50+ private final String computation ;
51+ private final KeyedGetDataRequest request ;
52+
53+ ComputationAndKeyRequest (String computation , KeyedGetDataRequest request ) {
54+ this .computation = computation ;
55+ this .request = request ;
56+ }
57+
58+ String getComputation () {
59+ return computation ;
60+ }
61+
62+ KeyedGetDataRequest getKeyedGetDataRequest () {
63+ return request ;
64+ }
65+ }
66+
4967 static class QueuedRequest {
5068 private final long id ;
51- private final ComputationOrGlobalDataRequest dataRequest ;
69+ private final @ Nullable ComputationAndKeyRequest computationAndKeyRequest ;
70+ private final @ Nullable GlobalDataRequest globalDataRequest ;
5271 private AppendableInputStream responseStream ;
5372
73+ private QueuedRequest (long id , GlobalDataRequest globalDataRequest , long deadlineSeconds ) {
74+ this .id = id ;
75+ this .computationAndKeyRequest = null ;
76+ this .globalDataRequest = globalDataRequest ;
77+ responseStream = new AppendableInputStream (deadlineSeconds );
78+ }
79+
5480 private QueuedRequest (
55- long id , ComputationOrGlobalDataRequest dataRequest , long deadlineSeconds ) {
81+ long id , ComputationAndKeyRequest computationAndKeyRequest , long deadlineSeconds ) {
5682 this .id = id ;
57- this .dataRequest = dataRequest ;
83+ this .computationAndKeyRequest = computationAndKeyRequest ;
84+ this .globalDataRequest = null ;
5885 responseStream = new AppendableInputStream (deadlineSeconds );
5986 }
6087
@@ -63,27 +90,19 @@ static QueuedRequest forComputation(
6390 String computation ,
6491 KeyedGetDataRequest keyedGetDataRequest ,
6592 long deadlineSeconds ) {
66- ComputationGetDataRequest computationGetDataRequest =
67- ComputationGetDataRequest .newBuilder ()
68- .setComputationId (computation )
69- .addRequests (keyedGetDataRequest )
70- .build ();
7193 return new QueuedRequest (
72- id ,
73- ComputationOrGlobalDataRequest .computation (computationGetDataRequest ),
74- deadlineSeconds );
94+ id , new ComputationAndKeyRequest (computation , keyedGetDataRequest ), deadlineSeconds );
7595 }
7696
7797 static QueuedRequest global (
7898 long id , GlobalDataRequest globalDataRequest , long deadlineSeconds ) {
79- return new QueuedRequest (
80- id , ComputationOrGlobalDataRequest .global (globalDataRequest ), deadlineSeconds );
99+ return new QueuedRequest (id , globalDataRequest , deadlineSeconds );
81100 }
82101
83102 static Comparator <QueuedRequest > globalRequestsFirst () {
84103 return (QueuedRequest r1 , QueuedRequest r2 ) -> {
85- boolean r1gd = r1 .dataRequest . isGlobal () ;
86- boolean r2gd = r2 .dataRequest . isGlobal () ;
104+ boolean r1gd = r1 .getKind () == Kind . GLOBAL ;
105+ boolean r2gd = r2 .getKind () == Kind . GLOBAL ;
87106 return r1gd == r2gd ? 0 : (r1gd ? -1 : 1 );
88107 };
89108 }
@@ -93,7 +112,13 @@ long id() {
93112 }
94113
95114 long byteSize () {
96- return dataRequest .serializedSize ();
115+ if (globalDataRequest != null ) {
116+ return globalDataRequest .getSerializedSize ();
117+ }
118+ Preconditions .checkStateNotNull (computationAndKeyRequest );
119+ return 10L
120+ + computationAndKeyRequest .request .getSerializedSize ()
121+ + computationAndKeyRequest .getComputation ().length ();
97122 }
98123
99124 AppendableInputStream getResponseStream () {
@@ -104,22 +129,56 @@ void resetResponseStream() {
104129 this .responseStream = new AppendableInputStream (responseStream .getDeadlineSeconds ());
105130 }
106131
107- public ComputationOrGlobalDataRequest getDataRequest () {
108- return dataRequest ;
132+ enum Kind {
133+ COMPUTATION_AND_KEY_REQUEST ,
134+ GLOBAL
135+ }
136+
137+ Kind getKind () {
138+ return computationAndKeyRequest != null ? Kind .COMPUTATION_AND_KEY_REQUEST : Kind .GLOBAL ;
139+ }
140+
141+ ComputationAndKeyRequest getComputationAndKeyRequest () {
142+ return Preconditions .checkStateNotNull (computationAndKeyRequest );
143+ }
144+
145+ GlobalDataRequest getGlobalDataRequest () {
146+ return Preconditions .checkStateNotNull (globalDataRequest );
109147 }
110148
111149 void addToStreamingGetDataRequest (Windmill .StreamingGetDataRequest .Builder builder ) {
112150 builder .addRequestId (id );
113- if (dataRequest .isForComputation ()) {
114- builder .addStateRequest (dataRequest .computation ());
115- } else {
116- builder .addGlobalDataRequest (dataRequest .global ());
151+ switch (getKind ()) {
152+ case COMPUTATION_AND_KEY_REQUEST :
153+ ComputationAndKeyRequest request = getComputationAndKeyRequest ();
154+ builder
155+ .addStateRequestBuilder ()
156+ .setComputationId (request .getComputation ())
157+ .addRequests (request .request );
158+ break ;
159+ case GLOBAL :
160+ builder .addGlobalDataRequest (getGlobalDataRequest ());
161+ break ;
117162 }
118163 }
119164
120165 @ Override
121166 public final String toString () {
122- return "QueuedRequest{" + "dataRequest=" + dataRequest + ", id=" + id + '}' ;
167+ StringBuilder result = new StringBuilder ("QueuedRequest{id=" ).append (id ).append (", " );
168+ if (getKind () == Kind .GLOBAL ) {
169+ result .append ("GetSideInput=" ).append (getGlobalDataRequest ());
170+ } else {
171+ KeyedGetDataRequest key = getComputationAndKeyRequest ().request ;
172+ result
173+ .append ("KeyedGetState=[shardingKey=" )
174+ .append (debugFormat (key .getShardingKey ()))
175+ .append ("cacheToken=" )
176+ .append (debugFormat (key .getCacheToken ()))
177+ .append ("workToken" )
178+ .append (debugFormat (key .getWorkToken ()))
179+ .append ("]" );
180+ }
181+ return result .append ('}' ).toString ();
123182 }
124183 }
125184
@@ -128,13 +187,14 @@ public final String toString() {
128187 */
129188 static class QueuedBatch {
130189 private final List <QueuedRequest > requests = new ArrayList <>();
190+ private final HashSet <Long > workTokens = new HashSet <>();
131191 private final CountDownLatch sent = new CountDownLatch (1 );
132192 private long byteSize = 0 ;
133193 private volatile boolean finalized = false ;
134194 private volatile boolean failed = false ;
135195
136196 /** Returns a read-only view of requests. */
137- List <QueuedRequest > requestsReadOnly () {
197+ List <QueuedRequest > requestsView () {
138198 return Collections .unmodifiableList (requests );
139199 }
140200
@@ -155,18 +215,10 @@ Windmill.StreamingGetDataRequest asGetDataRequest() {
155215 return builder .build ();
156216 }
157217
158- boolean isEmpty () {
159- return requests .isEmpty ();
160- }
161-
162218 int requestsCount () {
163219 return requests .size ();
164220 }
165221
166- long byteSize () {
167- return byteSize ;
168- }
169-
170222 boolean isFinalized () {
171223 return finalized ;
172224 }
@@ -176,9 +228,26 @@ void markFinalized() {
176228 }
177229
178230 /** Adds a request to the batch. */
179- void addRequest (QueuedRequest request ) {
231+ boolean tryAddRequest (QueuedRequest request , int countLimit , long byteLimit ) {
232+ if (finalized ) {
233+ return false ;
234+ }
235+ if (requests .size () >= countLimit ) {
236+ return false ;
237+ }
238+ long estimatedBytes = request .byteSize ();
239+ if (byteSize + estimatedBytes >= byteLimit ) {
240+ return false ;
241+ }
242+
243+ if (request .getKind () == QueuedRequest .Kind .COMPUTATION_AND_KEY_REQUEST
244+ && !workTokens .add (request .getComputationAndKeyRequest ().request .getWorkToken ())) {
245+ return false ;
246+ }
247+ // At this point we have added to work items so we must accept the item.
180248 requests .add (request );
181- byteSize += request .byteSize ();
249+ byteSize += estimatedBytes ;
250+ return true ;
182251 }
183252
184253 /**
@@ -227,75 +296,9 @@ void waitForSendOrFailNotification()
227296
228297 private ImmutableList <String > createStreamCancelledErrorMessages () {
229298 return requests .stream ()
230- .flatMap (
231- request -> {
232- switch (request .getDataRequest ().getKind ()) {
233- case GLOBAL :
234- return Stream .of ("GetSideInput=" + request .getDataRequest ().global ());
235- case COMPUTATION :
236- return request .getDataRequest ().computation ().getRequestsList ().stream ()
237- .map (
238- keyedRequest ->
239- "KeyedGetState=["
240- + "shardingKey="
241- + debugFormat (keyedRequest .getShardingKey ())
242- + "cacheToken="
243- + debugFormat (keyedRequest .getCacheToken ())
244- + "workToken"
245- + debugFormat (keyedRequest .getWorkToken ())
246- + "]" );
247- default :
248- // Will never happen switch is exhaustive.
249- throw new IllegalStateException ();
250- }
251- })
299+ .map (QueuedRequest ::toString )
252300 .limit (STREAM_CANCELLED_ERROR_LOG_LIMIT )
253301 .collect (toImmutableList ());
254302 }
255303 }
256-
257- @ AutoOneOf (ComputationOrGlobalDataRequest .Kind .class )
258- abstract static class ComputationOrGlobalDataRequest {
259- static ComputationOrGlobalDataRequest computation (
260- ComputationGetDataRequest computationGetDataRequest ) {
261- return AutoOneOf_GrpcGetDataStreamRequests_ComputationOrGlobalDataRequest .computation (
262- computationGetDataRequest );
263- }
264-
265- static ComputationOrGlobalDataRequest global (GlobalDataRequest globalDataRequest ) {
266- return AutoOneOf_GrpcGetDataStreamRequests_ComputationOrGlobalDataRequest .global (
267- globalDataRequest );
268- }
269-
270- abstract Kind getKind ();
271-
272- abstract ComputationGetDataRequest computation ();
273-
274- abstract GlobalDataRequest global ();
275-
276- boolean isGlobal () {
277- return getKind () == Kind .GLOBAL ;
278- }
279-
280- boolean isForComputation () {
281- return getKind () == Kind .COMPUTATION ;
282- }
283-
284- long serializedSize () {
285- switch (getKind ()) {
286- case GLOBAL :
287- return global ().getSerializedSize ();
288- case COMPUTATION :
289- return computation ().getSerializedSize ();
290- // this will never happen since the switch is exhaustive.
291- default :
292- throw new UnsupportedOperationException ("unknown dataRequest type." );
293- }
294- }
295-
296- enum Kind {
297- COMPUTATION ,
298- GLOBAL
299- }
300- }
301304}
0 commit comments