Skip to content

Commit e927f2d

Browse files
authored
Merge pull request #37019 from apache/revert-36897-optimize_weigh
Revert "Reduce weighing overhead for caching blocks"
2 parents f22d2bc + 63177cb commit e927f2d

File tree

2 files changed

+39
-55
lines changed

2 files changed

+39
-55
lines changed

sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import java.util.List;
2121
import java.util.concurrent.atomic.AtomicLong;
2222
import org.apache.beam.sdk.util.Weighted;
23-
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath;
2423

2524
/** Facade for a {@link List<T>} that keeps track of weight, for cache limit reasons. */
2625
public class WeightedList<T> implements Weighted {
@@ -72,6 +71,14 @@ public void addAll(List<T> values, long weight) {
7271
}
7372

7473
public void accumulateWeight(long weight) {
75-
this.weight.accumulateAndGet(weight, LongMath::saturatedAdd);
74+
this.weight.accumulateAndGet(
75+
weight,
76+
(first, second) -> {
77+
try {
78+
return Math.addExact(first, second);
79+
} catch (ArithmeticException e) {
80+
return Long.MAX_VALUE;
81+
}
82+
});
7683
}
7784
}

sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java

Lines changed: 30 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,6 @@
4949
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
5050
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Throwables;
5151
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator;
52-
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
53-
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath;
5452

