diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/ThreadLocalByteStringOutputStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/ThreadLocalByteStringOutputStream.java new file mode 100644 index 000000000000..8e33be639e43 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/ThreadLocalByteStringOutputStream.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.util; + +import java.lang.ref.SoftReference; +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.util.ByteStringOutputStream; +import org.apache.beam.sdk.util.Preconditions; +import org.checkerframework.checker.nullness.qual.Nullable; + +@Internal +@ThreadSafe +/* + * A utility class for caching a thread-local {@link ByteStringOutputStream}. + * + * Example Usage: + * try (StreamHandle streamHandle = ThreadLocalByteStringOutputStream.acquire()) { + * ByteStringOutputStream stream = streamHandle.stream(); + * stream.write(1); + * ByteString byteString = stream.toByteStringAndReset(); + * } + */ +public class ThreadLocalByteStringOutputStream { + + private static final ThreadLocal<@Nullable SoftRefHolder> threadLocalSoftRefHolder = + ThreadLocal.withInitial(SoftRefHolder::new); + + // Private constructor to prevent instantiations from outside. + private ThreadLocalByteStringOutputStream() {} + + /** @return An AutoClosable StreamHandle that holds a cached ByteStringOutputStream. */ + public static StreamHandle acquire() { + StreamHandle streamHandle = getStreamHandleFromThreadLocal(); + if (streamHandle.inUse) { + // Stream is already in use, create a new uncached one + return new StreamHandle(); + } + streamHandle.inUse = true; + return streamHandle; // inUse will be unset when streamHandle closes. + } + + /** + * Handle to a thread-local {@link ByteStringOutputStream}. If the thread local stream is already + * in use, a new one is used. The streams are cached and reused across calls. Users should not + * keep a reference to the stream after closing the StreamHandle. + */ + public static class StreamHandle implements AutoCloseable { + + private final ByteStringOutputStream stream = new ByteStringOutputStream(); + + private boolean inUse = false; + + /** + * Returns the underlying cached ByteStringOutputStream. Callers should not keep a reference to + * the stream after closing the StreamHandle. + */ + public ByteStringOutputStream stream() { + return stream; + } + + @Override + public void close() { + stream.reset(); + inUse = false; + } + } + + private static class SoftRefHolder { + private @Nullable SoftReference softReference; + } + + private static StreamHandle getStreamHandleFromThreadLocal() { + // softRefHolder is only set by Threadlocal initializer and should not be null + SoftRefHolder softRefHolder = + Preconditions.checkArgumentNotNull(threadLocalSoftRefHolder.get()); + @Nullable StreamHandle streamHandle = null; + @Nullable SoftReference softReference = softRefHolder.softReference; + if (softReference != null) { + streamHandle = softReference.get(); + } + if (streamHandle == null) { + streamHandle = new StreamHandle(); + softRefHolder.softReference = new SoftReference<>(streamHandle); + } + return streamHandle; + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillBag.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillBag.java index b15064ff81e0..db1f3e7a6dec 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillBag.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillBag.java @@ -24,6 +24,8 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import org.apache.beam.runners.core.StateNamespace; +import org.apache.beam.runners.dataflow.worker.util.ThreadLocalByteStringOutputStream; +import org.apache.beam.runners.dataflow.worker.util.ThreadLocalByteStringOutputStream.StreamHandle; import org.apache.beam.runners.dataflow.worker.util.common.worker.InternedByteString; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.sdk.coders.Coder; @@ -165,17 +167,20 @@ public Windmill.WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyA if (bagUpdatesBuilder == null) { bagUpdatesBuilder = commitBuilder.addBagUpdatesBuilder(); } - for (T value : localAdditions) { - ByteStringOutputStream stream = new ByteStringOutputStream(); - // Encode the value - elemCoder.encode(value, stream, Coder.Context.OUTER); - ByteString encoded = stream.toByteString(); - if (cachedValues != null) { - // We'll capture this value in the cache below. - // Capture the value's size now since we have it. - encodedSize += encoded.size(); + try (StreamHandle streamHandle = ThreadLocalByteStringOutputStream.acquire()) { + ByteStringOutputStream stream = streamHandle.stream(); + for (T value : localAdditions) { + elemCoder.encode(value, stream, Coder.Context.OUTER); + ByteString encoded = stream.toByteStringAndReset(); + if (cachedValues != null) { + // We'll capture this value in the cache below. + // Capture the value's size now since we have it. + encodedSize += encoded.size(); + } + bagUpdatesBuilder.addValues(encoded); } - bagUpdatesBuilder.addValues(encoded); + } catch (IOException e) { + throw new RuntimeException(e); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateTagUtil.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateTagUtil.java index dbb5f57f8a52..12b4001d530f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateTagUtil.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateTagUtil.java @@ -18,24 +18,23 @@ package org.apache.beam.runners.dataflow.worker.windmill.state; import java.io.IOException; -import java.lang.ref.SoftReference; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.StateTag; import org.apache.beam.runners.core.TimerInternals.TimerData; import org.apache.beam.runners.dataflow.worker.WindmillNamespacePrefix; +import org.apache.beam.runners.dataflow.worker.util.ThreadLocalByteStringOutputStream; +import org.apache.beam.runners.dataflow.worker.util.ThreadLocalByteStringOutputStream.StreamHandle; import org.apache.beam.runners.dataflow.worker.util.common.worker.InternedByteString; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.util.ByteStringOutputStream; import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; -import org.checkerframework.checker.nullness.qual.Nullable; @Internal @ThreadSafe public class WindmillStateTagUtil { - private static final ThreadLocal<@Nullable RefHolder> threadLocalRefHolder = new ThreadLocal<>(); private static final String TIMER_HOLD_PREFIX = "/h"; private static final WindmillStateTagUtil INSTANCE = new WindmillStateTagUtil(); @@ -48,21 +47,10 @@ private WindmillStateTagUtil() {} */ @VisibleForTesting InternedByteString encodeKey(StateNamespace namespace, StateTag address) { - RefHolder refHolder = getRefHolderFromThreadLocal(); - // Use ByteStringOutputStream rather than concatenation and String.format. We build these keys - // a lot, and this leads to better performance results. See associated benchmarks. - ByteStringOutputStream stream; - boolean releaseThreadLocal; - if (refHolder.inUse) { - // If the thread local stream is already in use, create a new one - stream = new ByteStringOutputStream(); - releaseThreadLocal = false; - } else { - stream = getByteStringOutputStream(refHolder); - refHolder.inUse = true; - releaseThreadLocal = true; - } - try { + try (StreamHandle streamHandle = ThreadLocalByteStringOutputStream.acquire()) { + // Use ByteStringOutputStream rather than concatenation and String.format. We build these keys + // a lot, and this leads to better performance results. See associated benchmarks. + ByteStringOutputStream stream = streamHandle.stream(); // stringKey starts and ends with a slash. We separate it from the // StateTag ID by a '+' (which is guaranteed not to be in the stringKey) because the // ID comes from the user. @@ -72,11 +60,6 @@ InternedByteString encodeKey(StateNamespace namespace, StateTag address) { return InternedByteString.of(stream.toByteStringAndReset()); } catch (IOException e) { throw new RuntimeException(e); - } finally { - stream.reset(); - if (releaseThreadLocal) { - refHolder.inUse = false; - } } } @@ -116,35 +99,6 @@ public ByteString timerHoldTag(WindmillNamespacePrefix prefix, TimerData timerDa return ByteString.copyFromUtf8(tagString); } - private static class RefHolder { - - public SoftReference<@Nullable ByteStringOutputStream> streamRef = - new SoftReference<>(new ByteStringOutputStream()); - - // Boolean is true when the thread local stream is already in use by the current thread. - // Used to avoid reusing the same stream from nested calls if any. - public boolean inUse = false; - } - - private static RefHolder getRefHolderFromThreadLocal() { - @Nullable RefHolder refHolder = threadLocalRefHolder.get(); - if (refHolder == null) { - refHolder = new RefHolder(); - threadLocalRefHolder.set(refHolder); - } - return refHolder; - } - - private static ByteStringOutputStream getByteStringOutputStream(RefHolder refHolder) { - @Nullable - ByteStringOutputStream stream = refHolder.streamRef == null ? null : refHolder.streamRef.get(); - if (stream == null) { - stream = new ByteStringOutputStream(); - refHolder.streamRef = new SoftReference<>(stream); - } - return stream; - } - /** @return the singleton WindmillStateTagUtil */ public static WindmillStateTagUtil instance() { return INSTANCE; diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/ThreadLocalByteStringOutputStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/ThreadLocalByteStringOutputStreamTest.java new file mode 100644 index 000000000000..ef167203a96f --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/ThreadLocalByteStringOutputStreamTest.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.util; + +import static org.junit.Assert.*; + +import org.apache.beam.runners.dataflow.worker.util.ThreadLocalByteStringOutputStream.StreamHandle; +import org.apache.beam.sdk.util.ByteStringOutputStream; +import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; +import org.junit.Test; + +public class ThreadLocalByteStringOutputStreamTest { + + @Test + public void simple() { + try (StreamHandle streamHandle = ThreadLocalByteStringOutputStream.acquire()) { + ByteStringOutputStream stream = streamHandle.stream(); + stream.write(1); + stream.write(2); + stream.write(3); + assertEquals(ByteString.copyFrom(new byte[] {1, 2, 3}), stream.toByteStringAndReset()); + } + } + + @Test + public void nested() { + try (StreamHandle streamHandle = ThreadLocalByteStringOutputStream.acquire()) { + ByteStringOutputStream stream = streamHandle.stream(); + stream.write(1); + try (StreamHandle streamHandle1 = ThreadLocalByteStringOutputStream.acquire()) { + ByteStringOutputStream stream1 = streamHandle1.stream(); + stream1.write(2); + assertEquals(ByteString.copyFrom(new byte[] {2}), stream1.toByteStringAndReset()); + } + stream.write(3); + assertEquals(ByteString.copyFrom(new byte[] {1, 3}), stream.toByteStringAndReset()); + } + } + + @Test + public void resetDirtyStream() { + try (StreamHandle streamHandle = ThreadLocalByteStringOutputStream.acquire()) { + ByteStringOutputStream stream = streamHandle.stream(); + stream.write(1); + // Don't read/reset stream + } + + try (StreamHandle streamHandle = ThreadLocalByteStringOutputStream.acquire()) { + ByteStringOutputStream stream = streamHandle.stream(); + assertEquals(ByteString.EMPTY, stream.toByteStringAndReset()); + } + } +}