From 5e7438106a839d881ce230beb480e3545a6233da Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Tue, 25 Nov 2025 14:30:36 +0100 Subject: [PATCH] Reduce weighing overhead for caching blocks --- .../apache/beam/sdk/fn/data/WeightedList.java | 11 +-- .../harness/state/StateFetchingIterators.java | 83 ++++++++++++------- 2 files changed, 55 insertions(+), 39 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java index ad5e131cb2d7..5eb317fc2875 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java @@ -20,6 +20,7 @@ import java.util.List; import java.util.concurrent.atomic.AtomicLong; import org.apache.beam.sdk.util.Weighted; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath; /** Facade for a {@link List} that keeps track of weight, for cache limit reasons. */ public class WeightedList implements Weighted { @@ -71,14 +72,6 @@ public void addAll(List values, long weight) { } public void accumulateWeight(long weight) { - this.weight.accumulateAndGet( - weight, - (first, second) -> { - try { - return Math.addExact(first, second); - } catch (ArithmeticException e) { - return Long.MAX_VALUE; - } - }); + this.weight.accumulateAndGet(weight, LongMath::saturatedAdd); } } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java index 1e06c98f2e31..339ddad4061e 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java @@ -49,6 +49,8 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Throwables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath; /** * Adapters which convert a logical series of chunks using continuation tokens over the Beam Fn @@ -249,15 +251,11 @@ static class BlocksPrefix extends Blocks implements Shrinkable block : blocks) { - sum = Math.addExact(sum, block.getWeight()); - } - return sum; - } catch (ArithmeticException e) { - return Long.MAX_VALUE; + long sum = 8 + blocks.size() * 8L; + for (Block block : blocks) { + sum = LongMath.saturatedAdd(sum, block.getWeight()); } + return sum; } BlocksPrefix(List> blocks) { @@ -282,8 +280,7 @@ public List> getBlocks() { @AutoValue abstract static class Block implements Weighted { - private static final Block EMPTY = - fromValues(WeightedList.of(Collections.emptyList(), 0), null); + private static final Block EMPTY = fromValues(ImmutableList.of(), 0, null); @SuppressWarnings("unchecked") // Based upon as Collections.emptyList() public static Block emptyBlock() { @@ -299,21 +296,37 @@ public static Block mutatedBlock(WeightedList values) { } public static Block fromValues(List values, @Nullable ByteString nextToken) { - return fromValues(WeightedList.of(values, Caches.weigh(values)), nextToken); + if (values.isEmpty() && nextToken == null) { + return emptyBlock(); + } + ImmutableList immutableValues = ImmutableList.copyOf(values); + long listWeight = immutableValues.size() * Caches.REFERENCE_SIZE; + for (T value : immutableValues) { + listWeight = LongMath.saturatedAdd(listWeight, Caches.weigh(value)); + } + return fromValues(immutableValues, listWeight, nextToken); } public static Block fromValues( WeightedList values, @Nullable ByteString nextToken) { - long weight = values.getWeight() + 24; + if (values.isEmpty() && nextToken == null) { + return emptyBlock(); + } + return fromValues(ImmutableList.copyOf(values.getBacking()), values.getWeight(), nextToken); + } + + private static Block fromValues( + ImmutableList values, long listWeight, @Nullable ByteString nextToken) { + long weight = LongMath.saturatedAdd(listWeight, 24); if (nextToken != null) { if (nextToken.isEmpty()) { nextToken = ByteString.EMPTY; } else { - weight += Caches.weigh(nextToken); + weight = LongMath.saturatedAdd(weight, Caches.weigh(nextToken)); } } return new AutoValue_StateFetchingIterators_CachingStateIterable_Block<>( - values.getBacking(), nextToken, weight); + values, nextToken, weight); } abstract List getValues(); @@ -372,10 +385,12 @@ public void remove(Set toRemoveStructuralValues) { totalSize += tBlock.getValues().size(); } - WeightedList allValues = WeightedList.of(new ArrayList<>(totalSize), 0L); + ImmutableList.Builder allValues = ImmutableList.builderWithExpectedSize(totalSize); + long weight = 0; + List blockValuesToKeep = new ArrayList<>(); for (Block block : blocks) { + blockValuesToKeep.clear(); boolean valueRemovedFromBlock = false; - List blockValuesToKeep = new ArrayList<>(); for (T value : block.getValues()) { if (!toRemoveStructuralValues.contains(valueCoder.structuralValue(value))) { blockValuesToKeep.add(value); @@ -387,13 +402,19 @@ public void remove(Set toRemoveStructuralValues) { // If any value was removed from this block, need to estimate the weight again. // Otherwise, just reuse the block's weight. if (valueRemovedFromBlock) { - allValues.addAll(blockValuesToKeep, Caches.weigh(block.getValues())); + allValues.addAll(blockValuesToKeep); + for (T value : blockValuesToKeep) { + weight = LongMath.saturatedAdd(weight, Caches.weigh(value)); + } } else { - allValues.addAll(block.getValues(), block.getWeight()); + allValues.addAll(block.getValues()); + weight = LongMath.saturatedAdd(weight, block.getWeight()); } } - cache.put(IterableCacheKey.INSTANCE, new MutatedBlocks<>(Block.mutatedBlock(allValues))); + cache.put( + IterableCacheKey.INSTANCE, + new MutatedBlocks<>(Block.fromValues(allValues.build(), weight, null))); } /** @@ -484,21 +505,24 @@ private void appendHelper(List newValues, long newWeight) { for (Block block : blocks) { totalSize += block.getValues().size(); } - WeightedList allValues = WeightedList.of(new ArrayList<>(totalSize), 0L); + ImmutableList.Builder allValues = ImmutableList.builderWithExpectedSize(totalSize); + long weight = 0; for (Block block : blocks) { - allValues.addAll(block.getValues(), block.getWeight()); + allValues.addAll(block.getValues()); + weight = LongMath.saturatedAdd(weight, block.getWeight()); } if (newWeight < 0) { - if (newValues.size() == 1) { - // Optimize weighing of the common value state as single single-element bag state. - newWeight = Caches.weigh(newValues.get(0)); - } else { - newWeight = Caches.weigh(newValues); + newWeight = 0; + for (T value : newValues) { + newWeight = LongMath.saturatedAdd(newWeight, Caches.weigh(value)); } } - allValues.addAll(newValues, newWeight); + allValues.addAll(newValues); + weight = LongMath.saturatedAdd(weight, newWeight); - cache.put(IterableCacheKey.INSTANCE, new MutatedBlocks<>(Block.mutatedBlock(allValues))); + cache.put( + IterableCacheKey.INSTANCE, + new MutatedBlocks<>(Block.fromValues(allValues.build(), weight, null))); } class CachingStateIterator implements PrefetchableIterator { @@ -580,8 +604,7 @@ public boolean hasNext() { return false; } // Release the block while we are loading the next one. - currentBlock = - Block.fromValues(WeightedList.of(Collections.emptyList(), 0L), ByteString.EMPTY); + currentBlock = Block.emptyBlock(); @Nullable Blocks existing = cache.peek(IterableCacheKey.INSTANCE); boolean isFirstBlock = ByteString.EMPTY.equals(nextToken);