5553
/**
5654
* Adapters which convert a logical series of chunks using continuation tokens over the Beam Fn
@@ -251,11 +249,15 @@ static class BlocksPrefix<T> extends Blocks<T> implements Shrinkable<BlocksPrefi
251249

252250
@Override
253251
public long getWeight() {
254-
long sum = 8 + blocks.size() * 8L;
255-
for (Block<T> block : blocks) {
256-
sum = LongMath.saturatedAdd(sum, block.getWeight());
252+
try {
253+
long sum = 8 + blocks.size() * 8L;
254+
for (Block<T> block : blocks) {
255+
sum = Math.addExact(sum, block.getWeight());
256+
}
257+
return sum;
258+
} catch (ArithmeticException e) {
259+
return Long.MAX_VALUE;
257260
}
258-
return sum;
259261
}
260262

261263
BlocksPrefix(List<Block<T>> blocks) {
@@ -280,7 +282,8 @@ public List<Block<T>> getBlocks() {
280282

281283
@AutoValue
282284
abstract static class Block<T> implements Weighted {
283-
private static final Block<Void> EMPTY = fromValues(ImmutableList.of(), 0, null);
285+
private static final Block<Void> EMPTY =
286+
fromValues(WeightedList.of(Collections.emptyList(), 0), null);
284287

285288
@SuppressWarnings("unchecked") // Based upon as Collections.emptyList()
286289
public static <T> Block<T> emptyBlock() {
@@ -296,37 +299,21 @@ public static <T> Block<T> mutatedBlock(WeightedList<T> values) {
296299
}
297300

298301
public static <T> Block<T> fromValues(List<T> values, @Nullable ByteString nextToken) {
299-
if (values.isEmpty() && nextToken == null) {
300-
return emptyBlock();
301-
}
302-
ImmutableList<T> immutableValues = ImmutableList.copyOf(values);
303-
long listWeight = immutableValues.size() * Caches.REFERENCE_SIZE;
304-
for (T value : immutableValues) {
305-
listWeight = LongMath.saturatedAdd(listWeight, Caches.weigh(value));
306-
}
307-
return fromValues(immutableValues, listWeight, nextToken);
302+
return fromValues(WeightedList.of(values, Caches.weigh(values)), nextToken);
308303
}
309304

310305
public static <T> Block<T> fromValues(
311306
WeightedList<T> values, @Nullable ByteString nextToken) {
312-
if (values.isEmpty() && nextToken == null) {
313-
return emptyBlock();
314-
}
315-
return fromValues(ImmutableList.copyOf(values.getBacking()), values.getWeight(), nextToken);
316-
}
317-
318-
private static <T> Block<T> fromValues(
319-
ImmutableList<T> values, long listWeight, @Nullable ByteString nextToken) {
320-
long weight = LongMath.saturatedAdd(listWeight, 24);
307+
long weight = values.getWeight() + 24;
321308
if (nextToken != null) {
322309
if (nextToken.isEmpty()) {
323310
nextToken = ByteString.EMPTY;
324311
} else {
325-
weight = LongMath.saturatedAdd(weight, Caches.weigh(nextToken));
312+
weight += Caches.weigh(nextToken);
326313
}
327314
}
328315
return new AutoValue_StateFetchingIterators_CachingStateIterable_Block<>(
329-
values, nextToken, weight);
316+
values.getBacking(), nextToken, weight);
330317
}
331318

332319
abstract List<T> getValues();
@@ -385,12 +372,10 @@ public void remove(Set<Object> toRemoveStructuralValues) {
385372
totalSize += tBlock.getValues().size();
386373
}
387374

388-
ImmutableList.Builder<T> allValues = ImmutableList.builderWithExpectedSize(totalSize);
389-
long weight = 0;
390-
List<T> blockValuesToKeep = new ArrayList<>();
375+
WeightedList<T> allValues = WeightedList.of(new ArrayList<>(totalSize), 0L);
391376
for (Block<T> block : blocks) {
392-
blockValuesToKeep.clear();
393377
boolean valueRemovedFromBlock = false;
378+
List<T> blockValuesToKeep = new ArrayList<>();
394379
for (T value : block.getValues()) {
395380
if (!toRemoveStructuralValues.contains(valueCoder.structuralValue(value))) {
396381
blockValuesToKeep.add(value);
@@ -402,19 +387,13 @@ public void remove(Set<Object> toRemoveStructuralValues) {
402387
// If any value was removed from this block, need to estimate the weight again.
403388
// Otherwise, just reuse the block's weight.
404389
if (valueRemovedFromBlock) {
405-
allValues.addAll(blockValuesToKeep);
406-
for (T value : blockValuesToKeep) {
407-
weight = LongMath.saturatedAdd(weight, Caches.weigh(value));
408-
}
390+
allValues.addAll(blockValuesToKeep, Caches.weigh(block.getValues()));
409391
} else {
410-
allValues.addAll(block.getValues());
411-
weight = LongMath.saturatedAdd(weight, block.getWeight());
392+
allValues.addAll(block.getValues(), block.getWeight());
412393
}
413394
}
414395

415-
cache.put(
416-
IterableCacheKey.INSTANCE,
417-
new MutatedBlocks<>(Block.fromValues(allValues.build(), weight, null)));
396+
cache.put(IterableCacheKey.INSTANCE, new MutatedBlocks<>(Block.mutatedBlock(allValues)));
418397
}
419398

420399
/**
@@ -505,24 +484,21 @@ private void appendHelper(List<T> newValues, long newWeight) {
505484
for (Block<T> block : blocks) {
506485
totalSize += block.getValues().size();
507486
}
508-
ImmutableList.Builder<T> allValues = ImmutableList.builderWithExpectedSize(totalSize);
509-
long weight = 0;
487+
WeightedList<T> allValues = WeightedList.of(new ArrayList<>(totalSize), 0L);
510488
for (Block<T> block : blocks) {
511-
allValues.addAll(block.getValues());
512-
weight = LongMath.saturatedAdd(weight, block.getWeight());
489+
allValues.addAll(block.getValues(), block.getWeight());
513490
}
514491
if (newWeight < 0) {
515-
newWeight = 0;
516-
for (T value : newValues) {
517-
newWeight = LongMath.saturatedAdd(newWeight, Caches.weigh(value));
492+
if (newValues.size() == 1) {
493+
// Optimize weighing of the common value state as single single-element bag state.
494+
newWeight = Caches.weigh(newValues.get(0));
495+
} else {
496+
newWeight = Caches.weigh(newValues);
518497
}
519498
}
520-
allValues.addAll(newValues);
521-
weight = LongMath.saturatedAdd(weight, newWeight);
499+
allValues.addAll(newValues, newWeight);
522500

523-
cache.put(
524-
IterableCacheKey.INSTANCE,
525-
new MutatedBlocks<>(Block.fromValues(allValues.build(), weight, null)));
501+
cache.put(IterableCacheKey.INSTANCE, new MutatedBlocks<>(Block.mutatedBlock(allValues)));
526502
}
527503

528504
class CachingStateIterator implements PrefetchableIterator<T> {
@@ -604,7 +580,8 @@ public boolean hasNext() {
604580
return false;
605581
}
606582
// Release the block while we are loading the next one.
607-
currentBlock = Block.emptyBlock();
583+
currentBlock =
584+
Block.fromValues(WeightedList.of(Collections.emptyList(), 0L), ByteString.EMPTY);
608585

609586
@Nullable Blocks<T> existing = cache.peek(IterableCacheKey.INSTANCE);
610587
boolean isFirstBlock = ByteString.EMPTY.equals(nextToken);

0 commit comments

Comments
 (0)