Skip to content

Commit 5b862dd

Browse files
authored
Merge pull request #34746 Streamline non-cached state backed iterable.
2 parents c72b351 + 877f111 commit 5b862dd

File tree

3 files changed

+93
-73
lines changed

3 files changed

+93
-73
lines changed

sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateBackedIterable.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
import java.util.concurrent.ThreadLocalRandom;
3434
import java.util.function.Supplier;
3535
import org.apache.beam.fn.harness.Cache;
36-
import org.apache.beam.fn.harness.Caches;
3736
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
3837
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
3938
import org.apache.beam.sdk.coders.IterableLikeCoder;
@@ -90,8 +89,7 @@ public StateBackedIterable(
9089
StateRequest.newBuilder().setInstructionId(instructionId).setStateKey(stateKey).build();
9190
this.prefix = prefix;
9291
this.suffix =
93-
StateFetchingIterators.readAllAndDecodeStartingFrom(
94-
Caches.subCache(cache, stateKey), beamFnStateClient, request, elemCoder);
92+
StateFetchingIterators.readAllAndDecodeStartingFrom(beamFnStateClient, request, elemCoder);
9593
this.elemCoder = elemCoder;
9694
}
9795

sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;
2121

2222
import com.google.auto.value.AutoValue;
23+
import java.io.IOException;
24+
import java.io.InputStream;
2325
import java.util.ArrayList;
2426
import java.util.Collections;
2527
import java.util.Iterator;
@@ -46,6 +48,7 @@
4648
import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString;
4749
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
4850
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Throwables;
51+
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator;
4952

