Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Supplier;
import org.apache.beam.fn.harness.Cache;
import org.apache.beam.fn.harness.Caches;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
import org.apache.beam.sdk.coders.IterableLikeCoder;
Expand Down Expand Up @@ -90,8 +89,7 @@ public StateBackedIterable(
StateRequest.newBuilder().setInstructionId(instructionId).setStateKey(stateKey).build();
this.prefix = prefix;
this.suffix =
StateFetchingIterators.readAllAndDecodeStartingFrom(
Caches.subCache(cache, stateKey), beamFnStateClient, request, elemCoder);
StateFetchingIterators.readAllAndDecodeStartingFrom(beamFnStateClient, request, elemCoder);
this.elemCoder = elemCoder;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;

import com.google.auto.value.AutoValue;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
Expand All @@ -46,6 +48,7 @@
import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Throwables;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator;

/**
* Adapters which convert a logical series of chunks using continuation tokens over the Beam Fn
Expand Down Expand Up @@ -91,6 +94,95 @@ public static <T> CachingStateIterable<T> readAllAndDecodeStartingFrom(
valueCoder);
}

/**
* This adapter handles using the continuation token to provide iteration over all the elements
* returned by the Beam Fn State API using the supplied state client, state request for the first
* chunk of the state stream, and a value decoder, without caching support.
*
* @param beamFnStateClient A client for handling state requests.
* @param stateRequestForFirstChunk A fully populated state request for the first (and possibly
* only) chunk of a state stream. This state request will be populated with a continuation
* token to request further chunks of the stream if required.
* @param valueCoder A coder for decoding the state stream.
*/
public static <T> UncachedStateIterable<T> readAllAndDecodeStartingFrom(
BeamFnStateClient beamFnStateClient,
StateRequest stateRequestForFirstChunk,
Coder<T> valueCoder) {
return new UncachedStateIterable<>(beamFnStateClient, stateRequestForFirstChunk, valueCoder);
}

private static class UncachedStateIterable<T> extends PrefetchableIterables.Default<T> {
private final BeamFnStateClient beamFnStateClient;
private final StateRequest stateRequestForFirstChunk;
private final Coder<T> valueCoder;

public UncachedStateIterable(
BeamFnStateClient beamFnStateClient,
StateRequest stateRequestForFirstChunk,
Coder<T> valueCoder) {
this.beamFnStateClient = beamFnStateClient;
this.stateRequestForFirstChunk = stateRequestForFirstChunk;
this.valueCoder = valueCoder;
}

@Override
public PrefetchableIterator<T> createIterator() {
return new DecodingIterator<T>(
new LazyBlockingStateFetchingIterator(beamFnStateClient, stateRequestForFirstChunk),
valueCoder);
}

private static class DecodingIterator<T> extends AbstractIterator<T>
implements PrefetchableIterator<T> {
private final PrefetchableIterator<ByteString> chunkIterator;
private final Coder<T> valueCoder;
private InputStream currentChunk;

public DecodingIterator(PrefetchableIterator<ByteString> chunkIterator, Coder<T> valueCoder) {
this.chunkIterator = chunkIterator;
this.valueCoder = valueCoder;
this.currentChunk = ByteString.EMPTY.newInput();
}

@Override
protected T computeNext() {
try {
while (currentChunk.available() == 0) {
if (chunkIterator.hasNext()) {
currentChunk = chunkIterator.next().newInput();
} else {
return endOfData();
}
}
return valueCoder.decode(currentChunk);
} catch (IOException exn) {
// Should never get here as ByteString.newInput() returns InputStreams
// that don't do actual IO operations.
throw new IllegalStateException(exn);
}
}

@Override
public boolean isReady() {
try {
return currentChunk.available() > 0 || chunkIterator.isReady();
} catch (IOException exn) {
// Should never get here as ByteString.newInput() returns InputStreams
// that don't do actual IO operations.
throw new IllegalStateException(exn);
}
}

@Override
public void prefetch() {
if (!isReady()) {
chunkIterator.prefetch();
}
}
}
}

@VisibleForTesting
static class IterableCacheKey implements Weighted {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.apache.beam.fn.harness.Cache;
import org.apache.beam.fn.harness.Caches;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.sdk.coders.StringUtf8Coder;
Expand Down Expand Up @@ -111,75 +110,6 @@ public void testReiteration() throws Exception {
assertEquals(expected, Lists.newArrayList(iterable));
}

@Test
public void testReiterationCached() throws Exception {
FakeBeamFnStateClient fakeBeamFnStateClient =
new FakeBeamFnStateClient(
StringUtf8Coder.of(),
ImmutableMap.of(
key("nonEmptySuffix"), asList("C", "D", "E", "F", "G", "H", "I", "J", "K"),
key("emptySuffix"), asList()));

StateBackedIterable<String> iterable =
new StateBackedIterable<>(
Caches.eternal(),
fakeBeamFnStateClient,
"instruction",
key(suffixKey),
StringUtf8Coder.of(),
prefix);

// Ensure that the load is lazy
assertEquals(0, fakeBeamFnStateClient.getCallCount());
assertEquals(expected, Lists.newArrayList(iterable));
// We expect future reiterations to not perform any loads
int callCount = fakeBeamFnStateClient.getCallCount();
assertEquals(expected, Lists.newArrayList(iterable));
assertEquals(expected, Lists.newArrayList(iterable));
assertEquals(callCount, fakeBeamFnStateClient.getCallCount());
}

@Test
public void testCacheKeyIsUnique() throws Exception {
// Share a cache for multiple iterables leads to distinct keys being used.
Cache cache = Caches.eternal();
FakeBeamFnStateClient fakeBeamFnStateClient =
new FakeBeamFnStateClient(
StringUtf8Coder.of(),
ImmutableMap.of(
key("nonEmptySuffix"), asList("C", "D", "E", "F", "G", "H", "I", "J", "K"),
key("emptySuffix"), asList(),
key("otherIterable"), asList("Z")));

StateBackedIterable<String> otherIterable =
new StateBackedIterable<>(
cache,
fakeBeamFnStateClient,
"instruction",
key("otherIterable"),
StringUtf8Coder.of(),
Collections.emptyList());
// Ensure that the load is lazy
assertEquals(0, fakeBeamFnStateClient.getCallCount());
assertEquals(asList("Z"), Lists.newArrayList(otherIterable));

StateBackedIterable<String> iterable =
new StateBackedIterable<>(
cache,
fakeBeamFnStateClient,
"instruction",
key(suffixKey),
StringUtf8Coder.of(),
prefix);

assertEquals(expected, Lists.newArrayList(iterable));
// We expect future reiterations to not perform any loads
int callCount = fakeBeamFnStateClient.getCallCount();
assertEquals(expected, Lists.newArrayList(iterable));
assertEquals(expected, Lists.newArrayList(iterable));
assertEquals(callCount, fakeBeamFnStateClient.getCallCount());
}

@Test
public void testUsingInterleavedReiteration() throws Exception {
FakeBeamFnStateClient fakeBeamFnStateClient =
Expand Down
Loading