Skip to content

Commit 5f014f0

Browse files
committed
Convert CompressionStream to autogate
1 parent 7a5c452 commit 5f014f0

File tree

6 files changed

+248
-74
lines changed

6 files changed

+248
-74
lines changed

src/workerd/api/streams/compression.c++

Lines changed: 196 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
#include <workerd/api/system-streams.h>
1010
#include <workerd/io/features.h>
11+
#include <workerd/util/autogate.h>
1112
#include <workerd/util/ring-buffer.h>
13+
#include <workerd/util/state-machine.h>
1214

1315
namespace workerd::api {
1416
CompressionAllocator::CompressionAllocator(
@@ -235,94 +237,74 @@ class LazyBuffer {
235237
size_t valid_size_;
236238
};
237239

238-
// Uncompressed data goes in. Compressed data comes out.
240+
// Because we have to use an autogate to switch things over to the new state manager, we need
241+
// to separate out a common base class for the compression stream internal state and separate
242+
// two separate impls that differ only in how they manage state. Once the autogate is removed,
243+
// we can delete the first impl class and merge everything back together.
239244
template <Context::Mode mode>
240-
class CompressionStreamImpl: public kj::Refcounted,
245+
class CompressionStreamBase: public kj::Refcounted,
241246
public kj::AsyncInputStream,
242247
public capnp::ExplicitEndOutputStream {
243248
public:
244-
explicit CompressionStreamImpl(kj::String format,
249+
explicit CompressionStreamBase(kj::String format,
245250
Context::ContextFlags flags,
246251
kj::Arc<const jsg::ExternalMemoryTarget>&& externalMemoryTarget)
247252
: context(mode, format, flags, kj::mv(externalMemoryTarget)) {}
248253

249254
// WritableStreamSink implementation ---------------------------------------------------
250255

251-
kj::Promise<void> write(kj::ArrayPtr<const byte> buffer) override {
252-
KJ_SWITCH_ONEOF(state) {
253-
KJ_CASE_ONEOF(ended, Ended) {
254-
JSG_FAIL_REQUIRE(Error, "Write after close");
255-
}
256-
KJ_CASE_ONEOF(exception, kj::Exception) {
257-
kj::throwFatalException(kj::cp(exception));
258-
}
259-
KJ_CASE_ONEOF(open, Open) {
260-
context.setInput(buffer.begin(), buffer.size());
261-
writeInternal(Z_NO_FLUSH);
262-
co_return;
263-
}
264-
}
265-
KJ_UNREACHABLE;
256+
kj::Promise<void> write(kj::ArrayPtr<const byte> buffer) override final {
257+
requireActive("Write after close");
258+
context.setInput(buffer.begin(), buffer.size());
259+
writeInternal(Z_NO_FLUSH);
260+
co_return;
266261
}
267262

268-
kj::Promise<void> write(kj::ArrayPtr<const kj::ArrayPtr<const kj::byte>> pieces) override {
269-
// We check for Ended, Exception here so that we catch
270-
// these even if pieces is empty.
271-
KJ_SWITCH_ONEOF(state) {
272-
KJ_CASE_ONEOF(ended, Ended) {
273-
JSG_FAIL_REQUIRE(Error, "Write after close");
274-
}
275-
KJ_CASE_ONEOF(exception, kj::Exception) {
276-
kj::throwFatalException(kj::cp(exception));
277-
}
278-
KJ_CASE_ONEOF(open, Open) {
279-
for (auto piece: pieces) {
280-
co_await write(piece);
281-
}
282-
co_return;
283-
}
263+
kj::Promise<void> write(kj::ArrayPtr<const kj::ArrayPtr<const kj::byte>> pieces) override final {
264+
// We check state here so that we catch errors even if pieces is empty.
265+
requireActive("Write after close");
266+
for (auto piece: pieces) {
267+
co_await write(piece);
284268
}
285-
KJ_UNREACHABLE;
269+
co_return;
286270
}
287271

288-
kj::Promise<void> end() override {
289-
state = Ended();
272+
kj::Promise<void> end() override final {
273+
transitionToEnded();
290274
writeInternal(Z_FINISH);
291275
co_return;
292276
}
293277

294-
kj::Promise<void> whenWriteDisconnected() override {
278+
kj::Promise<void> whenWriteDisconnected() override final {
295279
return kj::NEVER_DONE;
296280
}
297281

298-
void abortWrite(kj::Exception&& reason) override {
282+
void abortWrite(kj::Exception&& reason) override final {
299283
cancelInternal(kj::mv(reason));
300284
}
301285

302286
// AsyncInputStream implementation -----------------------------------------------------
303287

304-
kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
288+
kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override final {
305289
KJ_ASSERT(minBytes <= maxBytes);
306-
KJ_SWITCH_ONEOF(state) {
307-
KJ_CASE_ONEOF(ended, Ended) {
308-
// There might still be data in the output buffer remaining to read.
309-
if (output.empty()) {
310-
co_return static_cast<size_t>(0);
311-
}
312-
co_return co_await tryReadInternal(
313-
kj::arrayPtr(reinterpret_cast<kj::byte*>(buffer), maxBytes), minBytes);
314-
}
315-
KJ_CASE_ONEOF(exception, kj::Exception) {
316-
kj::throwFatalException(kj::cp(exception));
317-
}
318-
KJ_CASE_ONEOF(open, Open) {
319-
co_return co_await tryReadInternal(
320-
kj::arrayPtr(reinterpret_cast<kj::byte*>(buffer), maxBytes), minBytes);
321-
}
290+
// Re-throw any stored exception
291+
throwIfException();
292+
// If stream has ended normally and no buffered data, return EOF
293+
if (isInTerminalState() && output.empty()) {
294+
co_return static_cast<size_t>(0);
322295
}
323-
KJ_UNREACHABLE;
296+
// Active or terminal with data remaining
297+
co_return co_await tryReadInternal(
298+
kj::arrayPtr(reinterpret_cast<kj::byte*>(buffer), maxBytes), minBytes);
324299
}
325300

301+
protected:
302+
virtual void requireActive(kj::StringPtr errorMessage) = 0;
303+
virtual void transitionToEnded() = 0;
304+
virtual void transitionToErrored(kj::Exception&& reason) = 0;
305+
virtual void throwIfException() = 0;
306+
virtual bool isInTerminalState() = 0;
307+
326308
private:
327309
struct PendingRead {
328310
kj::ArrayPtr<kj::byte> buffer;
@@ -343,7 +325,8 @@ class CompressionStreamImpl: public kj::Refcounted,
343325
}
344326

345327
canceler.cancel(kj::cp(reason));
346-
state = kj::mv(reason);
328+
transitionToErrored(kj::mv(reason));
329+
//state = kj::mv(reason);
347330
}
348331

349332
kj::Promise<size_t> tryReadInternal(kj::ArrayPtr<kj::byte> dest, size_t minBytes) {
@@ -357,9 +340,9 @@ class CompressionStreamImpl: public kj::Refcounted,
357340
// If the output currently contains >= minBytes, then we'll fulfill
358341
// the read immediately, removing as many bytes as possible from the
359342
// output queue.
360-
// If we reached the end, resolve the read immediately as well, since no
361-
// new data is expected.
362-
if (output.size() >= minBytes || state.template is<Ended>()) {
343+
// If we reached the end (terminal state), resolve the read immediately
344+
// as well, since no new data is expected.
345+
if (output.size() >= minBytes || isInTerminalState()) {
363346
co_return copyIntoBuffer(dest);
364347
}
365348

@@ -385,7 +368,7 @@ class CompressionStreamImpl: public kj::Refcounted,
385368
void writeInternal(int flush) {
386369
// TODO(later): This does not yet implement any backpressure. A caller can keep calling
387370
// write without reading, which will continue to fill the internal buffer.
388-
KJ_ASSERT(flush == Z_FINISH || state.template is<Open>());
371+
KJ_ASSERT(flush == Z_FINISH || !isInTerminalState());
389372
Context::Result result;
390373

391374
while (true) {
@@ -460,7 +443,7 @@ class CompressionStreamImpl: public kj::Refcounted,
460443
KJ_ASSERT(output.empty());
461444
}
462445

463-
if (state.template is<Ended>() && !pendingReads.empty()) {
446+
if (isInTerminalState() && !pendingReads.empty()) {
464447
// We are ended and we have pending reads. Because of the loop above,
465448
// one of either pendingReads or output must be empty, so if we got this
466449
// far, output.empty() must be true. Let's check.
@@ -477,17 +460,128 @@ class CompressionStreamImpl: public kj::Refcounted,
477460
}
478461
}
479462

480-
struct Ended {};
481-
struct Open {};
482-
483-
kj::OneOf<Open, Ended, kj::Exception> state = Open();
484463
Context context;
485464

486465
kj::Canceler canceler;
487466
LazyBuffer output;
488467
RingBuffer<PendingRead, 8> pendingReads;
489468
};
490469

470+
// Uncompressed data goes in. Compressed data comes out.
471+
// TODO(cleanup): Once the autogate is removed, delete this class and merge CompressionStreamBase
472+
// and CompressionStreamImplV2 back into a single class.
473+
template <Context::Mode mode>
474+
class CompressionStreamImpl final: public CompressionStreamBase<mode> {
475+
public:
476+
explicit CompressionStreamImpl(kj::String format,
477+
Context::ContextFlags flags,
478+
kj::Arc<const jsg::ExternalMemoryTarget>&& externalMemoryTarget)
479+
: CompressionStreamBase<mode>(kj::mv(format), flags, kj::mv(externalMemoryTarget)) {}
480+
481+
protected:
482+
void requireActive(kj::StringPtr errorMessage) override {
483+
KJ_SWITCH_ONEOF(state) {
484+
KJ_CASE_ONEOF(ended, Ended) {
485+
JSG_FAIL_REQUIRE(Error, errorMessage);
486+
}
487+
KJ_CASE_ONEOF(exception, kj::Exception) {
488+
kj::throwFatalException(kj::cp(exception));
489+
}
490+
KJ_CASE_ONEOF(open, Open) {
491+
return;
492+
}
493+
}
494+
KJ_UNREACHABLE;
495+
}
496+
497+
void transitionToEnded() override {
498+
state = Ended();
499+
}
500+
501+
void transitionToErrored(kj::Exception&& reason) override {
502+
state = kj::mv(reason);
503+
}
504+
505+
void throwIfException() override {
506+
KJ_IF_SOME(exception, state.template tryGet<kj::Exception>()) {
507+
kj::throwFatalException(kj::cp(exception));
508+
}
509+
}
510+
511+
virtual bool isInTerminalState() override {
512+
// Ended or Exception are both terminal states.
513+
return state.template is<Ended>() || state.template is<kj::Exception>();
514+
}
515+
516+
private:
517+
struct Ended {};
518+
struct Open {};
519+
520+
kj::OneOf<Open, Ended, kj::Exception> state = Open();
521+
};
522+
523+
template <Context::Mode mode>
524+
class CompressionStreamImplV2 final: public CompressionStreamBase<mode> {
525+
public:
526+
explicit CompressionStreamImplV2(kj::String format,
527+
Context::ContextFlags flags,
528+
kj::Arc<const jsg::ExternalMemoryTarget>&& externalMemoryTarget)
529+
: CompressionStreamBase<mode>(kj::mv(format), flags, kj::mv(externalMemoryTarget)),
530+
state(decltype(state)::template create<Open>()) {}
531+
532+
protected:
533+
void requireActive(kj::StringPtr errorMessage) override {
534+
KJ_IF_SOME(exception, state.tryGetErrorUnsafe()) {
535+
kj::throwFatalException(kj::cp(exception));
536+
}
537+
// isActive() returns true only if in Open state (the ActiveState)
538+
JSG_REQUIRE(state.isActive(), Error, errorMessage);
539+
}
540+
541+
void transitionToEnded() override {
542+
// Use transitionFromTo to ensure we're in Open state before ending.
543+
// This provides a clearer error if end() is called twice.
544+
auto result = state.template transitionFromTo<Open, Ended>();
545+
KJ_REQUIRE(result != kj::none, "Stream already ended or errored");
546+
}
547+
548+
void transitionToErrored(kj::Exception&& reason) override {
549+
// Use forceTransitionTo because cancelInternal may be called when already
550+
// in an error state (e.g., from writeInternal error handling).
551+
state.template forceTransitionTo<kj::Exception>(kj::mv(reason));
552+
}
553+
554+
void throwIfException() override {
555+
KJ_IF_SOME(exception, state.tryGetErrorUnsafe()) {
556+
kj::throwFatalException(kj::cp(exception));
557+
}
558+
}
559+
560+
virtual bool isInTerminalState() override {
561+
return state.isTerminal();
562+
}
563+
564+
private:
565+
struct Ended {
566+
static constexpr kj::StringPtr NAME KJ_UNUSED = "ended"_kj;
567+
};
568+
struct Open {
569+
static constexpr kj::StringPtr NAME KJ_UNUSED = "open"_kj;
570+
};
571+
572+
// State machine for tracking compression stream lifecycle:
573+
// Open -> Ended (normal close via end())
574+
// Open -> kj::Exception (error via abortWrite())
575+
// Ended is terminal, kj::Exception is implicitly terminal via ErrorState.
576+
StateMachine<TerminalStates<Ended>,
577+
ErrorState<kj::Exception>,
578+
ActiveState<Open>,
579+
Open,
580+
Ended,
581+
kj::Exception>
582+
state;
583+
};
584+
491585
// Adapter to bridge CompressionStreamImpl (which implements AsyncInputStream and
492586
// ExplicitEndOutputStream) to the ReadableStreamSource/WritableStreamSink interfaces.
493587
// TODO(soon): This class is intended to be replaced by the new ReadableSource/WritableSink
@@ -500,7 +594,7 @@ class CompressionStreamAdapter final: public kj::Refcounted,
500594
public ReadableStreamSource,
501595
public WritableStreamSink {
502596
public:
503-
explicit CompressionStreamAdapter(kj::Rc<CompressionStreamImpl<mode>> impl)
597+
explicit CompressionStreamAdapter(kj::Rc<CompressionStreamBase<mode>> impl)
504598
: impl(kj::mv(impl)),
505599
ioContext(IoContext::current()) {}
506600

@@ -532,17 +626,44 @@ class CompressionStreamAdapter final: public kj::Refcounted,
532626
}
533627

534628
private:
535-
kj::Rc<CompressionStreamImpl<mode>> impl;
629+
kj::Rc<CompressionStreamBase<mode>> impl;
536630
IoContext& ioContext;
537631
};
538632

633+
kj::Rc<CompressionStreamBase<Context::Mode::COMPRESS>> createCompressionStreamImpl(
634+
kj::String format,
635+
Context::ContextFlags flags,
636+
kj::Arc<const jsg::ExternalMemoryTarget>&& externalMemoryTarget) {
637+
// TODO(cleanup): Once the autogate is removed, we can delete CompressionStreamImpl
638+
if (util::Autogate::isEnabled(util::AutogateKey::COMPRESSION_STREAM_USE_STATE_MACHINE)) {
639+
return kj::rc<CompressionStreamImplV2<Context::Mode::COMPRESS>>(
640+
kj::mv(format), flags, kj::mv(externalMemoryTarget));
641+
}
642+
return kj::rc<CompressionStreamImpl<Context::Mode::COMPRESS>>(
643+
kj::mv(format), flags, kj::mv(externalMemoryTarget));
644+
}
645+
646+
kj::Rc<CompressionStreamBase<Context::Mode::DECOMPRESS>> createDecompressionStreamImpl(
647+
kj::String format,
648+
Context::ContextFlags flags,
649+
kj::Arc<const jsg::ExternalMemoryTarget>&& externalMemoryTarget) {
650+
// TODO(cleanup): Once the autogate is removed, we can delete CompressionStreamImpl
651+
if (util::Autogate::isEnabled(util::AutogateKey::COMPRESSION_STREAM_USE_STATE_MACHINE)) {
652+
return kj::rc<CompressionStreamImplV2<Context::Mode::DECOMPRESS>>(
653+
kj::mv(format), flags, kj::mv(externalMemoryTarget));
654+
}
655+
return kj::rc<CompressionStreamImpl<Context::Mode::DECOMPRESS>>(
656+
kj::mv(format), flags, kj::mv(externalMemoryTarget));
657+
}
658+
539659
} // namespace
540660

541661
jsg::Ref<CompressionStream> CompressionStream::constructor(jsg::Lock& js, kj::String format) {
542662
JSG_REQUIRE(format == "deflate" || format == "gzip" || format == "deflate-raw", TypeError,
543663
"The compression format must be either 'deflate', 'deflate-raw' or 'gzip'.");
544664

545-
auto impl = kj::rc<CompressionStreamImpl<Context::Mode::COMPRESS>>(
665+
// TODO(cleanup): Once the autogate is removed, we can delete CompressionStreamImpl
666+
kj::Rc<CompressionStreamBase<Context::Mode::COMPRESS>> impl = createCompressionStreamImpl(
546667
kj::mv(format), Context::ContextFlags::NONE, js.getExternalMemoryTarget());
547668

548669
auto& ioContext = IoContext::current();
@@ -561,10 +682,11 @@ jsg::Ref<DecompressionStream> DecompressionStream::constructor(jsg::Lock& js, kj
561682
JSG_REQUIRE(format == "deflate" || format == "gzip" || format == "deflate-raw", TypeError,
562683
"The compression format must be either 'deflate', 'deflate-raw' or 'gzip'.");
563684

564-
auto impl = kj::rc<CompressionStreamImpl<Context::Mode::DECOMPRESS>>(kj::mv(format),
565-
FeatureFlags::get(js).getStrictCompression() ? Context::ContextFlags::STRICT
566-
: Context::ContextFlags::NONE,
567-
js.getExternalMemoryTarget());
685+
kj::Rc<CompressionStreamBase<Context::Mode::DECOMPRESS>> impl =
686+
createDecompressionStreamImpl(kj::mv(format),
687+
FeatureFlags::get(js).getStrictCompression() ? Context::ContextFlags::STRICT
688+
: Context::ContextFlags::NONE,
689+
js.getExternalMemoryTarget());
568690

569691
auto& ioContext = IoContext::current();
570692

src/workerd/api/tests/BUILD.bazel

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,3 +578,11 @@ wd_test(
578578
args = ["--experimental"],
579579
data = ["headers-immutable-prototype-test.js"],
580580
)
581+
582+
# TODO(cleanup): This is a copy of an existing test in streams-test. Once the autogate is remvoed,
583+
# this separate test can be deleted.
584+
wd_test(
585+
src = "compression-streams-test.wd-test",
586+
args = ["--experimental"],
587+
data = ["compression-streams-test.js"],
588+
)

0 commit comments

Comments
 (0)