Skip to content

Commit e904467

Browse files
committed
Reduce weighing overhead for caching blocks
1 parent f41cbde commit e904467

File tree

3 files changed

+55
-37
lines changed

3 files changed

+55
-37
lines changed

sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
import java.util.List;
2121
import java.util.concurrent.atomic.AtomicLong;
22+
23+
import com.google.common.math.LongMath;
2224
import org.apache.beam.sdk.util.Weighted;
2325

2426
/** Facade for a {@link List<T>} that keeps track of weight, for cache limit reasons. */
@@ -73,12 +75,6 @@ public void addAll(List<T> values, long weight) {
7375
public void accumulateWeight(long weight) {
7476
this.weight.accumulateAndGet(
7577
weight,
76-
(first, second) -> {
77-
try {
78-
return Math.addExact(first, second);
79-
} catch (ArithmeticException e) {
80-
return Long.MAX_VALUE;
81-
}
82-
});
78+
LongMath::saturatedAdd);
8379
}
8480
}

sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Caches.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.apache.beam.sdk.options.SdkHarnessOptions;
3131
import org.apache.beam.sdk.util.Weighted;
3232
import org.apache.beam.sdk.util.WeightedValue;
33+
import org.apache.beam.vendor.grpc.v1p69p0.com.google.common.math.LongMath;
3334
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
3435
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder;
3536
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheStats;

sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,12 @@
4545
import org.apache.beam.sdk.fn.stream.PrefetchableIterables;
4646
import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
4747
import org.apache.beam.sdk.util.Weighted;
48+
import org.apache.beam.vendor.grpc.v1p69p0.com.google.common.math.LongMath;
4849
import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString;
4950
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
5051
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Throwables;
5152
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator;
53+
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
5254

