diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java index 5eb317fc2875..ad5e131cb2d7 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java @@ -20,7 +20,6 @@ import java.util.List; import java.util.concurrent.atomic.AtomicLong; import org.apache.beam.sdk.util.Weighted; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath; /** Facade for a {@link List} that keeps track of weight, for cache limit reasons. */ public class WeightedList implements Weighted { @@ -72,6 +71,14 @@ public void addAll(List values, long weight) { } public void accumulateWeight(long weight) { - this.weight.accumulateAndGet(weight, LongMath::saturatedAdd); + this.weight.accumulateAndGet( + weight, + (first, second) -> { + try { + return Math.addExact(first, second); + } catch (ArithmeticException e) { + return Long.MAX_VALUE; + } + }); } } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java index 339ddad4061e..1e06c98f2e31 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java @@ -49,8 +49,6 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Throwables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath; /** * Adapters which convert a logical series of chunks using continuation tokens over the Beam Fn @@ -251,11 +249,15 @@ static class BlocksPrefix extends Blocks implements Shrinkable block : blocks) { - sum = LongMath.saturatedAdd(sum, block.getWeight()); + try { + long sum = 8 + blocks.size() * 8L; + for (Block block : blocks) { + sum = Math.addExact(sum, block.getWeight()); + } + return sum; + } catch (ArithmeticException e) { + return Long.MAX_VALUE; } - return sum; } BlocksPrefix(List> blocks) { @@ -280,7 +282,8 @@ public List> getBlocks() { @AutoValue abstract static class Block implements Weighted { - private static final Block EMPTY = fromValues(ImmutableList.of(), 0, null); + private static final Block EMPTY = + fromValues(WeightedList.of(Collections.emptyList(), 0), null); @SuppressWarnings("unchecked") // Based upon as Collections.emptyList() public static Block emptyBlock() { @@ -296,37 +299,21 @@ public static Block mutatedBlock(WeightedList values) { } public static Block fromValues(List values, @Nullable ByteString nextToken) { - if (values.isEmpty() && nextToken == null) { - return emptyBlock(); - } - ImmutableList immutableValues = ImmutableList.copyOf(values); - long listWeight = immutableValues.size() * Caches.REFERENCE_SIZE; - for (T value : immutableValues) { - listWeight = LongMath.saturatedAdd(listWeight, Caches.weigh(value)); - } - return fromValues(immutableValues, listWeight, nextToken); + return fromValues(WeightedList.of(values, Caches.weigh(values)), nextToken); } public static Block fromValues( WeightedList values, @Nullable ByteString nextToken) { - if (values.isEmpty() && nextToken == null) { - return emptyBlock(); - } - return fromValues(ImmutableList.copyOf(values.getBacking()), values.getWeight(), nextToken); - } - - private static Block fromValues( - ImmutableList values, long listWeight, @Nullable ByteString nextToken) { - long weight = LongMath.saturatedAdd(listWeight, 24); + long weight = values.getWeight() + 24; if (nextToken != null) { if (nextToken.isEmpty()) { nextToken = ByteString.EMPTY; } else { - weight = LongMath.saturatedAdd(weight, Caches.weigh(nextToken)); + weight += Caches.weigh(nextToken); } } return new AutoValue_StateFetchingIterators_CachingStateIterable_Block<>( - values, nextToken, weight); + values.getBacking(), nextToken, weight); } abstract List getValues(); @@ -385,12 +372,10 @@ public void remove(Set toRemoveStructuralValues) { totalSize += tBlock.getValues().size(); } - ImmutableList.Builder allValues = ImmutableList.builderWithExpectedSize(totalSize); - long weight = 0; - List blockValuesToKeep = new ArrayList<>(); + WeightedList allValues = WeightedList.of(new ArrayList<>(totalSize), 0L); for (Block block : blocks) { - blockValuesToKeep.clear(); boolean valueRemovedFromBlock = false; + List blockValuesToKeep = new ArrayList<>(); for (T value : block.getValues()) { if (!toRemoveStructuralValues.contains(valueCoder.structuralValue(value))) { blockValuesToKeep.add(value); @@ -402,19 +387,13 @@ public void remove(Set toRemoveStructuralValues) { // If any value was removed from this block, need to estimate the weight again. // Otherwise, just reuse the block's weight. if (valueRemovedFromBlock) { - allValues.addAll(blockValuesToKeep); - for (T value : blockValuesToKeep) { - weight = LongMath.saturatedAdd(weight, Caches.weigh(value)); - } + allValues.addAll(blockValuesToKeep, Caches.weigh(block.getValues())); } else { - allValues.addAll(block.getValues()); - weight = LongMath.saturatedAdd(weight, block.getWeight()); + allValues.addAll(block.getValues(), block.getWeight()); } } - cache.put( - IterableCacheKey.INSTANCE, - new MutatedBlocks<>(Block.fromValues(allValues.build(), weight, null))); + cache.put(IterableCacheKey.INSTANCE, new MutatedBlocks<>(Block.mutatedBlock(allValues))); } /** @@ -505,24 +484,21 @@ private void appendHelper(List newValues, long newWeight) { for (Block block : blocks) { totalSize += block.getValues().size(); } - ImmutableList.Builder allValues = ImmutableList.builderWithExpectedSize(totalSize); - long weight = 0; + WeightedList allValues = WeightedList.of(new ArrayList<>(totalSize), 0L); for (Block block : blocks) { - allValues.addAll(block.getValues()); - weight = LongMath.saturatedAdd(weight, block.getWeight()); + allValues.addAll(block.getValues(), block.getWeight()); } if (newWeight < 0) { - newWeight = 0; - for (T value : newValues) { - newWeight = LongMath.saturatedAdd(newWeight, Caches.weigh(value)); + if (newValues.size() == 1) { + // Optimize weighing of the common value state as single single-element bag state. + newWeight = Caches.weigh(newValues.get(0)); + } else { + newWeight = Caches.weigh(newValues); } } - allValues.addAll(newValues); - weight = LongMath.saturatedAdd(weight, newWeight); + allValues.addAll(newValues, newWeight); - cache.put( - IterableCacheKey.INSTANCE, - new MutatedBlocks<>(Block.fromValues(allValues.build(), weight, null))); + cache.put(IterableCacheKey.INSTANCE, new MutatedBlocks<>(Block.mutatedBlock(allValues))); } class CachingStateIterator implements PrefetchableIterator { @@ -604,7 +580,8 @@ public boolean hasNext() { return false; } // Release the block while we are loading the next one. - currentBlock = Block.emptyBlock(); + currentBlock = + Block.fromValues(WeightedList.of(Collections.emptyList(), 0L), ByteString.EMPTY); @Nullable Blocks existing = cache.peek(IterableCacheKey.INSTANCE); boolean isFirstBlock = ByteString.EMPTY.equals(nextToken);