4949import org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .annotations .VisibleForTesting ;
5050import org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .base .Throwables ;
5151import org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .collect .AbstractIterator ;
52- import org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .collect .ImmutableList ;
53- import org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .math .LongMath ;
5452
5553/**
5654 * Adapters which convert a logical series of chunks using continuation tokens over the Beam Fn
@@ -251,11 +249,15 @@ static class BlocksPrefix<T> extends Blocks<T> implements Shrinkable<BlocksPrefi
251249
252250 @ Override
253251 public long getWeight () {
254- long sum = 8 + blocks .size () * 8L ;
255- for (Block <T > block : blocks ) {
256- sum = LongMath .saturatedAdd (sum , block .getWeight ());
252+ try {
253+ long sum = 8 + blocks .size () * 8L ;
254+ for (Block <T > block : blocks ) {
255+ sum = Math .addExact (sum , block .getWeight ());
256+ }
257+ return sum ;
258+ } catch (ArithmeticException e ) {
259+ return Long .MAX_VALUE ;
257260 }
258- return sum ;
259261 }
260262
261263 BlocksPrefix (List <Block <T >> blocks ) {
@@ -280,7 +282,8 @@ public List<Block<T>> getBlocks() {
280282
281283 @ AutoValue
282284 abstract static class Block <T > implements Weighted {
283- private static final Block <Void > EMPTY = fromValues (ImmutableList .of (), 0 , null );
285+ private static final Block <Void > EMPTY =
286+ fromValues (WeightedList .of (Collections .emptyList (), 0 ), null );
284287
285288 @ SuppressWarnings ("unchecked" ) // Based upon as Collections.emptyList()
286289 public static <T > Block <T > emptyBlock () {
@@ -296,37 +299,21 @@ public static <T> Block<T> mutatedBlock(WeightedList<T> values) {
296299 }
297300
298301 public static <T > Block <T > fromValues (List <T > values , @ Nullable ByteString nextToken ) {
299- if (values .isEmpty () && nextToken == null ) {
300- return emptyBlock ();
301- }
302- ImmutableList <T > immutableValues = ImmutableList .copyOf (values );
303- long listWeight = immutableValues .size () * Caches .REFERENCE_SIZE ;
304- for (T value : immutableValues ) {
305- listWeight = LongMath .saturatedAdd (listWeight , Caches .weigh (value ));
306- }
307- return fromValues (immutableValues , listWeight , nextToken );
302+ return fromValues (WeightedList .of (values , Caches .weigh (values )), nextToken );
308303 }
309304
310305 public static <T > Block <T > fromValues (
311306 WeightedList <T > values , @ Nullable ByteString nextToken ) {
312- if (values .isEmpty () && nextToken == null ) {
313- return emptyBlock ();
314- }
315- return fromValues (ImmutableList .copyOf (values .getBacking ()), values .getWeight (), nextToken );
316- }
317-
318- private static <T > Block <T > fromValues (
319- ImmutableList <T > values , long listWeight , @ Nullable ByteString nextToken ) {
320- long weight = LongMath .saturatedAdd (listWeight , 24 );
307+ long weight = values .getWeight () + 24 ;
321308 if (nextToken != null ) {
322309 if (nextToken .isEmpty ()) {
323310 nextToken = ByteString .EMPTY ;
324311 } else {
325- weight = LongMath . saturatedAdd ( weight , Caches .weigh (nextToken ) );
312+ weight += Caches .weigh (nextToken );
326313 }
327314 }
328315 return new AutoValue_StateFetchingIterators_CachingStateIterable_Block <>(
329- values , nextToken , weight );
316+ values . getBacking () , nextToken , weight );
330317 }
331318
332319 abstract List <T > getValues ();
@@ -385,12 +372,10 @@ public void remove(Set<Object> toRemoveStructuralValues) {
385372 totalSize += tBlock .getValues ().size ();
386373 }
387374
388- ImmutableList .Builder <T > allValues = ImmutableList .builderWithExpectedSize (totalSize );
389- long weight = 0 ;
390- List <T > blockValuesToKeep = new ArrayList <>();
375+ WeightedList <T > allValues = WeightedList .of (new ArrayList <>(totalSize ), 0L );
391376 for (Block <T > block : blocks ) {
392- blockValuesToKeep .clear ();
393377 boolean valueRemovedFromBlock = false ;
378+ List <T > blockValuesToKeep = new ArrayList <>();
394379 for (T value : block .getValues ()) {
395380 if (!toRemoveStructuralValues .contains (valueCoder .structuralValue (value ))) {
396381 blockValuesToKeep .add (value );
@@ -402,19 +387,13 @@ public void remove(Set<Object> toRemoveStructuralValues) {
402387 // If any value was removed from this block, need to estimate the weight again.
403388 // Otherwise, just reuse the block's weight.
404389 if (valueRemovedFromBlock ) {
405- allValues .addAll (blockValuesToKeep );
406- for (T value : blockValuesToKeep ) {
407- weight = LongMath .saturatedAdd (weight , Caches .weigh (value ));
408- }
390+ allValues .addAll (blockValuesToKeep , Caches .weigh (block .getValues ()));
409391 } else {
410- allValues .addAll (block .getValues ());
411- weight = LongMath .saturatedAdd (weight , block .getWeight ());
392+ allValues .addAll (block .getValues (), block .getWeight ());
412393 }
413394 }
414395
415- cache .put (
416- IterableCacheKey .INSTANCE ,
417- new MutatedBlocks <>(Block .fromValues (allValues .build (), weight , null )));
396+ cache .put (IterableCacheKey .INSTANCE , new MutatedBlocks <>(Block .mutatedBlock (allValues )));
418397 }
419398
420399 /**
@@ -505,24 +484,21 @@ private void appendHelper(List<T> newValues, long newWeight) {
505484 for (Block <T > block : blocks ) {
506485 totalSize += block .getValues ().size ();
507486 }
508- ImmutableList .Builder <T > allValues = ImmutableList .builderWithExpectedSize (totalSize );
509- long weight = 0 ;
487+ WeightedList <T > allValues = WeightedList .of (new ArrayList <>(totalSize ), 0L );
510488 for (Block <T > block : blocks ) {
511- allValues .addAll (block .getValues ());
512- weight = LongMath .saturatedAdd (weight , block .getWeight ());
489+ allValues .addAll (block .getValues (), block .getWeight ());
513490 }
514491 if (newWeight < 0 ) {
515- newWeight = 0 ;
516- for (T value : newValues ) {
517- newWeight = LongMath .saturatedAdd (newWeight , Caches .weigh (value ));
492+ if (newValues .size () == 1 ) {
493+ // Optimize weighing of the common value state as single single-element bag state.
494+ newWeight = Caches .weigh (newValues .get (0 ));
495+ } else {
496+ newWeight = Caches .weigh (newValues );
518497 }
519498 }
520- allValues .addAll (newValues );
521- weight = LongMath .saturatedAdd (weight , newWeight );
499+ allValues .addAll (newValues , newWeight );
522500
523- cache .put (
524- IterableCacheKey .INSTANCE ,
525- new MutatedBlocks <>(Block .fromValues (allValues .build (), weight , null )));
501+ cache .put (IterableCacheKey .INSTANCE , new MutatedBlocks <>(Block .mutatedBlock (allValues )));
526502 }
527503
528504 class CachingStateIterator implements PrefetchableIterator <T > {
@@ -604,7 +580,8 @@ public boolean hasNext() {
604580 return false ;
605581 }
606582 // Release the block while we are loading the next one.
607- currentBlock = Block .emptyBlock ();
583+ currentBlock =
584+ Block .fromValues (WeightedList .of (Collections .emptyList (), 0L ), ByteString .EMPTY );
608585
609586 @ Nullable Blocks <T > existing = cache .peek (IterableCacheKey .INSTANCE );
610587 boolean isFirstBlock = ByteString .EMPTY .equals (nextToken );
0 commit comments