Skip to content

Commit addc06e

Browse files
authored
[Dataflow Streaming] Reuse ByteStringOutputStream buffers in WindmillBag (#36742)
1 parent 83ebe73 commit addc06e

File tree

4 files changed

+192
-62
lines changed

4 files changed

+192
-62
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
package org.apache.beam.runners.dataflow.worker.util;
19+
20+
import java.lang.ref.SoftReference;
21+
import javax.annotation.concurrent.ThreadSafe;
22+
import org.apache.beam.sdk.annotations.Internal;
23+
import org.apache.beam.sdk.util.ByteStringOutputStream;
24+
import org.apache.beam.sdk.util.Preconditions;
25+
import org.checkerframework.checker.nullness.qual.Nullable;
26+
27+
@Internal
28+
@ThreadSafe
29+
/*
30+
* A utility class for caching a thread-local {@link ByteStringOutputStream}.
31+
*
32+
* Example Usage:
33+
* try (StreamHandle streamHandle = ThreadLocalByteStringOutputStream.acquire()) {
34+
* ByteStringOutputStream stream = streamHandle.stream();
35+
* stream.write(1);
36+
* ByteString byteString = stream.toByteStringAndReset();
37+
* }
38+
*/
39+
public class ThreadLocalByteStringOutputStream {
40+
41+
private static final ThreadLocal<@Nullable SoftRefHolder> threadLocalSoftRefHolder =
42+
ThreadLocal.withInitial(SoftRefHolder::new);
43+
44+
// Private constructor to prevent instantiations from outside.
45+
private ThreadLocalByteStringOutputStream() {}
46+
47+
/** @return An AutoClosable StreamHandle that holds a cached ByteStringOutputStream. */
48+
public static StreamHandle acquire() {
49+
StreamHandle streamHandle = getStreamHandleFromThreadLocal();
50+
if (streamHandle.inUse) {
51+
// Stream is already in use, create a new uncached one
52+
return new StreamHandle();
53+
}
54+
streamHandle.inUse = true;
55+
return streamHandle; // inUse will be unset when streamHandle closes.
56+
}
57+
58+
/**
59+
* Handle to a thread-local {@link ByteStringOutputStream}. If the thread local stream is already
60+
* in use, a new one is used. The streams are cached and reused across calls. Users should not
61+
* keep a reference to the stream after closing the StreamHandle.
62+
*/
63+
public static class StreamHandle implements AutoCloseable {
64+
65+
private final ByteStringOutputStream stream = new ByteStringOutputStream();
66+
67+
private boolean inUse = false;
68+
69+
/**
70+
* Returns the underlying cached ByteStringOutputStream. Callers should not keep a reference to
71+
* the stream after closing the StreamHandle.
72+
*/
73+
public ByteStringOutputStream stream() {
74+
return stream;
75+
}
76+
77+
@Override
78+
public void close() {
79+
stream.reset();
80+
inUse = false;
81+
}
82+
}
83+
84+
private static class SoftRefHolder {
85+
private @Nullable SoftReference<StreamHandle> softReference;
86+
}
87+
88+
private static StreamHandle getStreamHandleFromThreadLocal() {
89+
// softRefHolder is only set by Threadlocal initializer and should not be null
90+
SoftRefHolder softRefHolder =
91+
Preconditions.checkArgumentNotNull(threadLocalSoftRefHolder.get());
92+
@Nullable StreamHandle streamHandle = null;
93+
@Nullable SoftReference<StreamHandle> softReference = softRefHolder.softReference;
94+
if (softReference != null) {
95+
streamHandle = softReference.get();
96+
}
97+
if (streamHandle == null) {
98+
streamHandle = new StreamHandle();
99+
softRefHolder.softReference = new SoftReference<>(streamHandle);
100+
}
101+
return streamHandle;
102+
}
103+
}

runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillBag.java

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import java.util.concurrent.ExecutionException;
2525
import java.util.concurrent.Future;
2626
import org.apache.beam.runners.core.StateNamespace;
27+
import org.apache.beam.runners.dataflow.worker.util.ThreadLocalByteStringOutputStream;
28+
import org.apache.beam.runners.dataflow.worker.util.ThreadLocalByteStringOutputStream.StreamHandle;
2729
import org.apache.beam.runners.dataflow.worker.util.common.worker.InternedByteString;
2830
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
2931
import org.apache.beam.sdk.coders.Coder;
@@ -165,17 +167,20 @@ public Windmill.WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyA
165167
if (bagUpdatesBuilder == null) {
166168
bagUpdatesBuilder = commitBuilder.addBagUpdatesBuilder();
167169
}
168-
for (T value : localAdditions) {
169-
ByteStringOutputStream stream = new ByteStringOutputStream();
170-
// Encode the value
171-
elemCoder.encode(value, stream, Coder.Context.OUTER);
172-
ByteString encoded = stream.toByteString();
173-
if (cachedValues != null) {
174-
// We'll capture this value in the cache below.
175-
// Capture the value's size now since we have it.
176-
encodedSize += encoded.size();
170+
try (StreamHandle streamHandle = ThreadLocalByteStringOutputStream.acquire()) {
171+
ByteStringOutputStream stream = streamHandle.stream();
172+
for (T value : localAdditions) {
173+
elemCoder.encode(value, stream, Coder.Context.OUTER);
174+
ByteString encoded = stream.toByteStringAndReset();
175+
if (cachedValues != null) {
176+
// We'll capture this value in the cache below.
177+
// Capture the value's size now since we have it.
178+
encodedSize += encoded.size();
179+
}
180+
bagUpdatesBuilder.addValues(encoded);
177181
}
178-
bagUpdatesBuilder.addValues(encoded);
182+
} catch (IOException e) {
183+
throw new RuntimeException(e);
179184
}
180185
}
181186

runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateTagUtil.java

Lines changed: 6 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,23 @@
1818
package org.apache.beam.runners.dataflow.worker.windmill.state;
1919

2020
import java.io.IOException;
21-
import java.lang.ref.SoftReference;
2221
import javax.annotation.concurrent.ThreadSafe;
2322
import org.apache.beam.runners.core.StateNamespace;
2423
import org.apache.beam.runners.core.StateTag;
2524
import org.apache.beam.runners.core.TimerInternals.TimerData;
2625
import org.apache.beam.runners.dataflow.worker.WindmillNamespacePrefix;
26+
import org.apache.beam.runners.dataflow.worker.util.ThreadLocalByteStringOutputStream;
27+
import org.apache.beam.runners.dataflow.worker.util.ThreadLocalByteStringOutputStream.StreamHandle;
2728
import org.apache.beam.runners.dataflow.worker.util.common.worker.InternedByteString;
2829
import org.apache.beam.sdk.annotations.Internal;
2930
import org.apache.beam.sdk.util.ByteStringOutputStream;
3031
import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString;
3132
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
32-
import org.checkerframework.checker.nullness.qual.Nullable;
3333

3434
@Internal
3535
@ThreadSafe
3636
public class WindmillStateTagUtil {
3737

38-
private static final ThreadLocal<@Nullable RefHolder> threadLocalRefHolder = new ThreadLocal<>();
3938
private static final String TIMER_HOLD_PREFIX = "/h";
4039
private static final WindmillStateTagUtil INSTANCE = new WindmillStateTagUtil();
4140

@@ -48,21 +47,10 @@ private WindmillStateTagUtil() {}
4847
*/
4948
@VisibleForTesting
5049
InternedByteString encodeKey(StateNamespace namespace, StateTag<?> address) {
51-
RefHolder refHolder = getRefHolderFromThreadLocal();
52-
// Use ByteStringOutputStream rather than concatenation and String.format. We build these keys
53-
// a lot, and this leads to better performance results. See associated benchmarks.
54-
ByteStringOutputStream stream;
55-
boolean releaseThreadLocal;
56-
if (refHolder.inUse) {
57-
// If the thread local stream is already in use, create a new one
58-
stream = new ByteStringOutputStream();
59-
releaseThreadLocal = false;
60-
} else {
61-
stream = getByteStringOutputStream(refHolder);
62-
refHolder.inUse = true;
63-
releaseThreadLocal = true;
64-
}
65-
try {
50+
try (StreamHandle streamHandle = ThreadLocalByteStringOutputStream.acquire()) {
51+
// Use ByteStringOutputStream rather than concatenation and String.format. We build these keys
52+
// a lot, and this leads to better performance results. See associated benchmarks.
53+
ByteStringOutputStream stream = streamHandle.stream();
6654
// stringKey starts and ends with a slash. We separate it from the
6755
// StateTag ID by a '+' (which is guaranteed not to be in the stringKey) because the
6856
// ID comes from the user.
@@ -72,11 +60,6 @@ InternedByteString encodeKey(StateNamespace namespace, StateTag<?> address) {
7260
return InternedByteString.of(stream.toByteStringAndReset());
7361
} catch (IOException e) {
7462
throw new RuntimeException(e);
75-
} finally {
76-
stream.reset();
77-
if (releaseThreadLocal) {
78-
refHolder.inUse = false;
79-
}
8063
}
8164
}
8265

@@ -116,35 +99,6 @@ public ByteString timerHoldTag(WindmillNamespacePrefix prefix, TimerData timerDa
11699
return ByteString.copyFromUtf8(tagString);
117100
}
118101

119-
private static class RefHolder {
120-
121-
public SoftReference<@Nullable ByteStringOutputStream> streamRef =
122-
new SoftReference<>(new ByteStringOutputStream());
123-
124-
// Boolean is true when the thread local stream is already in use by the current thread.
125-
// Used to avoid reusing the same stream from nested calls if any.
126-
public boolean inUse = false;
127-
}
128-
129-
private static RefHolder getRefHolderFromThreadLocal() {
130-
@Nullable RefHolder refHolder = threadLocalRefHolder.get();
131-
if (refHolder == null) {
132-
refHolder = new RefHolder();
133-
threadLocalRefHolder.set(refHolder);
134-
}
135-
return refHolder;
136-
}
137-
138-
private static ByteStringOutputStream getByteStringOutputStream(RefHolder refHolder) {
139-
@Nullable
140-
ByteStringOutputStream stream = refHolder.streamRef == null ? null : refHolder.streamRef.get();
141-
if (stream == null) {
142-
stream = new ByteStringOutputStream();
143-
refHolder.streamRef = new SoftReference<>(stream);
144-
}
145-
return stream;
146-
}
147-
148102
/** @return the singleton WindmillStateTagUtil */
149103
public static WindmillStateTagUtil instance() {
150104
return INSTANCE;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
package org.apache.beam.runners.dataflow.worker.util;
19+
20+
import static org.junit.Assert.*;
21+
22+
import org.apache.beam.runners.dataflow.worker.util.ThreadLocalByteStringOutputStream.StreamHandle;
23+
import org.apache.beam.sdk.util.ByteStringOutputStream;
24+
import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString;
25+
import org.junit.Test;
26+
27+
public class ThreadLocalByteStringOutputStreamTest {
28+
29+
@Test
30+
public void simple() {
31+
try (StreamHandle streamHandle = ThreadLocalByteStringOutputStream.acquire()) {
32+
ByteStringOutputStream stream = streamHandle.stream();
33+
stream.write(1);
34+
stream.write(2);
35+
stream.write(3);
36+
assertEquals(ByteString.copyFrom(new byte[] {1, 2, 3}), stream.toByteStringAndReset());
37+
}
38+
}
39+
40+
@Test
41+
public void nested() {
42+
try (StreamHandle streamHandle = ThreadLocalByteStringOutputStream.acquire()) {
43+
ByteStringOutputStream stream = streamHandle.stream();
44+
stream.write(1);
45+
try (StreamHandle streamHandle1 = ThreadLocalByteStringOutputStream.acquire()) {
46+
ByteStringOutputStream stream1 = streamHandle1.stream();
47+
stream1.write(2);
48+
assertEquals(ByteString.copyFrom(new byte[] {2}), stream1.toByteStringAndReset());
49+
}
50+
stream.write(3);
51+
assertEquals(ByteString.copyFrom(new byte[] {1, 3}), stream.toByteStringAndReset());
52+
}
53+
}
54+
55+
@Test
56+
public void resetDirtyStream() {
57+
try (StreamHandle streamHandle = ThreadLocalByteStringOutputStream.acquire()) {
58+
ByteStringOutputStream stream = streamHandle.stream();
59+
stream.write(1);
60+
// Don't read/reset stream
61+
}
62+
63+
try (StreamHandle streamHandle = ThreadLocalByteStringOutputStream.acquire()) {
64+
ByteStringOutputStream stream = streamHandle.stream();
65+
assertEquals(ByteString.EMPTY, stream.toByteStringAndReset());
66+
}
67+
}
68+
}

0 commit comments

Comments
 (0)