88
99#include < workerd/api/system-streams.h>
1010#include < workerd/io/features.h>
11+ #include < workerd/util/autogate.h>
1112#include < workerd/util/ring-buffer.h>
13+ #include < workerd/util/state-machine.h>
1214
1315namespace workerd ::api {
1416CompressionAllocator::CompressionAllocator (
@@ -235,94 +237,74 @@ class LazyBuffer {
235237 size_t valid_size_;
236238};
237239
238- // Uncompressed data goes in. Compressed data comes out.
240+ // Because we have to use an autogate to switch things over to the new state manager, we need
241+ // to separate out a common base class for the compression stream internal state and separate
242+ // two separate impls that differ only in how they manage state. Once the autogate is removed,
243+ // we can delete the first impl class and merge everything back together.
239244template <Context::Mode mode>
240- class CompressionStreamImpl : public kj ::Refcounted,
245+ class CompressionStreamBase : public kj ::Refcounted,
241246 public kj::AsyncInputStream,
242247 public capnp::ExplicitEndOutputStream {
243248 public:
244- explicit CompressionStreamImpl (kj::String format,
249+ explicit CompressionStreamBase (kj::String format,
245250 Context::ContextFlags flags,
246251 kj::Arc<const jsg::ExternalMemoryTarget>&& externalMemoryTarget)
247252 : context(mode, format, flags, kj::mv(externalMemoryTarget)) {}
248253
249254 // WritableStreamSink implementation ---------------------------------------------------
250255
251- kj::Promise<void > write (kj::ArrayPtr<const byte> buffer) override {
252- KJ_SWITCH_ONEOF (state) {
253- KJ_CASE_ONEOF (ended, Ended) {
254- JSG_FAIL_REQUIRE (Error, " Write after close" );
255- }
256- KJ_CASE_ONEOF (exception, kj::Exception) {
257- kj::throwFatalException (kj::cp (exception));
258- }
259- KJ_CASE_ONEOF (open, Open) {
260- context.setInput (buffer.begin (), buffer.size ());
261- writeInternal (Z_NO_FLUSH);
262- co_return ;
263- }
264- }
265- KJ_UNREACHABLE;
256+ kj::Promise<void > write (kj::ArrayPtr<const byte> buffer) override final {
257+ requireActive (" Write after close" );
258+ context.setInput (buffer.begin (), buffer.size ());
259+ writeInternal (Z_NO_FLUSH);
260+ co_return ;
266261 }
267262
268- kj::Promise<void > write (kj::ArrayPtr<const kj::ArrayPtr<const kj::byte>> pieces) override {
269- // We check for Ended, Exception here so that we catch
270- // these even if pieces is empty.
271- KJ_SWITCH_ONEOF (state) {
272- KJ_CASE_ONEOF (ended, Ended) {
273- JSG_FAIL_REQUIRE (Error, " Write after close" );
274- }
275- KJ_CASE_ONEOF (exception, kj::Exception) {
276- kj::throwFatalException (kj::cp (exception));
277- }
278- KJ_CASE_ONEOF (open, Open) {
279- for (auto piece: pieces) {
280- co_await write (piece);
281- }
282- co_return ;
283- }
263+ kj::Promise<void > write (kj::ArrayPtr<const kj::ArrayPtr<const kj::byte>> pieces) override final {
264+ // We check state here so that we catch errors even if pieces is empty.
265+ requireActive (" Write after close" );
266+ for (auto piece: pieces) {
267+ co_await write (piece);
284268 }
285- KJ_UNREACHABLE ;
269+ co_return ;
286270 }
287271
288- kj::Promise<void > end () override {
289- state = Ended ();
272+ kj::Promise<void > end () override final {
273+ transitionToEnded ();
290274 writeInternal (Z_FINISH);
291275 co_return ;
292276 }
293277
294- kj::Promise<void > whenWriteDisconnected () override {
278+ kj::Promise<void > whenWriteDisconnected () override final {
295279 return kj::NEVER_DONE;
296280 }
297281
298- void abortWrite (kj::Exception&& reason) override {
282+ void abortWrite (kj::Exception&& reason) override final {
299283 cancelInternal (kj::mv (reason));
300284 }
301285
302286 // AsyncInputStream implementation -----------------------------------------------------
303287
304- kj::Promise<size_t > tryRead (void * buffer, size_t minBytes, size_t maxBytes) override {
288+ kj::Promise<size_t > tryRead (void * buffer, size_t minBytes, size_t maxBytes) override final {
305289 KJ_ASSERT (minBytes <= maxBytes);
306- KJ_SWITCH_ONEOF (state) {
307- KJ_CASE_ONEOF (ended, Ended) {
308- // There might still be data in the output buffer remaining to read.
309- if (output.empty ()) {
310- co_return static_cast <size_t >(0 );
311- }
312- co_return co_await tryReadInternal (
313- kj::arrayPtr (reinterpret_cast <kj::byte*>(buffer), maxBytes), minBytes);
314- }
315- KJ_CASE_ONEOF (exception, kj::Exception) {
316- kj::throwFatalException (kj::cp (exception));
317- }
318- KJ_CASE_ONEOF (open, Open) {
319- co_return co_await tryReadInternal (
320- kj::arrayPtr (reinterpret_cast <kj::byte*>(buffer), maxBytes), minBytes);
321- }
290+ // Re-throw any stored exception
291+ throwIfException ();
292+ // If stream has ended normally and no buffered data, return EOF
293+ if (isInTerminalState () && output.empty ()) {
294+ co_return static_cast <size_t >(0 );
322295 }
323- KJ_UNREACHABLE;
296+ // Active or terminal with data remaining
297+ co_return co_await tryReadInternal (
298+ kj::arrayPtr (reinterpret_cast <kj::byte*>(buffer), maxBytes), minBytes);
324299 }
325300
301+ protected:
302+ virtual void requireActive (kj::StringPtr errorMessage) = 0;
303+ virtual void transitionToEnded () = 0;
304+ virtual void transitionToErrored (kj::Exception&& reason) = 0;
305+ virtual void throwIfException () = 0;
306+ virtual bool isInTerminalState () = 0;
307+
326308 private:
327309 struct PendingRead {
328310 kj::ArrayPtr<kj::byte> buffer;
@@ -343,7 +325,8 @@ class CompressionStreamImpl: public kj::Refcounted,
343325 }
344326
345327 canceler.cancel (kj::cp (reason));
346- state = kj::mv (reason);
328+ transitionToErrored (kj::mv (reason));
329+ // state = kj::mv(reason);
347330 }
348331
349332 kj::Promise<size_t > tryReadInternal (kj::ArrayPtr<kj::byte> dest, size_t minBytes) {
@@ -357,9 +340,9 @@ class CompressionStreamImpl: public kj::Refcounted,
357340 // If the output currently contains >= minBytes, then we'll fulfill
358341 // the read immediately, removing as many bytes as possible from the
359342 // output queue.
360- // If we reached the end, resolve the read immediately as well, since no
361- // new data is expected.
362- if (output.size () >= minBytes || state. template is <Ended> ()) {
343+ // If we reached the end (terminal state) , resolve the read immediately
344+ // as well, since no new data is expected.
345+ if (output.size () >= minBytes || isInTerminalState ()) {
363346 co_return copyIntoBuffer (dest);
364347 }
365348
@@ -385,7 +368,7 @@ class CompressionStreamImpl: public kj::Refcounted,
385368 void writeInternal (int flush) {
386369 // TODO(later): This does not yet implement any backpressure. A caller can keep calling
387370 // write without reading, which will continue to fill the internal buffer.
388- KJ_ASSERT (flush == Z_FINISH || state. template is <Open> ());
371+ KJ_ASSERT (flush == Z_FINISH || ! isInTerminalState ());
389372 Context::Result result;
390373
391374 while (true ) {
@@ -460,7 +443,7 @@ class CompressionStreamImpl: public kj::Refcounted,
460443 KJ_ASSERT (output.empty ());
461444 }
462445
463- if (state. template is <Ended> () && !pendingReads.empty ()) {
446+ if (isInTerminalState () && !pendingReads.empty ()) {
464447 // We are ended and we have pending reads. Because of the loop above,
465448 // one of either pendingReads or output must be empty, so if we got this
466449 // far, output.empty() must be true. Let's check.
@@ -477,17 +460,128 @@ class CompressionStreamImpl: public kj::Refcounted,
477460 }
478461 }
479462
480- struct Ended {};
481- struct Open {};
482-
483- kj::OneOf<Open, Ended, kj::Exception> state = Open();
484463 Context context;
485464
486465 kj::Canceler canceler;
487466 LazyBuffer output;
488467 RingBuffer<PendingRead, 8 > pendingReads;
489468};
490469
470+ // Uncompressed data goes in. Compressed data comes out.
471+ // TODO(cleanup): Once the autogate is removed, delete this class and merge CompressionStreamBase
472+ // and CompressionStreamImplV2 back into a single class.
473+ template <Context::Mode mode>
474+ class CompressionStreamImpl final : public CompressionStreamBase<mode> {
475+ public:
476+ explicit CompressionStreamImpl (kj::String format,
477+ Context::ContextFlags flags,
478+ kj::Arc<const jsg::ExternalMemoryTarget>&& externalMemoryTarget)
479+ : CompressionStreamBase<mode>(kj::mv(format), flags, kj::mv(externalMemoryTarget)) {}
480+
481+ protected:
482+ void requireActive (kj::StringPtr errorMessage) override {
483+ KJ_SWITCH_ONEOF (state) {
484+ KJ_CASE_ONEOF (ended, Ended) {
485+ JSG_FAIL_REQUIRE (Error, errorMessage);
486+ }
487+ KJ_CASE_ONEOF (exception, kj::Exception) {
488+ kj::throwFatalException (kj::cp (exception));
489+ }
490+ KJ_CASE_ONEOF (open, Open) {
491+ return ;
492+ }
493+ }
494+ KJ_UNREACHABLE;
495+ }
496+
497+ void transitionToEnded () override {
498+ state = Ended ();
499+ }
500+
501+ void transitionToErrored (kj::Exception&& reason) override {
502+ state = kj::mv (reason);
503+ }
504+
505+ void throwIfException () override {
506+ KJ_IF_SOME (exception, state.template tryGet <kj::Exception>()) {
507+ kj::throwFatalException (kj::cp (exception));
508+ }
509+ }
510+
511+ virtual bool isInTerminalState () override {
512+ // Ended or Exception are both terminal states.
513+ return state.template is <Ended>() || state.template is <kj::Exception>();
514+ }
515+
516+ private:
517+ struct Ended {};
518+ struct Open {};
519+
520+ kj::OneOf<Open, Ended, kj::Exception> state = Open();
521+ };
522+
523+ template <Context::Mode mode>
524+ class CompressionStreamImplV2 final : public CompressionStreamBase<mode> {
525+ public:
526+ explicit CompressionStreamImplV2 (kj::String format,
527+ Context::ContextFlags flags,
528+ kj::Arc<const jsg::ExternalMemoryTarget>&& externalMemoryTarget)
529+ : CompressionStreamBase<mode>(kj::mv(format), flags, kj::mv(externalMemoryTarget)),
530+ state(decltype (state)::template create<Open>()) {}
531+
532+ protected:
533+ void requireActive (kj::StringPtr errorMessage) override {
534+ KJ_IF_SOME (exception, state.tryGetErrorUnsafe ()) {
535+ kj::throwFatalException (kj::cp (exception));
536+ }
537+ // isActive() returns true only if in Open state (the ActiveState)
538+ JSG_REQUIRE (state.isActive (), Error, errorMessage);
539+ }
540+
541+ void transitionToEnded () override {
542+ // Use transitionFromTo to ensure we're in Open state before ending.
543+ // This provides a clearer error if end() is called twice.
544+ auto result = state.template transitionFromTo <Open, Ended>();
545+ KJ_REQUIRE (result != kj::none, " Stream already ended or errored" );
546+ }
547+
548+ void transitionToErrored (kj::Exception&& reason) override {
549+ // Use forceTransitionTo because cancelInternal may be called when already
550+ // in an error state (e.g., from writeInternal error handling).
551+ state.template forceTransitionTo <kj::Exception>(kj::mv (reason));
552+ }
553+
554+ void throwIfException () override {
555+ KJ_IF_SOME (exception, state.tryGetErrorUnsafe ()) {
556+ kj::throwFatalException (kj::cp (exception));
557+ }
558+ }
559+
560+ virtual bool isInTerminalState () override {
561+ return state.isTerminal ();
562+ }
563+
564+ private:
565+ struct Ended {
566+ static constexpr kj::StringPtr NAME KJ_UNUSED = " ended" _kj;
567+ };
568+ struct Open {
569+ static constexpr kj::StringPtr NAME KJ_UNUSED = " open" _kj;
570+ };
571+
572+ // State machine for tracking compression stream lifecycle:
573+ // Open -> Ended (normal close via end())
574+ // Open -> kj::Exception (error via abortWrite())
575+ // Ended is terminal, kj::Exception is implicitly terminal via ErrorState.
576+ StateMachine<TerminalStates<Ended>,
577+ ErrorState<kj::Exception>,
578+ ActiveState<Open>,
579+ Open,
580+ Ended,
581+ kj::Exception>
582+ state;
583+ };
584+
491585// Adapter to bridge CompressionStreamImpl (which implements AsyncInputStream and
492586// ExplicitEndOutputStream) to the ReadableStreamSource/WritableStreamSink interfaces.
493587// TODO(soon): This class is intended to be replaced by the new ReadableSource/WritableSink
@@ -500,7 +594,7 @@ class CompressionStreamAdapter final: public kj::Refcounted,
500594 public ReadableStreamSource,
501595 public WritableStreamSink {
502596 public:
503- explicit CompressionStreamAdapter (kj::Rc<CompressionStreamImpl <mode>> impl)
597+ explicit CompressionStreamAdapter (kj::Rc<CompressionStreamBase <mode>> impl)
504598 : impl(kj::mv(impl)),
505599 ioContext(IoContext::current()) {}
506600
@@ -532,17 +626,44 @@ class CompressionStreamAdapter final: public kj::Refcounted,
532626 }
533627
534628 private:
535- kj::Rc<CompressionStreamImpl <mode>> impl;
629+ kj::Rc<CompressionStreamBase <mode>> impl;
536630 IoContext& ioContext;
537631};
538632
633+ kj::Rc<CompressionStreamBase<Context::Mode::COMPRESS>> createCompressionStreamImpl (
634+ kj::String format,
635+ Context::ContextFlags flags,
636+ kj::Arc<const jsg::ExternalMemoryTarget>&& externalMemoryTarget) {
637+ // TODO(cleanup): Once the autogate is removed, we can delete CompressionStreamImpl
638+ if (util::Autogate::isEnabled (util::AutogateKey::COMPRESSION_STREAM_USE_STATE_MACHINE)) {
639+ return kj::rc<CompressionStreamImplV2<Context::Mode::COMPRESS>>(
640+ kj::mv (format), flags, kj::mv (externalMemoryTarget));
641+ }
642+ return kj::rc<CompressionStreamImpl<Context::Mode::COMPRESS>>(
643+ kj::mv (format), flags, kj::mv (externalMemoryTarget));
644+ }
645+
646+ kj::Rc<CompressionStreamBase<Context::Mode::DECOMPRESS>> createDecompressionStreamImpl (
647+ kj::String format,
648+ Context::ContextFlags flags,
649+ kj::Arc<const jsg::ExternalMemoryTarget>&& externalMemoryTarget) {
650+ // TODO(cleanup): Once the autogate is removed, we can delete CompressionStreamImpl
651+ if (util::Autogate::isEnabled (util::AutogateKey::COMPRESSION_STREAM_USE_STATE_MACHINE)) {
652+ return kj::rc<CompressionStreamImplV2<Context::Mode::DECOMPRESS>>(
653+ kj::mv (format), flags, kj::mv (externalMemoryTarget));
654+ }
655+ return kj::rc<CompressionStreamImpl<Context::Mode::DECOMPRESS>>(
656+ kj::mv (format), flags, kj::mv (externalMemoryTarget));
657+ }
658+
539659} // namespace
540660
541661jsg::Ref<CompressionStream> CompressionStream::constructor (jsg::Lock& js, kj::String format) {
542662 JSG_REQUIRE (format == " deflate" || format == " gzip" || format == " deflate-raw" , TypeError,
543663 " The compression format must be either 'deflate', 'deflate-raw' or 'gzip'." );
544664
545- auto impl = kj::rc<CompressionStreamImpl<Context::Mode::COMPRESS>>(
665+ // TODO(cleanup): Once the autogate is removed, we can delete CompressionStreamImpl
666+ kj::Rc<CompressionStreamBase<Context::Mode::COMPRESS>> impl = createCompressionStreamImpl (
546667 kj::mv (format), Context::ContextFlags::NONE, js.getExternalMemoryTarget ());
547668
548669 auto & ioContext = IoContext::current ();
@@ -561,10 +682,11 @@ jsg::Ref<DecompressionStream> DecompressionStream::constructor(jsg::Lock& js, kj
561682 JSG_REQUIRE (format == " deflate" || format == " gzip" || format == " deflate-raw" , TypeError,
562683 " The compression format must be either 'deflate', 'deflate-raw' or 'gzip'." );
563684
564- auto impl = kj::rc<CompressionStreamImpl<Context::Mode::DECOMPRESS>>(kj::mv (format),
565- FeatureFlags::get (js).getStrictCompression () ? Context::ContextFlags::STRICT
566- : Context::ContextFlags::NONE,
567- js.getExternalMemoryTarget ());
685+ kj::Rc<CompressionStreamBase<Context::Mode::DECOMPRESS>> impl =
686+ createDecompressionStreamImpl (kj::mv (format),
687+ FeatureFlags::get (js).getStrictCompression () ? Context::ContextFlags::STRICT
688+ : Context::ContextFlags::NONE,
689+ js.getExternalMemoryTarget ());
568690
569691 auto & ioContext = IoContext::current ();
570692
0 commit comments