Skip to content

Commit a420191

Browse files
committed
Add a utility class to enable sharing across all deserialized instances of a DoFn and use it in UnboundedSourceAsSdfWrapperFn to cache Readers across dofn instances
1 parent 9064743 commit a420191

File tree

4 files changed

+208
-24
lines changed

4 files changed

+208
-24
lines changed

sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators;
5757
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
5858
import org.apache.beam.sdk.util.NameUtils;
59+
import org.apache.beam.sdk.util.PerSerializationStatic;
5960
import org.apache.beam.sdk.util.SerializableUtils;
6061
import org.apache.beam.sdk.values.PBegin;
6162
import org.apache.beam.sdk.values.PCollection;
@@ -481,12 +482,31 @@ static class UnboundedSourceAsSDFWrapperFn<OutputT, CheckpointT extends Checkpoi
481482
private static final Logger LOG = LoggerFactory.getLogger(UnboundedSourceAsSDFWrapperFn.class);
482483
private static final int DEFAULT_BUNDLE_FINALIZATION_LIMIT_MINS = 10;
483484
private final Coder<CheckpointT> checkpointCoder;
484-
private @Nullable Cache<Object, UnboundedReader<OutputT>> cachedReaders;
485+
private final PerSerializationStatic<Cache<Object, UnboundedReader<OutputT>>> cachedReaders;
485486
private @Nullable Coder<UnboundedSourceRestriction<OutputT, CheckpointT>> restrictionCoder;
486487

487488
@VisibleForTesting
488489
UnboundedSourceAsSDFWrapperFn(Coder<CheckpointT> checkpointCoder) {
489490
this.checkpointCoder = checkpointCoder;
491+
cachedReaders =
492+
new PerSerializationStatic<>(
493+
() ->
494+
CacheBuilder.newBuilder()
495+
.expireAfterWrite(1, TimeUnit.MINUTES)
496+
.maximumSize(100)
497+
.removalListener(
498+
(RemovalListener<Object, UnboundedReader<OutputT>>)
499+
removalNotification -> {
500+
if (removalNotification.wasEvicted()) {
501+
try {
502+
Preconditions.checkNotNull(removalNotification.getValue())
503+
.close();
504+
} catch (IOException e) {
505+
LOG.warn("Failed to close UnboundedReader.", e);
506+
}
507+
}
508+
})
509+
.build());
490510
}
491511