5355
/**
5456
* Adapters which convert a logical series of chunks using continuation tokens over the Beam Fn
@@ -249,15 +251,11 @@ static class BlocksPrefix<T> extends Blocks<T> implements Shrinkable<BlocksPrefi
249251

250252
@Override
251253
public long getWeight() {
252-
try {
253-
long sum = 8 + blocks.size() * 8L;
254-
for (Block<T> block : blocks) {
255-
sum = Math.addExact(sum, block.getWeight());
256-
}
257-
return sum;
258-
} catch (ArithmeticException e) {
259-
return Long.MAX_VALUE;
254+
long sum = 8 + blocks.size() * 8L;
255+
for (Block<T> block : blocks) {
256+
sum = LongMath.saturatedAdd(sum, block.getWeight());
260257
}
258+
return sum;
261259
}
262260

263261
BlocksPrefix(List<Block<T>> blocks) {
@@ -282,8 +280,7 @@ public List<Block<T>> getBlocks() {
282280

283281
@AutoValue
284282
abstract static class Block<T> implements Weighted {
285-
private static final Block<Void> EMPTY =
286-
fromValues(WeightedList.of(Collections.emptyList(), 0), null);
283+
private static final Block<Void> EMPTY = fromValues(ImmutableList.of(), 0, null);
287284

288285
@SuppressWarnings("unchecked") // Based upon as Collections.emptyList()
289286
public static <T> Block<T> emptyBlock() {
@@ -299,21 +296,37 @@ public static <T> Block<T> mutatedBlock(WeightedList<T> values) {
299296
}
300297

301298
public static <T> Block<T> fromValues(List<T> values, @Nullable ByteString nextToken) {
302-
return fromValues(WeightedList.of(values, Caches.weigh(values)), nextToken);
299+
if (values.isEmpty() && nextToken == null) {
300+
return emptyBlock();
301+
}
302+
ImmutableList<T> immutableValues = ImmutableList.copyOf(values);
303+
long listWeight = immutableValues.size() * Caches.REFERENCE_SIZE;
304+
for (T value : immutableValues) {
305+
listWeight = LongMath.saturatedAdd(listWeight, Caches.weigh(value));
306+
}
307+
return fromValues(immutableValues, listWeight, nextToken);
303308
}
304309

305310
public static <T> Block<T> fromValues(
306311
WeightedList<T> values, @Nullable ByteString nextToken) {
307-
long weight = values.getWeight() + 24;
312+
if (values.isEmpty() && nextToken == null) {
313+
return emptyBlock();
314+
}
315+
return fromValues(ImmutableList.copyOf(values.getBacking()), values.getWeight(), nextToken);
316+
}
317+
318+
private static <T> Block<T> fromValues(
319+
ImmutableList<T> values, long listWeight, @Nullable ByteString nextToken) {
320+
long weight = LongMath.saturatedAdd(listWeight, 24);
308321
if (nextToken != null) {
309322
if (nextToken.isEmpty()) {
310323
nextToken = ByteString.EMPTY;
311324
} else {
312-
weight += Caches.weigh(nextToken);
325+
weight = LongMath.saturatedAdd(weight, Caches.weigh(nextToken));
313326
}
314327
}
315328
return new AutoValue_StateFetchingIterators_CachingStateIterable_Block<>(
316-
values.getBacking(), nextToken, weight);
329+
values, nextToken, weight);
317330
}
318331

319332
abstract List<T> getValues();
@@ -372,10 +385,12 @@ public void remove(Set<Object> toRemoveStructuralValues) {
372385
totalSize += tBlock.getValues().size();
373386
}
374387

375-
WeightedList<T> allValues = WeightedList.of(new ArrayList<>(totalSize), 0L);
388+
ImmutableList.Builder<T> allValues = ImmutableList.builderWithExpectedSize(totalSize);
389+
long weight = 0;
390+
List<T> blockValuesToKeep = new ArrayList<>();
376391
for (Block<T> block : blocks) {
392+
blockValuesToKeep.clear();
377393
boolean valueRemovedFromBlock = false;
378-
List<T> blockValuesToKeep = new ArrayList<>();
379394
for (T value : block.getValues()) {
380395
if (!toRemoveStructuralValues.contains(valueCoder.structuralValue(value))) {
381396
blockValuesToKeep.add(value);
@@ -387,13 +402,19 @@ public void remove(Set<Object> toRemoveStructuralValues) {
387402
// If any value was removed from this block, need to estimate the weight again.
388403
// Otherwise, just reuse the block's weight.
389404
if (valueRemovedFromBlock) {
390-
allValues.addAll(blockValuesToKeep, Caches.weigh(block.getValues()));
405+
allValues.addAll(blockValuesToKeep);
406+
for (T value : blockValuesToKeep) {
407+
weight = LongMath.saturatedAdd(weight, Caches.weigh(value));
408+
}
391409
} else {
392-
allValues.addAll(block.getValues(), block.getWeight());
410+
allValues.addAll(block.getValues());
411+
weight = LongMath.saturatedAdd(weight, block.getWeight());
393412
}
394413
}
395414

396-
cache.put(IterableCacheKey.INSTANCE, new MutatedBlocks<>(Block.mutatedBlock(allValues)));
415+
cache.put(
416+
IterableCacheKey.INSTANCE,
417+
new MutatedBlocks<>(Block.fromValues(allValues.build(), weight, null)));
397418
}
398419

399420
/**
@@ -484,21 +505,22 @@ private void appendHelper(List<T> newValues, long newWeight) {
484505
for (Block<T> block : blocks) {
485506
totalSize += block.getValues().size();
486507
}
487-
WeightedList<T> allValues = WeightedList.of(new ArrayList<>(totalSize), 0L);
508+
ImmutableList.Builder<T> allValues = ImmutableList.builderWithExpectedSize(totalSize);
509+
long weight = 0;
488510
for (Block<T> block : blocks) {
489-
allValues.addAll(block.getValues(), block.getWeight());
511+
allValues.addAll(block.getValues());
512+
weight = LongMath.saturatedAdd(weight, block.getWeight());
490513
}
491514
if (newWeight < 0) {
492-
if (newValues.size() == 1) {
493-
// Optimize weighing of the common value state as single single-element bag state.
494-
newWeight = Caches.weigh(newValues.get(0));
495-
} else {
496-
newWeight = Caches.weigh(newValues);
515+
newWeight = 0;
516+
for (T value : newValues) {
517+
newWeight = LongMath.saturatedAdd(newWeight, Caches.weigh(value));
497518
}
498519
}
499-
allValues.addAll(newValues, newWeight);
520+
allValues.addAll(newValues);
521+
weight = LongMath.saturatedAdd(weight, newWeight);
500522

501-
cache.put(IterableCacheKey.INSTANCE, new MutatedBlocks<>(Block.mutatedBlock(allValues)));
523+
cache.put(IterableCacheKey.INSTANCE, new MutatedBlocks<>(Block.fromValues(allValues.build(), weight, null)));
502524
}
503525

504526
class CachingStateIterator implements PrefetchableIterator<T> {
@@ -580,8 +602,7 @@ public boolean hasNext() {
580602
return false;
581603
}
582604
// Release the block while we are loading the next one.
583-
currentBlock =
584-
Block.fromValues(WeightedList.of(Collections.emptyList(), 0L), ByteString.EMPTY);
605+
currentBlock = Block.emptyBlock();
585606

586607
@Nullable Blocks<T> existing = cache.peek(IterableCacheKey.INSTANCE);
587608
boolean isFirstBlock = ByteString.EMPTY.equals(nextToken);

0 commit comments

Comments
 (0)