Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.beam.sdk.util.Weighted;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath;

/** Facade for a {@link List<T>} that keeps track of weight, for cache limit reasons. */
public class WeightedList<T> implements Weighted {
Expand Down Expand Up @@ -72,6 +71,14 @@ public void addAll(List<T> values, long weight) {
}

public void accumulateWeight(long weight) {
this.weight.accumulateAndGet(weight, LongMath::saturatedAdd);
this.weight.accumulateAndGet(
weight,
(first, second) -> {
try {
return Math.addExact(first, second);
} catch (ArithmeticException e) {
return Long.MAX_VALUE;
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Throwables;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath;

/**
* Adapters which convert a logical series of chunks using continuation tokens over the Beam Fn
Expand Down Expand Up @@ -251,11 +249,15 @@ static class BlocksPrefix<T> extends Blocks<T> implements Shrinkable<BlocksPrefi

@Override
public long getWeight() {
long sum = 8 + blocks.size() * 8L;
for (Block<T> block : blocks) {
sum = LongMath.saturatedAdd(sum, block.getWeight());
try {
long sum = 8 + blocks.size() * 8L;
for (Block<T> block : blocks) {
sum = Math.addExact(sum, block.getWeight());
}
return sum;
} catch (ArithmeticException e) {
return Long.MAX_VALUE;
}
return sum;
}

BlocksPrefix(List<Block<T>> blocks) {
Expand All @@ -280,7 +282,8 @@ public List<Block<T>> getBlocks() {

@AutoValue
abstract static class Block<T> implements Weighted {
private static final Block<Void> EMPTY = fromValues(ImmutableList.of(), 0, null);
private static final Block<Void> EMPTY =
fromValues(WeightedList.of(Collections.emptyList(), 0), null);

@SuppressWarnings("unchecked") // Based upon as Collections.emptyList()
public static <T> Block<T> emptyBlock() {
Expand All @@ -296,37 +299,21 @@ public static <T> Block<T> mutatedBlock(WeightedList<T> values) {
}

public static <T> Block<T> fromValues(List<T> values, @Nullable ByteString nextToken) {
if (values.isEmpty() && nextToken == null) {
return emptyBlock();
}
ImmutableList<T> immutableValues = ImmutableList.copyOf(values);
long listWeight = immutableValues.size() * Caches.REFERENCE_SIZE;
for (T value : immutableValues) {
listWeight = LongMath.saturatedAdd(listWeight, Caches.weigh(value));
}
return fromValues(immutableValues, listWeight, nextToken);
return fromValues(WeightedList.of(values, Caches.weigh(values)), nextToken);
}

public static <T> Block<T> fromValues(
WeightedList<T> values, @Nullable ByteString nextToken) {
if (values.isEmpty() && nextToken == null) {
return emptyBlock();
}
return fromValues(ImmutableList.copyOf(values.getBacking()), values.getWeight(), nextToken);
}

private static <T> Block<T> fromValues(
ImmutableList<T> values, long listWeight, @Nullable ByteString nextToken) {
long weight = LongMath.saturatedAdd(listWeight, 24);
long weight = values.getWeight() + 24;
if (nextToken != null) {
if (nextToken.isEmpty()) {
nextToken = ByteString.EMPTY;
} else {
weight = LongMath.saturatedAdd(weight, Caches.weigh(nextToken));
weight += Caches.weigh(nextToken);
}
}
return new AutoValue_StateFetchingIterators_CachingStateIterable_Block<>(
values, nextToken, weight);
values.getBacking(), nextToken, weight);
}

abstract List<T> getValues();
Expand Down Expand Up @@ -385,12 +372,10 @@ public void remove(Set<Object> toRemoveStructuralValues) {
totalSize += tBlock.getValues().size();
}

ImmutableList.Builder<T> allValues = ImmutableList.builderWithExpectedSize(totalSize);
long weight = 0;
List<T> blockValuesToKeep = new ArrayList<>();
WeightedList<T> allValues = WeightedList.of(new ArrayList<>(totalSize), 0L);
for (Block<T> block : blocks) {
blockValuesToKeep.clear();
boolean valueRemovedFromBlock = false;
List<T> blockValuesToKeep = new ArrayList<>();
for (T value : block.getValues()) {
if (!toRemoveStructuralValues.contains(valueCoder.structuralValue(value))) {
blockValuesToKeep.add(value);
Expand All @@ -402,19 +387,13 @@ public void remove(Set<Object> toRemoveStructuralValues) {
// If any value was removed from this block, need to estimate the weight again.
// Otherwise, just reuse the block's weight.
if (valueRemovedFromBlock) {
allValues.addAll(blockValuesToKeep);
for (T value : blockValuesToKeep) {
weight = LongMath.saturatedAdd(weight, Caches.weigh(value));
}
allValues.addAll(blockValuesToKeep, Caches.weigh(block.getValues()));
} else {
allValues.addAll(block.getValues());
weight = LongMath.saturatedAdd(weight, block.getWeight());
allValues.addAll(block.getValues(), block.getWeight());
}
}

cache.put(
IterableCacheKey.INSTANCE,
new MutatedBlocks<>(Block.fromValues(allValues.build(), weight, null)));
cache.put(IterableCacheKey.INSTANCE, new MutatedBlocks<>(Block.mutatedBlock(allValues)));
}

/**
Expand Down Expand Up @@ -505,24 +484,21 @@ private void appendHelper(List<T> newValues, long newWeight) {
for (Block<T> block : blocks) {
totalSize += block.getValues().size();
}
ImmutableList.Builder<T> allValues = ImmutableList.builderWithExpectedSize(totalSize);
long weight = 0;
WeightedList<T> allValues = WeightedList.of(new ArrayList<>(totalSize), 0L);
for (Block<T> block : blocks) {
allValues.addAll(block.getValues());
weight = LongMath.saturatedAdd(weight, block.getWeight());
allValues.addAll(block.getValues(), block.getWeight());
}
if (newWeight < 0) {
newWeight = 0;
for (T value : newValues) {
newWeight = LongMath.saturatedAdd(newWeight, Caches.weigh(value));
if (newValues.size() == 1) {
// Optimize weighing of the common value state as single single-element bag state.
newWeight = Caches.weigh(newValues.get(0));
} else {
newWeight = Caches.weigh(newValues);
}
}
allValues.addAll(newValues);
weight = LongMath.saturatedAdd(weight, newWeight);
allValues.addAll(newValues, newWeight);

cache.put(
IterableCacheKey.INSTANCE,
new MutatedBlocks<>(Block.fromValues(allValues.build(), weight, null)));
cache.put(IterableCacheKey.INSTANCE, new MutatedBlocks<>(Block.mutatedBlock(allValues)));
}

class CachingStateIterator implements PrefetchableIterator<T> {
Expand Down Expand Up @@ -604,7 +580,8 @@ public boolean hasNext() {
return false;
}
// Release the block while we are loading the next one.
currentBlock = Block.emptyBlock();
currentBlock =
Block.fromValues(WeightedList.of(Collections.emptyList(), 0L), ByteString.EMPTY);

@Nullable Blocks<T> existing = cache.peek(IterableCacheKey.INSTANCE);
boolean isFirstBlock = ByteString.EMPTY.equals(nextToken);
Expand Down
Loading