diff --git a/src/workerd/api/streams/compression.c++ b/src/workerd/api/streams/compression.c++ index 6f5e43f16a8..b4ac31cab26 100644 --- a/src/workerd/api/streams/compression.c++ +++ b/src/workerd/api/streams/compression.c++ @@ -8,7 +8,9 @@ #include #include +#include #include +#include namespace workerd::api { CompressionAllocator::CompressionAllocator( @@ -235,94 +237,74 @@ class LazyBuffer { size_t valid_size_; }; -// Uncompressed data goes in. Compressed data comes out. +// Because we have to use an autogate to switch things over to the new state manager, we need +// to separate out a common base class for the compression stream internal state and separate +// two separate impls that differ only in how they manage state. Once the autogate is removed, +// we can delete the first impl class and merge everything back together. template -class CompressionStreamImpl: public kj::Refcounted, +class CompressionStreamBase: public kj::Refcounted, public kj::AsyncInputStream, public capnp::ExplicitEndOutputStream { public: - explicit CompressionStreamImpl(kj::String format, + explicit CompressionStreamBase(kj::String format, Context::ContextFlags flags, kj::Arc&& externalMemoryTarget) : context(mode, format, flags, kj::mv(externalMemoryTarget)) {} // WritableStreamSink implementation --------------------------------------------------- - kj::Promise write(kj::ArrayPtr buffer) override { - KJ_SWITCH_ONEOF(state) { - KJ_CASE_ONEOF(ended, Ended) { - JSG_FAIL_REQUIRE(Error, "Write after close"); - } - KJ_CASE_ONEOF(exception, kj::Exception) { - kj::throwFatalException(kj::cp(exception)); - } - KJ_CASE_ONEOF(open, Open) { - context.setInput(buffer.begin(), buffer.size()); - writeInternal(Z_NO_FLUSH); - co_return; - } - } - KJ_UNREACHABLE; + kj::Promise write(kj::ArrayPtr buffer) override final { + requireActive("Write after close"); + context.setInput(buffer.begin(), buffer.size()); + writeInternal(Z_NO_FLUSH); + co_return; } - kj::Promise write(kj::ArrayPtr> pieces) override { - // We check for Ended, Exception here so that we catch - // these even if pieces is empty. - KJ_SWITCH_ONEOF(state) { - KJ_CASE_ONEOF(ended, Ended) { - JSG_FAIL_REQUIRE(Error, "Write after close"); - } - KJ_CASE_ONEOF(exception, kj::Exception) { - kj::throwFatalException(kj::cp(exception)); - } - KJ_CASE_ONEOF(open, Open) { - for (auto piece: pieces) { - co_await write(piece); - } - co_return; - } + kj::Promise write(kj::ArrayPtr> pieces) override final { + // We check state here so that we catch errors even if pieces is empty. + requireActive("Write after close"); + for (auto piece: pieces) { + co_await write(piece); } - KJ_UNREACHABLE; + co_return; } - kj::Promise end() override { - state = Ended(); + kj::Promise end() override final { + transitionToEnded(); writeInternal(Z_FINISH); co_return; } - kj::Promise whenWriteDisconnected() override { + kj::Promise whenWriteDisconnected() override final { return kj::NEVER_DONE; } - void abortWrite(kj::Exception&& reason) override { + void abortWrite(kj::Exception&& reason) override final { cancelInternal(kj::mv(reason)); } // AsyncInputStream implementation ----------------------------------------------------- - kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override final { KJ_ASSERT(minBytes <= maxBytes); - KJ_SWITCH_ONEOF(state) { - KJ_CASE_ONEOF(ended, Ended) { - // There might still be data in the output buffer remaining to read. - if (output.empty()) { - co_return static_cast(0); - } - co_return co_await tryReadInternal( - kj::arrayPtr(reinterpret_cast(buffer), maxBytes), minBytes); - } - KJ_CASE_ONEOF(exception, kj::Exception) { - kj::throwFatalException(kj::cp(exception)); - } - KJ_CASE_ONEOF(open, Open) { - co_return co_await tryReadInternal( - kj::arrayPtr(reinterpret_cast(buffer), maxBytes), minBytes); - } + // Re-throw any stored exception + throwIfException(); + // If stream has ended normally and no buffered data, return EOF + if (isInTerminalState() && output.empty()) { + co_return static_cast(0); } - KJ_UNREACHABLE; + // Active or terminal with data remaining + co_return co_await tryReadInternal( + kj::arrayPtr(reinterpret_cast(buffer), maxBytes), minBytes); } + protected: + virtual void requireActive(kj::StringPtr errorMessage) = 0; + virtual void transitionToEnded() = 0; + virtual void transitionToErrored(kj::Exception&& reason) = 0; + virtual void throwIfException() = 0; + virtual bool isInTerminalState() = 0; + private: struct PendingRead { kj::ArrayPtr buffer; @@ -343,7 +325,8 @@ class CompressionStreamImpl: public kj::Refcounted, } canceler.cancel(kj::cp(reason)); - state = kj::mv(reason); + transitionToErrored(kj::mv(reason)); + //state = kj::mv(reason); } kj::Promise tryReadInternal(kj::ArrayPtr dest, size_t minBytes) { @@ -357,9 +340,9 @@ class CompressionStreamImpl: public kj::Refcounted, // If the output currently contains >= minBytes, then we'll fulfill // the read immediately, removing as many bytes as possible from the // output queue. - // If we reached the end, resolve the read immediately as well, since no - // new data is expected. - if (output.size() >= minBytes || state.template is()) { + // If we reached the end (terminal state), resolve the read immediately + // as well, since no new data is expected. + if (output.size() >= minBytes || isInTerminalState()) { co_return copyIntoBuffer(dest); } @@ -385,7 +368,7 @@ class CompressionStreamImpl: public kj::Refcounted, void writeInternal(int flush) { // TODO(later): This does not yet implement any backpressure. A caller can keep calling // write without reading, which will continue to fill the internal buffer. - KJ_ASSERT(flush == Z_FINISH || state.template is()); + KJ_ASSERT(flush == Z_FINISH || !isInTerminalState()); Context::Result result; while (true) { @@ -460,7 +443,7 @@ class CompressionStreamImpl: public kj::Refcounted, KJ_ASSERT(output.empty()); } - if (state.template is() && !pendingReads.empty()) { + if (isInTerminalState() && !pendingReads.empty()) { // We are ended and we have pending reads. Because of the loop above, // one of either pendingReads or output must be empty, so if we got this // far, output.empty() must be true. Let's check. @@ -477,10 +460,6 @@ class CompressionStreamImpl: public kj::Refcounted, } } - struct Ended {}; - struct Open {}; - - kj::OneOf state = Open(); Context context; kj::Canceler canceler; @@ -488,6 +467,121 @@ class CompressionStreamImpl: public kj::Refcounted, RingBuffer pendingReads; }; +// Uncompressed data goes in. Compressed data comes out. +// TODO(cleanup): Once the autogate is removed, delete this class and merge CompressionStreamBase +// and CompressionStreamImplV2 back into a single class. +template +class CompressionStreamImpl final: public CompressionStreamBase { + public: + explicit CompressionStreamImpl(kj::String format, + Context::ContextFlags flags, + kj::Arc&& externalMemoryTarget) + : CompressionStreamBase(kj::mv(format), flags, kj::mv(externalMemoryTarget)) {} + + protected: + void requireActive(kj::StringPtr errorMessage) override { + KJ_SWITCH_ONEOF(state) { + KJ_CASE_ONEOF(ended, Ended) { + JSG_FAIL_REQUIRE(Error, errorMessage); + } + KJ_CASE_ONEOF(exception, kj::Exception) { + kj::throwFatalException(kj::cp(exception)); + } + KJ_CASE_ONEOF(open, Open) { + return; + } + } + KJ_UNREACHABLE; + } + + void transitionToEnded() override { + state = Ended(); + } + + void transitionToErrored(kj::Exception&& reason) override { + state = kj::mv(reason); + } + + void throwIfException() override { + KJ_IF_SOME(exception, state.template tryGet()) { + kj::throwFatalException(kj::cp(exception)); + } + } + + virtual bool isInTerminalState() override { + // Ended or Exception are both terminal states. + return state.template is() || state.template is(); + } + + private: + struct Ended {}; + struct Open {}; + + kj::OneOf state = Open(); +}; + +template +class CompressionStreamImplV2 final: public CompressionStreamBase { + public: + explicit CompressionStreamImplV2(kj::String format, + Context::ContextFlags flags, + kj::Arc&& externalMemoryTarget) + : CompressionStreamBase(kj::mv(format), flags, kj::mv(externalMemoryTarget)), + state(decltype(state)::template create()) {} + + protected: + void requireActive(kj::StringPtr errorMessage) override { + KJ_IF_SOME(exception, state.tryGetErrorUnsafe()) { + kj::throwFatalException(kj::cp(exception)); + } + // isActive() returns true only if in Open state (the ActiveState) + JSG_REQUIRE(state.isActive(), Error, errorMessage); + } + + void transitionToEnded() override { + // Use transitionFromTo to ensure we're in Open state before ending. + // This provides a clearer error if end() is called twice. + auto result = state.template transitionFromTo(); + KJ_REQUIRE(result != kj::none, "Stream already ended or errored"); + } + + void transitionToErrored(kj::Exception&& reason) override { + // Use forceTransitionTo because cancelInternal may be called when already + // in an error state (e.g., from writeInternal error handling). + state.template forceTransitionTo(kj::mv(reason)); + } + + void throwIfException() override { + KJ_IF_SOME(exception, state.tryGetErrorUnsafe()) { + kj::throwFatalException(kj::cp(exception)); + } + } + + virtual bool isInTerminalState() override { + return state.isTerminal(); + } + + private: + struct Ended { + static constexpr kj::StringPtr NAME KJ_UNUSED = "ended"_kj; + }; + struct Open { + static constexpr kj::StringPtr NAME KJ_UNUSED = "open"_kj; + }; + + // State machine for tracking compression stream lifecycle: + // Open -> Ended (normal close via end()) + // Open -> kj::Exception (error via abortWrite()) + // Ended is terminal, kj::Exception is implicitly terminal via ErrorState. + StateMachine, + ErrorState, + ActiveState, + Open, + Ended, + kj::Exception> + state; +}; + // Adapter to bridge CompressionStreamImpl (which implements AsyncInputStream and // ExplicitEndOutputStream) to the ReadableStreamSource/WritableStreamSink interfaces. // TODO(soon): This class is intended to be replaced by the new ReadableSource/WritableSink @@ -500,7 +594,7 @@ class CompressionStreamAdapter final: public kj::Refcounted, public ReadableStreamSource, public WritableStreamSink { public: - explicit CompressionStreamAdapter(kj::Rc> impl) + explicit CompressionStreamAdapter(kj::Rc> impl) : impl(kj::mv(impl)), ioContext(IoContext::current()) {} @@ -532,17 +626,44 @@ class CompressionStreamAdapter final: public kj::Refcounted, } private: - kj::Rc> impl; + kj::Rc> impl; IoContext& ioContext; }; +kj::Rc> createCompressionStreamImpl( + kj::String format, + Context::ContextFlags flags, + kj::Arc&& externalMemoryTarget) { + // TODO(cleanup): Once the autogate is removed, we can delete CompressionStreamImpl + if (util::Autogate::isEnabled(util::AutogateKey::COMPRESSION_STREAM_USE_STATE_MACHINE)) { + return kj::rc>( + kj::mv(format), flags, kj::mv(externalMemoryTarget)); + } + return kj::rc>( + kj::mv(format), flags, kj::mv(externalMemoryTarget)); +} + +kj::Rc> createDecompressionStreamImpl( + kj::String format, + Context::ContextFlags flags, + kj::Arc&& externalMemoryTarget) { + // TODO(cleanup): Once the autogate is removed, we can delete CompressionStreamImpl + if (util::Autogate::isEnabled(util::AutogateKey::COMPRESSION_STREAM_USE_STATE_MACHINE)) { + return kj::rc>( + kj::mv(format), flags, kj::mv(externalMemoryTarget)); + } + return kj::rc>( + kj::mv(format), flags, kj::mv(externalMemoryTarget)); +} + } // namespace jsg::Ref CompressionStream::constructor(jsg::Lock& js, kj::String format) { JSG_REQUIRE(format == "deflate" || format == "gzip" || format == "deflate-raw", TypeError, "The compression format must be either 'deflate', 'deflate-raw' or 'gzip'."); - auto impl = kj::rc>( + // TODO(cleanup): Once the autogate is removed, we can delete CompressionStreamImpl + kj::Rc> impl = createCompressionStreamImpl( kj::mv(format), Context::ContextFlags::NONE, js.getExternalMemoryTarget()); auto& ioContext = IoContext::current(); @@ -561,10 +682,11 @@ jsg::Ref DecompressionStream::constructor(jsg::Lock& js, kj JSG_REQUIRE(format == "deflate" || format == "gzip" || format == "deflate-raw", TypeError, "The compression format must be either 'deflate', 'deflate-raw' or 'gzip'."); - auto impl = kj::rc>(kj::mv(format), - FeatureFlags::get(js).getStrictCompression() ? Context::ContextFlags::STRICT - : Context::ContextFlags::NONE, - js.getExternalMemoryTarget()); + kj::Rc> impl = + createDecompressionStreamImpl(kj::mv(format), + FeatureFlags::get(js).getStrictCompression() ? Context::ContextFlags::STRICT + : Context::ContextFlags::NONE, + js.getExternalMemoryTarget()); auto& ioContext = IoContext::current(); diff --git a/src/workerd/api/tests/BUILD.bazel b/src/workerd/api/tests/BUILD.bazel index e7b0e3ee22b..bd6b8c8a533 100644 --- a/src/workerd/api/tests/BUILD.bazel +++ b/src/workerd/api/tests/BUILD.bazel @@ -578,3 +578,11 @@ wd_test( args = ["--experimental"], data = ["headers-immutable-prototype-test.js"], ) + +# TODO(cleanup): This is a copy of an existing test in streams-test. Once the autogate is remvoed, +# this separate test can be deleted. +wd_test( + src = "compression-streams-test.wd-test", + args = ["--experimental"], + data = ["compression-streams-test.js"], +) diff --git a/src/workerd/api/tests/compression-streams-test.js b/src/workerd/api/tests/compression-streams-test.js new file mode 100644 index 00000000000..ea0e153d183 --- /dev/null +++ b/src/workerd/api/tests/compression-streams-test.js @@ -0,0 +1,22 @@ +import { strictEqual } from 'node:assert'; + +// TODO(cleanup): This is a copy of an existing test in streams-test. Once the autogate is remvoed, +// this separate test can be deleted. +export const test = { + async test() { + const cs = new CompressionStream('gzip'); + const cw = cs.writable.getWriter(); + await cw.write(new TextEncoder().encode('0123456789'.repeat(1000))); + await cw.close(); + const data = await new Response(cs.readable).arrayBuffer(); + strictEqual(66, data.byteLength); + + const ds = new DecompressionStream('gzip'); + const dw = ds.writable.getWriter(); + await dw.write(data); + await dw.close(); + + const read = await new Response(ds.readable).arrayBuffer(); + strictEqual(10_000, read.byteLength); + }, +}; diff --git a/src/workerd/api/tests/compression-streams-test.wd-test b/src/workerd/api/tests/compression-streams-test.wd-test new file mode 100644 index 00000000000..86dd6821b41 --- /dev/null +++ b/src/workerd/api/tests/compression-streams-test.wd-test @@ -0,0 +1,18 @@ +using Workerd = import "/workerd/workerd.capnp"; + +# TODO(cleanup): This is a copy of an existing test in streams-test. Once the autogate is remvoed, +# this separate test can be deleted. +const unitTests :Workerd.Config = ( + autogates = ["workerd-autogate-compression-stream-use-state-machine"], + services = [ + ( name = "compression-streams-test", + worker = ( + modules = [ + (name = "worker", esModule = embed "compression-streams-test.js") + ], + compatibilityDate = "2025-12-15", + compatibilityFlags = ["nodejs_compat"], + ) + ), + ], +); diff --git a/src/workerd/util/autogate.c++ b/src/workerd/util/autogate.c++ index aac95cbe387..61e99e636d9 100644 --- a/src/workerd/util/autogate.c++ +++ b/src/workerd/util/autogate.c++ @@ -31,6 +31,8 @@ kj::StringPtr KJ_STRINGIFY(AutogateKey key) { return "fetch-request-memory-adjustment"_kj; case AutogateKey::RUST_BACKED_NODE_DNS: return "rust-backed-node-dns"_kj; + case AutogateKey::COMPRESSION_STREAM_USE_STATE_MACHINE: + return "compression-stream-use-state-machine"_kj; case AutogateKey::NumOfKeys: KJ_FAIL_ASSERT("NumOfKeys should not be used in getName"); } diff --git a/src/workerd/util/autogate.h b/src/workerd/util/autogate.h index 53fc4c08760..23f5109e21e 100644 --- a/src/workerd/util/autogate.h +++ b/src/workerd/util/autogate.h @@ -26,6 +26,8 @@ enum class AutogateKey { FETCH_REQUEST_MEMORY_ADJUSTMENT, // Enable Rust-backed Node.js DNS implementation RUST_BACKED_NODE_DNS, + // Switch the CompressionStream to use the new state machine-based impl + COMPRESSION_STREAM_USE_STATE_MACHINE, NumOfKeys // Reserved for iteration. };