4545import org .apache .beam .sdk .fn .stream .PrefetchableIterables ;
4646import org .apache .beam .sdk .fn .stream .PrefetchableIterator ;
4747import org .apache .beam .sdk .util .Weighted ;
48+ import org .apache .beam .vendor .grpc .v1p69p0 .com .google .common .math .LongMath ;
4849import org .apache .beam .vendor .grpc .v1p69p0 .com .google .protobuf .ByteString ;
4950import org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .annotations .VisibleForTesting ;
5051import org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .base .Throwables ;
5152import org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .collect .AbstractIterator ;
53+ import org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .collect .ImmutableList ;
5254
5355/**
5456 * Adapters which convert a logical series of chunks using continuation tokens over the Beam Fn
@@ -249,15 +251,11 @@ static class BlocksPrefix<T> extends Blocks<T> implements Shrinkable<BlocksPrefi
249251
250252 @ Override
251253 public long getWeight () {
252- try {
253- long sum = 8 + blocks .size () * 8L ;
254- for (Block <T > block : blocks ) {
255- sum = Math .addExact (sum , block .getWeight ());
256- }
257- return sum ;
258- } catch (ArithmeticException e ) {
259- return Long .MAX_VALUE ;
254+ long sum = 8 + blocks .size () * 8L ;
255+ for (Block <T > block : blocks ) {
256+ sum = LongMath .saturatedAdd (sum , block .getWeight ());
260257 }
258+ return sum ;
261259 }
262260
263261 BlocksPrefix (List <Block <T >> blocks ) {
@@ -282,8 +280,7 @@ public List<Block<T>> getBlocks() {
282280
283281 @ AutoValue
284282 abstract static class Block <T > implements Weighted {
285- private static final Block <Void > EMPTY =
286- fromValues (WeightedList .of (Collections .emptyList (), 0 ), null );
283+ private static final Block <Void > EMPTY = fromValues (ImmutableList .of (), 0 , null );
287284
288285 @ SuppressWarnings ("unchecked" ) // Based upon as Collections.emptyList()
289286 public static <T > Block <T > emptyBlock () {
@@ -299,21 +296,37 @@ public static <T> Block<T> mutatedBlock(WeightedList<T> values) {
299296 }
300297
301298 public static <T > Block <T > fromValues (List <T > values , @ Nullable ByteString nextToken ) {
302- return fromValues (WeightedList .of (values , Caches .weigh (values )), nextToken );
299+ if (values .isEmpty () && nextToken == null ) {
300+ return emptyBlock ();
301+ }
302+ ImmutableList <T > immutableValues = ImmutableList .copyOf (values );
303+ long listWeight = immutableValues .size () * Caches .REFERENCE_SIZE ;
304+ for (T value : immutableValues ) {
305+ listWeight = LongMath .saturatedAdd (listWeight , Caches .weigh (value ));
306+ }
307+ return fromValues (immutableValues , listWeight , nextToken );
303308 }
304309
305310 public static <T > Block <T > fromValues (
306311 WeightedList <T > values , @ Nullable ByteString nextToken ) {
307- long weight = values .getWeight () + 24 ;
312+ if (values .isEmpty () && nextToken == null ) {
313+ return emptyBlock ();
314+ }
315+ return fromValues (ImmutableList .copyOf (values .getBacking ()), values .getWeight (), nextToken );
316+ }
317+
318+ private static <T > Block <T > fromValues (
319+ ImmutableList <T > values , long listWeight , @ Nullable ByteString nextToken ) {
320+ long weight = LongMath .saturatedAdd (listWeight , 24 );
308321 if (nextToken != null ) {
309322 if (nextToken .isEmpty ()) {
310323 nextToken = ByteString .EMPTY ;
311324 } else {
312- weight += Caches .weigh (nextToken );
325+ weight = LongMath . saturatedAdd ( weight , Caches .weigh (nextToken ) );
313326 }
314327 }
315328 return new AutoValue_StateFetchingIterators_CachingStateIterable_Block <>(
316- values . getBacking () , nextToken , weight );
329+ values , nextToken , weight );
317330 }
318331
319332 abstract List <T > getValues ();
@@ -372,10 +385,12 @@ public void remove(Set<Object> toRemoveStructuralValues) {
372385 totalSize += tBlock .getValues ().size ();
373386 }
374387
375- WeightedList <T > allValues = WeightedList .of (new ArrayList <>(totalSize ), 0L );
388+ ImmutableList .Builder <T > allValues = ImmutableList .builderWithExpectedSize (totalSize );
389+ long weight = 0 ;
390+ List <T > blockValuesToKeep = new ArrayList <>();
376391 for (Block <T > block : blocks ) {
392+ blockValuesToKeep .clear ();
377393 boolean valueRemovedFromBlock = false ;
378- List <T > blockValuesToKeep = new ArrayList <>();
379394 for (T value : block .getValues ()) {
380395 if (!toRemoveStructuralValues .contains (valueCoder .structuralValue (value ))) {
381396 blockValuesToKeep .add (value );
@@ -387,13 +402,19 @@ public void remove(Set<Object> toRemoveStructuralValues) {
387402 // If any value was removed from this block, need to estimate the weight again.
388403 // Otherwise, just reuse the block's weight.
389404 if (valueRemovedFromBlock ) {
390- allValues .addAll (blockValuesToKeep , Caches .weigh (block .getValues ()));
405+ allValues .addAll (blockValuesToKeep );
406+ for (T value : blockValuesToKeep ) {
407+ weight = LongMath .saturatedAdd (weight , Caches .weigh (value ));
408+ }
391409 } else {
392- allValues .addAll (block .getValues (), block .getWeight ());
410+ allValues .addAll (block .getValues ());
411+ weight = LongMath .saturatedAdd (weight , block .getWeight ());
393412 }
394413 }
395414
396- cache .put (IterableCacheKey .INSTANCE , new MutatedBlocks <>(Block .mutatedBlock (allValues )));
415+ cache .put (
416+ IterableCacheKey .INSTANCE ,
417+ new MutatedBlocks <>(Block .fromValues (allValues .build (), weight , null )));
397418 }
398419
399420 /**
@@ -484,21 +505,22 @@ private void appendHelper(List<T> newValues, long newWeight) {
484505 for (Block <T > block : blocks ) {
485506 totalSize += block .getValues ().size ();
486507 }
487- WeightedList <T > allValues = WeightedList .of (new ArrayList <>(totalSize ), 0L );
508+ ImmutableList .Builder <T > allValues = ImmutableList .builderWithExpectedSize (totalSize );
509+ long weight = 0 ;
488510 for (Block <T > block : blocks ) {
489- allValues .addAll (block .getValues (), block .getWeight ());
511+ allValues .addAll (block .getValues ());
512+ weight = LongMath .saturatedAdd (weight , block .getWeight ());
490513 }
491514 if (newWeight < 0 ) {
492- if (newValues .size () == 1 ) {
493- // Optimize weighing of the common value state as single single-element bag state.
494- newWeight = Caches .weigh (newValues .get (0 ));
495- } else {
496- newWeight = Caches .weigh (newValues );
515+ newWeight = 0 ;
516+ for (T value : newValues ) {
517+ newWeight = LongMath .saturatedAdd (newWeight , Caches .weigh (value ));
497518 }
498519 }
499- allValues .addAll (newValues , newWeight );
520+ allValues .addAll (newValues );
521+ weight = LongMath .saturatedAdd (weight , newWeight );
500522
501- cache .put (IterableCacheKey .INSTANCE , new MutatedBlocks <>(Block .mutatedBlock (allValues )));
523+ cache .put (IterableCacheKey .INSTANCE , new MutatedBlocks <>(Block .fromValues (allValues . build (), weight , null )));
502524 }
503525
504526 class CachingStateIterator implements PrefetchableIterator <T > {
@@ -580,8 +602,7 @@ public boolean hasNext() {
580602 return false ;
581603 }
582604 // Release the block while we are loading the next one.
583- currentBlock =
584- Block .fromValues (WeightedList .of (Collections .emptyList (), 0L ), ByteString .EMPTY );
605+ currentBlock = Block .emptyBlock ();
585606
586607 @ Nullable Blocks <T > existing = cache .peek (IterableCacheKey .INSTANCE );
587608 boolean isFirstBlock = ByteString .EMPTY .equals (nextToken );
0 commit comments