diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateBackedIterable.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateBackedIterable.java index 7c4b8d492e10..ef8d69bc1ec3 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateBackedIterable.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateBackedIterable.java @@ -33,7 +33,6 @@ import java.util.concurrent.ThreadLocalRandom; import java.util.function.Supplier; import org.apache.beam.fn.harness.Cache; -import org.apache.beam.fn.harness.Caches; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest; import org.apache.beam.sdk.coders.IterableLikeCoder; @@ -90,8 +89,7 @@ public StateBackedIterable( StateRequest.newBuilder().setInstructionId(instructionId).setStateKey(stateKey).build(); this.prefix = prefix; this.suffix = - StateFetchingIterators.readAllAndDecodeStartingFrom( - Caches.subCache(cache, stateKey), beamFnStateClient, request, elemCoder); + StateFetchingIterators.readAllAndDecodeStartingFrom(beamFnStateClient, request, elemCoder); this.elemCoder = elemCoder; } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java index c34a4a84edb8..e4144bcbb353 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java @@ -20,6 +20,8 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; import com.google.auto.value.AutoValue; +import java.io.IOException; +import java.io.InputStream; import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; @@ -46,6 +48,7 @@ import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Throwables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator; /** * Adapters which convert a logical series of chunks using continuation tokens over the Beam Fn @@ -91,6 +94,95 @@ public static CachingStateIterable readAllAndDecodeStartingFrom( valueCoder); } + /** + * This adapter handles using the continuation token to provide iteration over all the elements + * returned by the Beam Fn State API using the supplied state client, state request for the first + * chunk of the state stream, and a value decoder, without caching support. + * + * @param beamFnStateClient A client for handling state requests. + * @param stateRequestForFirstChunk A fully populated state request for the first (and possibly + * only) chunk of a state stream. This state request will be populated with a continuation + * token to request further chunks of the stream if required. + * @param valueCoder A coder for decoding the state stream. + */ + public static UncachedStateIterable readAllAndDecodeStartingFrom( + BeamFnStateClient beamFnStateClient, + StateRequest stateRequestForFirstChunk, + Coder valueCoder) { + return new UncachedStateIterable<>(beamFnStateClient, stateRequestForFirstChunk, valueCoder); + } + + private static class UncachedStateIterable extends PrefetchableIterables.Default { + private final BeamFnStateClient beamFnStateClient; + private final StateRequest stateRequestForFirstChunk; + private final Coder valueCoder; + + public UncachedStateIterable( + BeamFnStateClient beamFnStateClient, + StateRequest stateRequestForFirstChunk, + Coder valueCoder) { + this.beamFnStateClient = beamFnStateClient; + this.stateRequestForFirstChunk = stateRequestForFirstChunk; + this.valueCoder = valueCoder; + } + + @Override + public PrefetchableIterator createIterator() { + return new DecodingIterator( + new LazyBlockingStateFetchingIterator(beamFnStateClient, stateRequestForFirstChunk), + valueCoder); + } + + private static class DecodingIterator extends AbstractIterator + implements PrefetchableIterator { + private final PrefetchableIterator chunkIterator; + private final Coder valueCoder; + private InputStream currentChunk; + + public DecodingIterator(PrefetchableIterator chunkIterator, Coder valueCoder) { + this.chunkIterator = chunkIterator; + this.valueCoder = valueCoder; + this.currentChunk = ByteString.EMPTY.newInput(); + } + + @Override + protected T computeNext() { + try { + while (currentChunk.available() == 0) { + if (chunkIterator.hasNext()) { + currentChunk = chunkIterator.next().newInput(); + } else { + return endOfData(); + } + } + return valueCoder.decode(currentChunk); + } catch (IOException exn) { + // Should never get here as ByteString.newInput() returns InputStreams + // that don't do actual IO operations. + throw new IllegalStateException(exn); + } + } + + @Override + public boolean isReady() { + try { + return currentChunk.available() > 0 || chunkIterator.isReady(); + } catch (IOException exn) { + // Should never get here as ByteString.newInput() returns InputStreams + // that don't do actual IO operations. + throw new IllegalStateException(exn); + } + } + + @Override + public void prefetch() { + if (!isReady()) { + chunkIterator.prefetch(); + } + } + } + } + @VisibleForTesting static class IterableCacheKey implements Weighted { diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateBackedIterableTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateBackedIterableTest.java index 0e2598cf0784..9b1d51748e68 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateBackedIterableTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateBackedIterableTest.java @@ -32,7 +32,6 @@ import java.util.Iterator; import java.util.List; import java.util.Random; -import org.apache.beam.fn.harness.Cache; import org.apache.beam.fn.harness.Caches; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey; import org.apache.beam.sdk.coders.StringUtf8Coder; @@ -111,75 +110,6 @@ public void testReiteration() throws Exception { assertEquals(expected, Lists.newArrayList(iterable)); } - @Test - public void testReiterationCached() throws Exception { - FakeBeamFnStateClient fakeBeamFnStateClient = - new FakeBeamFnStateClient( - StringUtf8Coder.of(), - ImmutableMap.of( - key("nonEmptySuffix"), asList("C", "D", "E", "F", "G", "H", "I", "J", "K"), - key("emptySuffix"), asList())); - - StateBackedIterable iterable = - new StateBackedIterable<>( - Caches.eternal(), - fakeBeamFnStateClient, - "instruction", - key(suffixKey), - StringUtf8Coder.of(), - prefix); - - // Ensure that the load is lazy - assertEquals(0, fakeBeamFnStateClient.getCallCount()); - assertEquals(expected, Lists.newArrayList(iterable)); - // We expect future reiterations to not perform any loads - int callCount = fakeBeamFnStateClient.getCallCount(); - assertEquals(expected, Lists.newArrayList(iterable)); - assertEquals(expected, Lists.newArrayList(iterable)); - assertEquals(callCount, fakeBeamFnStateClient.getCallCount()); - } - - @Test - public void testCacheKeyIsUnique() throws Exception { - // Share a cache for multiple iterables leads to distinct keys being used. - Cache cache = Caches.eternal(); - FakeBeamFnStateClient fakeBeamFnStateClient = - new FakeBeamFnStateClient( - StringUtf8Coder.of(), - ImmutableMap.of( - key("nonEmptySuffix"), asList("C", "D", "E", "F", "G", "H", "I", "J", "K"), - key("emptySuffix"), asList(), - key("otherIterable"), asList("Z"))); - - StateBackedIterable otherIterable = - new StateBackedIterable<>( - cache, - fakeBeamFnStateClient, - "instruction", - key("otherIterable"), - StringUtf8Coder.of(), - Collections.emptyList()); - // Ensure that the load is lazy - assertEquals(0, fakeBeamFnStateClient.getCallCount()); - assertEquals(asList("Z"), Lists.newArrayList(otherIterable)); - - StateBackedIterable iterable = - new StateBackedIterable<>( - cache, - fakeBeamFnStateClient, - "instruction", - key(suffixKey), - StringUtf8Coder.of(), - prefix); - - assertEquals(expected, Lists.newArrayList(iterable)); - // We expect future reiterations to not perform any loads - int callCount = fakeBeamFnStateClient.getCallCount(); - assertEquals(expected, Lists.newArrayList(iterable)); - assertEquals(expected, Lists.newArrayList(iterable)); - assertEquals(callCount, fakeBeamFnStateClient.getCallCount()); - } - @Test public void testUsingInterleavedReiteration() throws Exception { FakeBeamFnStateClient fakeBeamFnStateClient =