492512
@GetInitialRestriction
@@ -498,22 +518,6 @@ public UnboundedSourceRestriction<OutputT, CheckpointT> initialRestriction(
498518
@Setup
499519
public void setUp() throws Exception {
500520
restrictionCoder = restrictionCoder();
501-
cachedReaders =
502-
CacheBuilder.newBuilder()
503-
.expireAfterWrite(1, TimeUnit.MINUTES)
504-
.maximumSize(100)
505-
.removalListener(
506-
(RemovalListener<Object, UnboundedReader<OutputT>>)
507-
removalNotification -> {
508-
if (removalNotification.wasEvicted()) {
509-
try {
510-
Preconditions.checkNotNull(removalNotification.getValue()).close();
511-
} catch (IOException e) {
512-
LOG.warn("Failed to close UnboundedReader.", e);
513-
}
514-
}
515-
})
516-
.build();
517521
}
518522

519523
@SplitRestriction
@@ -556,7 +560,8 @@ public void splitRestriction(
556560
PipelineOptions pipelineOptions) {
557561
Coder<UnboundedSourceRestriction<OutputT, CheckpointT>> restrictionCoder =
558562
checkStateNotNull(this.restrictionCoder);
559-
Cache<Object, UnboundedReader<OutputT>> cachedReaders = checkStateNotNull(this.cachedReaders);
563+
Cache<Object, UnboundedReader<OutputT>> cachedReaders =
564+
checkStateNotNull(this.cachedReaders.get());
560565
return new UnboundedSourceAsSDFRestrictionTracker<>(
561566
restriction, pipelineOptions, cachedReaders, restrictionCoder);
562567
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
package org.apache.beam.sdk.util;
19+
20+
import java.io.Serializable;
21+
import java.util.concurrent.ConcurrentHashMap;
22+
import java.util.concurrent.atomic.AtomicInteger;
23+
import javax.annotation.Nullable;
24+
import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
25+
import org.checkerframework.checker.nullness.qual.NonNull;
26+
27+
/**
28+
* An object that simplifies having a variable that behaves like a static object but which is scoped
29+
* to deserialized instances.
30+
*
31+
* <p>In particular this can be useful for use within a DoFn class to maintain shared state across
32+
* all instances of the DoFn that are the same step in the graph. This differs from a static
33+
* variable which would be shared across all instances of the DoFn and a non-static variable which
34+
* is per instance.
35+
*/
36+
public class PerSerializationStatic<T> implements Serializable {
37+
private static final AtomicInteger idGenerator = new AtomicInteger();
38+
private final int id;
39+
40+
private static final ConcurrentHashMap<Integer, Object> staticCache = new ConcurrentHashMap<>();
41+
private final SerializableSupplier<@NonNull T> supplier;
42+
private transient volatile @MonotonicNonNull T value;
43+
44+
public PerSerializationStatic(SerializableSupplier<@NonNull T> supplier) {
45+
id = idGenerator.incrementAndGet();
46+
this.supplier = supplier;
47+
}
48+
49+
@SuppressWarnings("unchecked")
50+
public T get() {
51+
@Nullable T result = value;
52+
if (result != null) {
53+
return result;
54+
}
55+
@Nullable T mapValue = (T) staticCache.computeIfAbsent(id, ignored -> supplier.get());
56+
return value = Preconditions.checkStateNotNull(mapValue);
57+
}
58+
}

sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,7 @@ public void testUnboundedSdfWrapperCacheStartedReaders() {
179179
// read is default.
180180
ExperimentalOptions.addExperiment(
181181
pipeline.getOptions().as(ExperimentalOptions.class), "use_sdf_read");
182-
// Force the pipeline to run with one thread to ensure the reader will be reused on one DoFn
183-
// instance.
184-
// We are not able to use DirectOptions because of circular dependency.
185-
pipeline
186-
.runWithAdditionalOptionArgs(ImmutableList.of("--targetParallelism=1"))
187-
.waitUntilFinish();
182+
pipeline.run().waitUntilFinish();
188183
}
189184

190185
@Test
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
package org.apache.beam.sdk.util;
19+
20+
import static org.junit.Assert.assertEquals;
21+
import static org.junit.Assert.assertNotSame;
22+
import static org.junit.Assert.assertSame;
23+
24+
import java.util.concurrent.ConcurrentHashMap;
25+
import java.util.concurrent.atomic.AtomicInteger;
26+
import org.junit.Test;
27+
import org.junit.runner.RunWith;
28+
import org.junit.runners.JUnit4;
29+
30+
@RunWith(JUnit4.class)
31+
public class PerSerializationStaticTest {
32+
33+
@SuppressWarnings("unchecked")
34+
@Test
35+
public void testSharedAcrossDeserialize() throws Exception {
36+
PerSerializationStatic<AtomicInteger> instance =
37+
new PerSerializationStatic<>(AtomicInteger::new);
38+
SerializableUtils.ensureSerializable(instance);
39+
40+
AtomicInteger i = instance.get();
41+
i.set(10);
42+
assertSame(i, instance.get());
43+
44+
byte[] serialized = SerializableUtils.serializeToByteArray(instance);
45+
PerSerializationStatic<AtomicInteger> deserialized1 =
46+
(PerSerializationStatic<AtomicInteger>)
47+
SerializableUtils.deserializeFromByteArray(serialized, "instance");
48+
assertSame(i, deserialized1.get());
49+
50+
PerSerializationStatic<AtomicInteger> deserialized2 =
51+
(PerSerializationStatic<AtomicInteger>)
52+
SerializableUtils.deserializeFromByteArray(serialized, "instance");
53+
assertSame(i, deserialized2.get());
54+
assertEquals(10, i.get());
55+
}
56+
57+
@Test
58+
public void testDifferentInstancesSeparate() throws Exception {
59+
PerSerializationStatic<AtomicInteger> instance =
60+
new PerSerializationStatic<>(AtomicInteger::new);
61+
SerializableUtils.ensureSerializable(instance);
62+
AtomicInteger i = instance.get();
63+
i.set(10);
64+
assertSame(i, instance.get());
65+
66+
PerSerializationStatic<AtomicInteger> instance2 =
67+
new PerSerializationStatic<>(AtomicInteger::new);
68+
SerializableUtils.ensureSerializable(instance2);
69+
AtomicInteger j = instance2.get();
70+
j.set(20);
71+
assertSame(j, instance2.get());
72+
assertNotSame(j, i);
73+
74+
PerSerializationStatic<AtomicInteger> instance1clone = SerializableUtils.clone(instance);
75+
assertSame(instance1clone.get(), i);
76+
PerSerializationStatic<AtomicInteger> instance2clone = SerializableUtils.clone(instance2);
77+
assertSame(instance2clone.get(), j);
78+
}
79+
80+
@SuppressWarnings("unchecked")
81+
@Test
82+
public void testDifferentInstancesSeparateNoGetBeforeSerialization() throws Exception {
83+
PerSerializationStatic<AtomicInteger> instance =
84+
new PerSerializationStatic<>(AtomicInteger::new);
85+
SerializableUtils.ensureSerializable(instance);
86+
87+
PerSerializationStatic<AtomicInteger> instance2 =
88+
new PerSerializationStatic<>(AtomicInteger::new);
89+
SerializableUtils.ensureSerializable(instance2);
90+
91+
byte[] serialized = SerializableUtils.serializeToByteArray(instance);
92+
PerSerializationStatic<AtomicInteger> deserialized1 =
93+
(PerSerializationStatic<AtomicInteger>)
94+
SerializableUtils.deserializeFromByteArray(serialized, "instance");
95+
PerSerializationStatic<AtomicInteger> deserialized2 =
96+
(PerSerializationStatic<AtomicInteger>)
97+
SerializableUtils.deserializeFromByteArray(serialized, "instance");
98+
assertSame(deserialized1.get(), deserialized2.get());
99+
100+
PerSerializationStatic<AtomicInteger> instance2clone = SerializableUtils.clone(instance2);
101+
assertNotSame(instance2clone.get(), deserialized1.get());
102+
}
103+
104+
@Test
105+
public void testDifferentTypes() throws Exception {
106+
PerSerializationStatic<AtomicInteger> instance =
107+
new PerSerializationStatic<>(AtomicInteger::new);
108+
SerializableUtils.ensureSerializable(instance);
109+
AtomicInteger i = instance.get();
110+
i.set(10);
111+
assertSame(i, instance.get());
112+
113+
PerSerializationStatic<ConcurrentHashMap<Integer, Integer>> instance2 =
114+
new PerSerializationStatic<>(ConcurrentHashMap::new);
115+
SerializableUtils.ensureSerializable(instance2);
116+
ConcurrentHashMap<Integer, Integer> j = instance2.get();
117+
j.put(1, 100);
118+
assertSame(j, instance2.get());
119+
120+
PerSerializationStatic<AtomicInteger> instance1clone = SerializableUtils.clone(instance);
121+
assertSame(instance1clone.get(), i);
122+
PerSerializationStatic<ConcurrentHashMap<Integer, Integer>> instance2clone =
123+
SerializableUtils.clone(instance2);
124+
assertSame(instance2clone.get(), j);
125+
}
126+
}

0 commit comments

Comments
 (0)