5053
/**
5154
* Adapters which convert a logical series of chunks using continuation tokens over the Beam Fn
@@ -91,6 +94,95 @@ public static <T> CachingStateIterable<T> readAllAndDecodeStartingFrom(
9194
valueCoder);
9295
}
9396

97+
/**
98+
* This adapter handles using the continuation token to provide iteration over all the elements
99+
* returned by the Beam Fn State API using the supplied state client, state request for the first
100+
* chunk of the state stream, and a value decoder, without caching support.
101+
*
102+
* @param beamFnStateClient A client for handling state requests.
103+
* @param stateRequestForFirstChunk A fully populated state request for the first (and possibly
104+
* only) chunk of a state stream. This state request will be populated with a continuation
105+
* token to request further chunks of the stream if required.
106+
* @param valueCoder A coder for decoding the state stream.
107+
*/
108+
public static <T> UncachedStateIterable<T> readAllAndDecodeStartingFrom(
109+
BeamFnStateClient beamFnStateClient,
110+
StateRequest stateRequestForFirstChunk,
111+
Coder<T> valueCoder) {
112+
return new UncachedStateIterable<>(beamFnStateClient, stateRequestForFirstChunk, valueCoder);
113+
}
114+
115+
private static class UncachedStateIterable<T> extends PrefetchableIterables.Default<T> {
116+
private final BeamFnStateClient beamFnStateClient;
117+
private final StateRequest stateRequestForFirstChunk;
118+
private final Coder<T> valueCoder;
119+
120+
public UncachedStateIterable(
121+
BeamFnStateClient beamFnStateClient,
122+
StateRequest stateRequestForFirstChunk,
123+
Coder<T> valueCoder) {
124+
this.beamFnStateClient = beamFnStateClient;
125+
this.stateRequestForFirstChunk = stateRequestForFirstChunk;
126+
this.valueCoder = valueCoder;
127+
}
128+
129+
@Override
130+
public PrefetchableIterator<T> createIterator() {
131+
return new DecodingIterator<T>(
132+
new LazyBlockingStateFetchingIterator(beamFnStateClient, stateRequestForFirstChunk),
133+
valueCoder);
134+
}
135+
136+
private static class DecodingIterator<T> extends AbstractIterator<T>
137+
implements PrefetchableIterator<T> {
138+
private final PrefetchableIterator<ByteString> chunkIterator;
139+
private final Coder<T> valueCoder;
140+
private InputStream currentChunk;
141+
142+
public DecodingIterator(PrefetchableIterator<ByteString> chunkIterator, Coder<T> valueCoder) {
143+
this.chunkIterator = chunkIterator;
144+
this.valueCoder = valueCoder;
145+
this.currentChunk = ByteString.EMPTY.newInput();
146+
}
147+
148+
@Override
149+
protected T computeNext() {
150+
try {
151+
while (currentChunk.available() == 0) {
152+
if (chunkIterator.hasNext()) {
153+
currentChunk = chunkIterator.next().newInput();
154+
} else {
155+
return endOfData();
156+
}
157+
}
158+
return valueCoder.decode(currentChunk);
159+
} catch (IOException exn) {
160+
// Should never get here as ByteString.newInput() returns InputStreams
161+
// that don't do actual IO operations.
162+
throw new IllegalStateException(exn);
163+
}
164+
}
165+
166+
@Override
167+
public boolean isReady() {
168+
try {
169+
return currentChunk.available() > 0 || chunkIterator.isReady();
170+
} catch (IOException exn) {
171+
// Should never get here as ByteString.newInput() returns InputStreams
172+
// that don't do actual IO operations.
173+
throw new IllegalStateException(exn);
174+
}
175+
}
176+
177+
@Override
178+
public void prefetch() {
179+
if (!isReady()) {
180+
chunkIterator.prefetch();
181+
}
182+
}
183+
}
184+
}
185+
94186
@VisibleForTesting
95187
static class IterableCacheKey implements Weighted {
96188

sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateBackedIterableTest.java

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import java.util.Iterator;
3333
import java.util.List;
3434
import java.util.Random;
35-
import org.apache.beam.fn.harness.Cache;
3635
import org.apache.beam.fn.harness.Caches;
3736
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
3837
import org.apache.beam.sdk.coders.StringUtf8Coder;
@@ -111,75 +110,6 @@ public void testReiteration() throws Exception {
111110
assertEquals(expected, Lists.newArrayList(iterable));
112111
}
113112

114-
@Test
115-
public void testReiterationCached() throws Exception {
116-
FakeBeamFnStateClient fakeBeamFnStateClient =
117-
new FakeBeamFnStateClient(
118-
StringUtf8Coder.of(),
119-
ImmutableMap.of(
120-
key("nonEmptySuffix"), asList("C", "D", "E", "F", "G", "H", "I", "J", "K"),
121-
key("emptySuffix"), asList()));
122-
123-
StateBackedIterable<String> iterable =
124-
new StateBackedIterable<>(
125-
Caches.eternal(),
126-
fakeBeamFnStateClient,
127-
"instruction",
128-
key(suffixKey),
129-
StringUtf8Coder.of(),
130-
prefix);
131-
132-
// Ensure that the load is lazy
133-
assertEquals(0, fakeBeamFnStateClient.getCallCount());
134-
assertEquals(expected, Lists.newArrayList(iterable));
135-
// We expect future reiterations to not perform any loads
136-
int callCount = fakeBeamFnStateClient.getCallCount();
137-
assertEquals(expected, Lists.newArrayList(iterable));
138-
assertEquals(expected, Lists.newArrayList(iterable));
139-
assertEquals(callCount, fakeBeamFnStateClient.getCallCount());
140-
}
141-
142-
@Test
143-
public void testCacheKeyIsUnique() throws Exception {
144-
// Share a cache for multiple iterables leads to distinct keys being used.
145-
Cache cache = Caches.eternal();
146-
FakeBeamFnStateClient fakeBeamFnStateClient =
147-
new FakeBeamFnStateClient(
148-
StringUtf8Coder.of(),
149-
ImmutableMap.of(
150-
key("nonEmptySuffix"), asList("C", "D", "E", "F", "G", "H", "I", "J", "K"),
151-
key("emptySuffix"), asList(),
152-
key("otherIterable"), asList("Z")));
153-
154-
StateBackedIterable<String> otherIterable =
155-
new StateBackedIterable<>(
156-
cache,
157-
fakeBeamFnStateClient,
158-
"instruction",
159-
key("otherIterable"),
160-
StringUtf8Coder.of(),
161-
Collections.emptyList());
162-
// Ensure that the load is lazy
163-
assertEquals(0, fakeBeamFnStateClient.getCallCount());
164-
assertEquals(asList("Z"), Lists.newArrayList(otherIterable));
165-
166-
StateBackedIterable<String> iterable =
167-
new StateBackedIterable<>(
168-
cache,
169-
fakeBeamFnStateClient,
170-
"instruction",
171-
key(suffixKey),
172-
StringUtf8Coder.of(),
173-
prefix);
174-
175-
assertEquals(expected, Lists.newArrayList(iterable));
176-
// We expect future reiterations to not perform any loads
177-
int callCount = fakeBeamFnStateClient.getCallCount();
178-
assertEquals(expected, Lists.newArrayList(iterable));
179-
assertEquals(expected, Lists.newArrayList(iterable));
180-
assertEquals(callCount, fakeBeamFnStateClient.getCallCount());
181-
}
182-
183113
@Test
184114
public void testUsingInterleavedReiteration() throws Exception {
185115
FakeBeamFnStateClient fakeBeamFnStateClient =

0 commit comments

Comments
 (0)