Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 196 additions & 74 deletions src/workerd/api/streams/compression.c++
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

#include <workerd/api/system-streams.h>
#include <workerd/io/features.h>
#include <workerd/util/autogate.h>
#include <workerd/util/ring-buffer.h>
#include <workerd/util/state-machine.h>

namespace workerd::api {
CompressionAllocator::CompressionAllocator(
Expand Down Expand Up @@ -235,94 +237,74 @@ class LazyBuffer {
size_t valid_size_;
};

// Uncompressed data goes in. Compressed data comes out.
// Because we have to use an autogate to switch things over to the new state manager, we need
// to separate out a common base class for the compression stream internal state and separate
// two separate impls that differ only in how they manage state. Once the autogate is removed,
// we can delete the first impl class and merge everything back together.
template <Context::Mode mode>
class CompressionStreamImpl: public kj::Refcounted,
class CompressionStreamBase: public kj::Refcounted,
public kj::AsyncInputStream,
public capnp::ExplicitEndOutputStream {
public:
explicit CompressionStreamImpl(kj::String format,
explicit CompressionStreamBase(kj::String format,
Context::ContextFlags flags,
kj::Arc<const jsg::ExternalMemoryTarget>&& externalMemoryTarget)
: context(mode, format, flags, kj::mv(externalMemoryTarget)) {}

// WritableStreamSink implementation ---------------------------------------------------

kj::Promise<void> write(kj::ArrayPtr<const byte> buffer) override {
KJ_SWITCH_ONEOF(state) {
KJ_CASE_ONEOF(ended, Ended) {
JSG_FAIL_REQUIRE(Error, "Write after close");
}
KJ_CASE_ONEOF(exception, kj::Exception) {
kj::throwFatalException(kj::cp(exception));
}
KJ_CASE_ONEOF(open, Open) {
context.setInput(buffer.begin(), buffer.size());
writeInternal(Z_NO_FLUSH);
co_return;
}
}
KJ_UNREACHABLE;
kj::Promise<void> write(kj::ArrayPtr<const byte> buffer) override final {
requireActive("Write after close");
context.setInput(buffer.begin(), buffer.size());
writeInternal(Z_NO_FLUSH);
co_return;
}

kj::Promise<void> write(kj::ArrayPtr<const kj::ArrayPtr<const kj::byte>> pieces) override {
// We check for Ended, Exception here so that we catch
// these even if pieces is empty.
KJ_SWITCH_ONEOF(state) {
KJ_CASE_ONEOF(ended, Ended) {
JSG_FAIL_REQUIRE(Error, "Write after close");
}
KJ_CASE_ONEOF(exception, kj::Exception) {
kj::throwFatalException(kj::cp(exception));
}
KJ_CASE_ONEOF(open, Open) {
for (auto piece: pieces) {
co_await write(piece);
}
co_return;
}
kj::Promise<void> write(kj::ArrayPtr<const kj::ArrayPtr<const kj::byte>> pieces) override final {
// We check state here so that we catch errors even if pieces is empty.
requireActive("Write after close");
for (auto piece: pieces) {
co_await write(piece);
}
KJ_UNREACHABLE;
co_return;
}

kj::Promise<void> end() override {
state = Ended();
kj::Promise<void> end() override final {
transitionToEnded();
writeInternal(Z_FINISH);
co_return;
}

kj::Promise<void> whenWriteDisconnected() override {
kj::Promise<void> whenWriteDisconnected() override final {
return kj::NEVER_DONE;
}

void abortWrite(kj::Exception&& reason) override {
void abortWrite(kj::Exception&& reason) override final {
cancelInternal(kj::mv(reason));
}

// AsyncInputStream implementation -----------------------------------------------------

kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override final {
KJ_ASSERT(minBytes <= maxBytes);
KJ_SWITCH_ONEOF(state) {
KJ_CASE_ONEOF(ended, Ended) {
// There might still be data in the output buffer remaining to read.
if (output.empty()) {
co_return static_cast<size_t>(0);
}
co_return co_await tryReadInternal(
kj::arrayPtr(reinterpret_cast<kj::byte*>(buffer), maxBytes), minBytes);
}
KJ_CASE_ONEOF(exception, kj::Exception) {
kj::throwFatalException(kj::cp(exception));
}
KJ_CASE_ONEOF(open, Open) {
co_return co_await tryReadInternal(
kj::arrayPtr(reinterpret_cast<kj::byte*>(buffer), maxBytes), minBytes);
}
// Re-throw any stored exception
throwIfException();
// If stream has ended normally and no buffered data, return EOF
if (isInTerminalState() && output.empty()) {
co_return static_cast<size_t>(0);
}
KJ_UNREACHABLE;
// Active or terminal with data remaining
co_return co_await tryReadInternal(
kj::arrayPtr(reinterpret_cast<kj::byte*>(buffer), maxBytes), minBytes);
}

protected:
virtual void requireActive(kj::StringPtr errorMessage) = 0;
virtual void transitionToEnded() = 0;
virtual void transitionToErrored(kj::Exception&& reason) = 0;
virtual void throwIfException() = 0;
virtual bool isInTerminalState() = 0;

private:
struct PendingRead {
kj::ArrayPtr<kj::byte> buffer;
Expand All @@ -343,7 +325,8 @@ class CompressionStreamImpl: public kj::Refcounted,
}

canceler.cancel(kj::cp(reason));
state = kj::mv(reason);
transitionToErrored(kj::mv(reason));
//state = kj::mv(reason);
}

kj::Promise<size_t> tryReadInternal(kj::ArrayPtr<kj::byte> dest, size_t minBytes) {
Expand All @@ -357,9 +340,9 @@ class CompressionStreamImpl: public kj::Refcounted,
// If the output currently contains >= minBytes, then we'll fulfill
// the read immediately, removing as many bytes as possible from the
// output queue.
// If we reached the end, resolve the read immediately as well, since no
// new data is expected.
if (output.size() >= minBytes || state.template is<Ended>()) {
// If we reached the end (terminal state), resolve the read immediately
// as well, since no new data is expected.
if (output.size() >= minBytes || isInTerminalState()) {
co_return copyIntoBuffer(dest);
}

Expand All @@ -385,7 +368,7 @@ class CompressionStreamImpl: public kj::Refcounted,
void writeInternal(int flush) {
// TODO(later): This does not yet implement any backpressure. A caller can keep calling
// write without reading, which will continue to fill the internal buffer.
KJ_ASSERT(flush == Z_FINISH || state.template is<Open>());
KJ_ASSERT(flush == Z_FINISH || !isInTerminalState());
Context::Result result;

while (true) {
Expand Down Expand Up @@ -460,7 +443,7 @@ class CompressionStreamImpl: public kj::Refcounted,
KJ_ASSERT(output.empty());
}

if (state.template is<Ended>() && !pendingReads.empty()) {
if (isInTerminalState() && !pendingReads.empty()) {
// We are ended and we have pending reads. Because of the loop above,
// one of either pendingReads or output must be empty, so if we got this
// far, output.empty() must be true. Let's check.
Expand All @@ -477,17 +460,128 @@ class CompressionStreamImpl: public kj::Refcounted,
}
}

struct Ended {};
struct Open {};

kj::OneOf<Open, Ended, kj::Exception> state = Open();
Context context;

kj::Canceler canceler;
LazyBuffer output;
RingBuffer<PendingRead, 8> pendingReads;
};

// Uncompressed data goes in. Compressed data comes out.
// TODO(cleanup): Once the autogate is removed, delete this class and merge CompressionStreamBase
// and CompressionStreamImplV2 back into a single class.
template <Context::Mode mode>
class CompressionStreamImpl final: public CompressionStreamBase<mode> {
public:
explicit CompressionStreamImpl(kj::String format,
Context::ContextFlags flags,
kj::Arc<const jsg::ExternalMemoryTarget>&& externalMemoryTarget)
: CompressionStreamBase<mode>(kj::mv(format), flags, kj::mv(externalMemoryTarget)) {}

protected:
void requireActive(kj::StringPtr errorMessage) override {
KJ_SWITCH_ONEOF(state) {
KJ_CASE_ONEOF(ended, Ended) {
JSG_FAIL_REQUIRE(Error, errorMessage);
}
KJ_CASE_ONEOF(exception, kj::Exception) {
kj::throwFatalException(kj::cp(exception));
}
KJ_CASE_ONEOF(open, Open) {
return;
}
}
KJ_UNREACHABLE;
}

void transitionToEnded() override {
state = Ended();
}

void transitionToErrored(kj::Exception&& reason) override {
state = kj::mv(reason);
}

void throwIfException() override {
KJ_IF_SOME(exception, state.template tryGet<kj::Exception>()) {
kj::throwFatalException(kj::cp(exception));
}
}

virtual bool isInTerminalState() override {
// Ended or Exception are both terminal states.
return state.template is<Ended>() || state.template is<kj::Exception>();
}

private:
struct Ended {};
struct Open {};

kj::OneOf<Open, Ended, kj::Exception> state = Open();
};

template <Context::Mode mode>
class CompressionStreamImplV2 final: public CompressionStreamBase<mode> {
public:
explicit CompressionStreamImplV2(kj::String format,
Context::ContextFlags flags,
kj::Arc<const jsg::ExternalMemoryTarget>&& externalMemoryTarget)
: CompressionStreamBase<mode>(kj::mv(format), flags, kj::mv(externalMemoryTarget)),
state(decltype(state)::template create<Open>()) {}

protected:
void requireActive(kj::StringPtr errorMessage) override {
KJ_IF_SOME(exception, state.tryGetErrorUnsafe()) {
kj::throwFatalException(kj::cp(exception));
}
// isActive() returns true only if in Open state (the ActiveState)
JSG_REQUIRE(state.isActive(), Error, errorMessage);
}

void transitionToEnded() override {
// Use transitionFromTo to ensure we're in Open state before ending.
// This provides a clearer error if end() is called twice.
auto result = state.template transitionFromTo<Open, Ended>();
KJ_REQUIRE(result != kj::none, "Stream already ended or errored");
}

void transitionToErrored(kj::Exception&& reason) override {
// Use forceTransitionTo because cancelInternal may be called when already
// in an error state (e.g., from writeInternal error handling).
state.template forceTransitionTo<kj::Exception>(kj::mv(reason));
}

void throwIfException() override {
KJ_IF_SOME(exception, state.tryGetErrorUnsafe()) {
kj::throwFatalException(kj::cp(exception));
}
}

virtual bool isInTerminalState() override {
return state.isTerminal();
}

private:
struct Ended {
static constexpr kj::StringPtr NAME KJ_UNUSED = "ended"_kj;
};
struct Open {
static constexpr kj::StringPtr NAME KJ_UNUSED = "open"_kj;
};

// State machine for tracking compression stream lifecycle:
// Open -> Ended (normal close via end())
// Open -> kj::Exception (error via abortWrite())
// Ended is terminal, kj::Exception is implicitly terminal via ErrorState.
StateMachine<TerminalStates<Ended>,
ErrorState<kj::Exception>,
ActiveState<Open>,
Open,
Ended,
kj::Exception>
state;
};

// Adapter to bridge CompressionStreamImpl (which implements AsyncInputStream and
// ExplicitEndOutputStream) to the ReadableStreamSource/WritableStreamSink interfaces.
// TODO(soon): This class is intended to be replaced by the new ReadableSource/WritableSink
Expand All @@ -500,7 +594,7 @@ class CompressionStreamAdapter final: public kj::Refcounted,
public ReadableStreamSource,
public WritableStreamSink {
public:
explicit CompressionStreamAdapter(kj::Rc<CompressionStreamImpl<mode>> impl)
explicit CompressionStreamAdapter(kj::Rc<CompressionStreamBase<mode>> impl)
: impl(kj::mv(impl)),
ioContext(IoContext::current()) {}

Expand Down Expand Up @@ -532,17 +626,44 @@ class CompressionStreamAdapter final: public kj::Refcounted,
}

private:
kj::Rc<CompressionStreamImpl<mode>> impl;
kj::Rc<CompressionStreamBase<mode>> impl;
IoContext& ioContext;
};

kj::Rc<CompressionStreamBase<Context::Mode::COMPRESS>> createCompressionStreamImpl(
kj::String format,
Context::ContextFlags flags,
kj::Arc<const jsg::ExternalMemoryTarget>&& externalMemoryTarget) {
// TODO(cleanup): Once the autogate is removed, we can delete CompressionStreamImpl
if (util::Autogate::isEnabled(util::AutogateKey::COMPRESSION_STREAM_USE_STATE_MACHINE)) {
return kj::rc<CompressionStreamImplV2<Context::Mode::COMPRESS>>(
kj::mv(format), flags, kj::mv(externalMemoryTarget));
}
return kj::rc<CompressionStreamImpl<Context::Mode::COMPRESS>>(
kj::mv(format), flags, kj::mv(externalMemoryTarget));
}

kj::Rc<CompressionStreamBase<Context::Mode::DECOMPRESS>> createDecompressionStreamImpl(
kj::String format,
Context::ContextFlags flags,
kj::Arc<const jsg::ExternalMemoryTarget>&& externalMemoryTarget) {
// TODO(cleanup): Once the autogate is removed, we can delete CompressionStreamImpl
if (util::Autogate::isEnabled(util::AutogateKey::COMPRESSION_STREAM_USE_STATE_MACHINE)) {
return kj::rc<CompressionStreamImplV2<Context::Mode::DECOMPRESS>>(
kj::mv(format), flags, kj::mv(externalMemoryTarget));
}
return kj::rc<CompressionStreamImpl<Context::Mode::DECOMPRESS>>(
kj::mv(format), flags, kj::mv(externalMemoryTarget));
}

} // namespace

jsg::Ref<CompressionStream> CompressionStream::constructor(jsg::Lock& js, kj::String format) {
JSG_REQUIRE(format == "deflate" || format == "gzip" || format == "deflate-raw", TypeError,
"The compression format must be either 'deflate', 'deflate-raw' or 'gzip'.");

auto impl = kj::rc<CompressionStreamImpl<Context::Mode::COMPRESS>>(
// TODO(cleanup): Once the autogate is removed, we can delete CompressionStreamImpl
kj::Rc<CompressionStreamBase<Context::Mode::COMPRESS>> impl = createCompressionStreamImpl(
kj::mv(format), Context::ContextFlags::NONE, js.getExternalMemoryTarget());

auto& ioContext = IoContext::current();
Expand All @@ -561,10 +682,11 @@ jsg::Ref<DecompressionStream> DecompressionStream::constructor(jsg::Lock& js, kj
JSG_REQUIRE(format == "deflate" || format == "gzip" || format == "deflate-raw", TypeError,
"The compression format must be either 'deflate', 'deflate-raw' or 'gzip'.");

auto impl = kj::rc<CompressionStreamImpl<Context::Mode::DECOMPRESS>>(kj::mv(format),
FeatureFlags::get(js).getStrictCompression() ? Context::ContextFlags::STRICT
: Context::ContextFlags::NONE,
js.getExternalMemoryTarget());
kj::Rc<CompressionStreamBase<Context::Mode::DECOMPRESS>> impl =
createDecompressionStreamImpl(kj::mv(format),
FeatureFlags::get(js).getStrictCompression() ? Context::ContextFlags::STRICT
: Context::ContextFlags::NONE,
js.getExternalMemoryTarget());

auto& ioContext = IoContext::current();

Expand Down
8 changes: 8 additions & 0 deletions src/workerd/api/tests/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -578,3 +578,11 @@ wd_test(
args = ["--experimental"],
data = ["headers-immutable-prototype-test.js"],
)

# TODO(cleanup): This is a copy of an existing test in streams-test. Once the autogate is remvoed,
# this separate test can be deleted.
wd_test(
src = "compression-streams-test.wd-test",
args = ["--experimental"],
data = ["compression-streams-test.js"],
)
Loading