diff --git a/src/workerd/api/streams/internal-test.c++ b/src/workerd/api/streams/internal-test.c++ index 5dd6215125f..c21519f7ad9 100644 --- a/src/workerd/api/streams/internal-test.c++ +++ b/src/workerd/api/streams/internal-test.c++ @@ -4,6 +4,7 @@ #include "internal.h" #include "readable.h" +#include "standard.h" #include "writable.h" #include @@ -352,5 +353,105 @@ KJ_TEST("WritableStreamInternalController observability") { KJ_ASSERT(observer.queueSizeBytes == 0); } +// Test for use-after-free fix in pipeLoop when abort is called during pending read. +// This tests the scenario where: +// 1. A JavaScript-backed ReadableStream is piped to an internal WritableStream +// 2. The pipeLoop is waiting for a read from the JS stream +// 3. abort() is called on the writable stream, which triggers drain() +// 4. drain() destroys the Pipe object +// 5. The pending read callback must not access the freed Pipe +// +// The fix ensures the Pipe::State is ref-counted and survives until all callbacks complete. +KJ_TEST("WritableStreamInternalController pipeLoop abort during pending read") { + capnp::MallocMessageBuilder message; + auto flags = message.initRoot(); + flags.setNodeJsCompat(true); + flags.setWorkerdExperimental(true); + flags.setStreamsJavaScriptControllers(true); + // Enable the flag that causes abort to call drain() immediately + flags.setInternalWritableStreamAbortClearsQueue(true); + + TestFixture fixture({.featureFlags = flags.asReader()}); + + class MySink final: public WritableStreamSink { + public: + kj::Promise write(kj::ArrayPtr buffer) override { + return kj::READY_NOW; + } + kj::Promise write(kj::ArrayPtr> pieces) override { + return kj::READY_NOW; + } + kj::Promise end() override { + return kj::READY_NOW; + } + void abort(kj::Exception reason) override {} + }; + + fixture.runInIoContext([&](const TestFixture::Environment& env) { + // Create a JavaScript-backed ReadableStream. + // The pull function will be called when the pipe tries to read. + // We use a JS-backed stream so that pipeLoop is used (not the kj pipe path). + // + // We need to simulate: + // 1. First read succeeds with some data + // 2. Second read is pending (the promise from pull is not resolved) + // 3. While pending, we abort the writable stream + // + // Using an UnderlyingSource with a pull callback that enqueues data once, + // then on the second call returns without enqueuing (leaving the read pending). + + int pullCount = 0; + jsg::Ref source = ReadableStream::constructor(env.js, + UnderlyingSource{.pull = + [&pullCount](jsg::Lock& js, UnderlyingSource::Controller controller) { + pullCount++; + auto& c = KJ_ASSERT_NONNULL(controller.tryGet>()); + if (pullCount == 1) { + // First pull: enqueue some data so the pipe loop can make progress + auto data = js.bytes(kj::heapArray({1, 2, 3, 4})); + c->enqueue(js, data.getHandle(js)); + } + // Second pull onwards: don't enqueue anything, leaving the read pending. + // This simulates an async data source that hasn't received data yet. + // The promise returned by read() will be pending. + return js.resolvedPromise(); + }}, + kj::none); + + jsg::Ref sink = + env.js.alloc(env.context, kj::heap(), kj::none); + + // Start the pipe. This will: + // 1. Call pull() which enqueues data + // 2. pipeLoop reads the data and writes it to the sink + // 3. pipeLoop calls read() again, which calls pull() + // 4. pull() returns without enqueuing, so read() returns a pending promise + // 5. pipeLoop's callback is now waiting for that promise + auto pipeTo = source->pipeTo(env.js, sink.addRef(), PipeToOptions{}); + pipeTo.markAsHandled(env.js); + + // Run microtasks to let the pipe make progress (first read/write cycle) + env.js.runMicrotasks(); + + // At this point, pipeLoop should be waiting for the second read. + // Now abort the writable stream. This should: + // 1. Call doAbort() which calls drain() + // 2. drain() destroys the Pipe (setting state->aborted = true) + // 3. The pending read callback should check aborted and bail out safely + + // Before the fix, this would cause a use-after-free when the pending callback + // tried to access the freed Pipe. + auto abortPromise = sink->getController().abort(env.js, env.js.v8TypeError("Test abort"_kj)); + abortPromise.markAsHandled(env.js); + + // Run microtasks to process the abort and any pending callbacks + env.js.runMicrotasks(); + + // If we get here without crashing, the test passes. + // The fix ensures that the Pipe::State survives until all callbacks complete. + KJ_ASSERT(pullCount >= 1); // Verify pull was called at least once + }); +} + } // namespace } // namespace workerd::api diff --git a/src/workerd/api/streams/internal.c++ b/src/workerd/api/streams/internal.c++ index 366acb07d08..02bba17cf25 100644 --- a/src/workerd/api/streams/internal.c++ +++ b/src/workerd/api/streams/internal.c++ @@ -1282,13 +1282,8 @@ kj::Maybe> WritableStreamInternalController::tryPipeFrom( } queue.push_back(WriteEvent{ .outputLock = IoContext::current().waitForOutputLocksIfNecessaryIoOwn(), - .event = kj::heap({.parent = *this, - .source = sourceLock, - .promise = kj::mv(prp.resolver), - .preventAbort = preventAbort, - .preventClose = preventClose, - .preventCancel = preventCancel, - .maybeSignal = kj::mv(options.signal)}), + .event = kj::heap(*this, sourceLock, kj::mv(prp.resolver), preventAbort, preventClose, + preventCancel, kj::mv(options.signal)), }); ensureWriting(js); return kj::mv(prp.promise); @@ -1647,11 +1642,11 @@ jsg::Promise WritableStreamInternalController::writeLoopAfterFrontOutputLo // The readable side should *should* still be readable here but let's double check, just // to be safe, both for closed state and errored states. - if (request->source.isClosed()) { - request->source.release(js); + if (request->source().isClosed()) { + request->source().release(js); // If the source is closed, the spec requires us to close the destination unless the // preventClose option is true. - if (!request->preventClose && !isClosedOrClosing()) { + if (!request->preventClose() && !isClosedOrClosing()) { doClose(js); } else { writeState.init(); @@ -1659,11 +1654,11 @@ jsg::Promise WritableStreamInternalController::writeLoopAfterFrontOutputLo return js.resolvedPromise(); } - KJ_IF_SOME(errored, request->source.tryGetErrored(js)) { - request->source.release(js); + KJ_IF_SOME(errored, request->source().tryGetErrored(js)) { + request->source().release(js); // If the source is errored, the spec requires us to error the destination unless the // preventAbort option is true. - if (!request->preventAbort) { + if (!request->preventAbort()) { auto ex = js.exceptionToKj(js.v8Ref(errored)); writable->abort(kj::mv(ex)); drain(js, errored); @@ -1682,7 +1677,7 @@ jsg::Promise WritableStreamInternalController::writeLoopAfterFrontOutputLo // loop to pass the data into the destination. const auto handlePromise = [this, &ioContext, check = makeChecker(), - preventAbort = request->preventAbort]( + preventAbort = request->preventAbort()]( jsg::Lock& js, auto promise) { return promise.then(js, ioContext.addFunctor([this, check](jsg::Lock& js) mutable { // Under some conditions, the clean up has already happened. @@ -1694,22 +1689,23 @@ jsg::Promise WritableStreamInternalController::writeLoopAfterFrontOutputLo // In that case, we need to treat preventAbort the same as preventClose. Be // sure to check this before calling sourceLock.close() or the error detail will // be lost. - KJ_IF_SOME(errored, request.source.tryGetErrored(js)) { - if (request.preventAbort) request.preventClose = true; + // Capture preventClose now so we can modify it locally if needed. + bool preventClose = request.preventClose(); + KJ_IF_SOME(errored, request.source().tryGetErrored(js)) { + if (request.preventAbort()) preventClose = true; // Even through we're not going to close the destination, we still want the // pipe promise itself to be rejected in this case. - maybeRejectPromise(js, request.promise, errored); + maybeRejectPromise(js, request.promise(), errored); } else KJ_IF_SOME(errored, state.tryGet()) { - maybeRejectPromise(js, request.promise, errored.getHandle(js)); + maybeRejectPromise(js, request.promise(), errored.getHandle(js)); } else { - maybeResolvePromise(js, request.promise); + maybeResolvePromise(js, request.promise()); } // Always transition the readable side to the closed state, because we read until EOF. // Note that preventClose (below) means "don't close the writable side", i.e. don't // call end(). - request.source.close(js); - auto preventClose = request.preventClose; + request.source().close(js); queue.pop_front(); if (!preventClose) { @@ -1724,7 +1720,7 @@ jsg::Promise WritableStreamInternalController::writeLoopAfterFrontOutputLo [this, check, preventAbort](jsg::Lock& js, jsg::Value reason) mutable { auto handle = reason.getHandle(js); auto& request = check.template operator()(); - maybeRejectPromise(js, request.promise, handle); + maybeRejectPromise(js, request.promise(), handle); // TODO(conform): Remember all those checks we performed in ReadableStream::pipeTo()? // We're supposed to perform the same checks continually, e.g., errored writes should // cancel the readable side unless preventCancel is truthy... This would require @@ -1732,7 +1728,7 @@ jsg::Promise WritableStreamInternalController::writeLoopAfterFrontOutputLo // of this is that if there is an error on the writable side, we error the readable // side, rather than close (cancel) it, which is what the spec would have us do. // TODO(now): Warn on the console about this. - request.source.error(js, handle); + request.source().error(js, handle); queue.pop_front(); if (!preventAbort) { return abort(js, handle); @@ -1742,11 +1738,11 @@ jsg::Promise WritableStreamInternalController::writeLoopAfterFrontOutputLo })); }; - KJ_IF_SOME(promise, request->source.tryPumpTo(*writable->sink, !request->preventClose)) { + KJ_IF_SOME(promise, request->source().tryPumpTo(*writable->sink, !request->preventClose())) { return handlePromise(js, ioContext.awaitIo(js, writable->canceler.wrap( - AbortSignal::maybeCancelWrap(js, request->maybeSignal, kj::mv(promise))))); + AbortSignal::maybeCancelWrap(js, request->maybeSignal(), kj::mv(promise))))); } // The ReadableStream is JavaScript-backed. We can still pipe the data but it's going to be @@ -1797,42 +1793,49 @@ jsg::Promise WritableStreamInternalController::writeLoopAfterFrontOutputLo KJ_UNREACHABLE; } -bool WritableStreamInternalController::Pipe::checkSignal(jsg::Lock& js) { +bool WritableStreamInternalController::Pipe::State::checkSignal(jsg::Lock& js) { + // Returns true if the caller should bail out and stop processing. This happens in two cases: + // 1. The State was aborted (e.g., by drain()) - the Pipe is being torn down + // 2. The AbortSignal was triggered - we handle the abort and return true + // In both cases, the caller should return a resolved promise and not continue the pipe loop. + if (aborted) return true; + KJ_IF_SOME(signal, maybeSignal) { if (signal->getAborted(js)) { auto reason = signal->getReason(js); // abort process might call parent.drain which will delete this, // move/copy everything we need after into temps. - auto& parent = this->parent; - auto& source = this->source; - auto preventCancel = this->preventCancel; - auto promise = kj::mv(this->promise); + auto& parentRef = this->parent; + auto& sourceRef = this->source; + auto preventCancelCopy = this->preventCancel; + auto promiseCopy = kj::mv(this->promise); if (!preventAbort) { - KJ_IF_SOME(writable, parent.state.tryGet>()) { + KJ_IF_SOME(writable, parentRef.state.tryGet>()) { auto ex = js.exceptionToKj(reason); writable->abort(kj::mv(ex)); - parent.drain(js, reason); + parentRef.drain(js, reason); } else { - parent.writeState.init(); + parentRef.writeState.init(); } } else { - parent.writeState.init(); + parentRef.writeState.init(); } - if (!preventCancel) { - source.release(js, v8::Local(reason)); + if (!preventCancelCopy) { + sourceRef.release(js, v8::Local(reason)); } else { - source.release(js); + sourceRef.release(js); } - maybeRejectPromise(js, promise, reason); + maybeRejectPromise(js, promiseCopy, reason); return true; } } return false; } -jsg::Promise WritableStreamInternalController::Pipe::write(v8::Local handle) { +jsg::Promise WritableStreamInternalController::Pipe::State::write( + v8::Local handle) { auto& writable = parent.state.get>(); // TODO(soon): Once jsg::BufferSource lands and we're able to use it, this can be simplified. KJ_ASSERT(handle->IsArrayBuffer() || handle->IsArrayBufferView()); @@ -1859,7 +1862,7 @@ jsg::Promise WritableStreamInternalController::Pipe::write(v8::Local WritableStreamInternalController::Pipe::pipeLoop(jsg::Lock& js) { +jsg::Promise WritableStreamInternalController::Pipe::State::pipeLoop(jsg::Lock& js) { // This is a bit of dance. We got here because the source ReadableStream does not support // the internal, more efficient kj pipe (which means it is a JavaScript-backed ReadableStream). // We need to call read() on the source which returns a JavaScript Promise, wait on it to resolve, @@ -1870,6 +1873,10 @@ jsg::Promise WritableStreamInternalController::Pipe::pipeLoop(jsg::Lock& j auto& ioContext = IoContext::current(); + if (aborted) { + return js.resolvedPromise(); + } + if (checkSignal(js)) { // If the signal is triggered, checkSignal will handle erroring the source and destination. return js.resolvedPromise(); @@ -1916,11 +1923,16 @@ jsg::Promise WritableStreamInternalController::Pipe::pipeLoop(jsg::Lock& j if (!parent.isClosedOrClosing()) { // We'll only be here if the sink is in the Writable state. auto& ioContext = IoContext::current(); + // Capture a ref to the state to keep it alive during async operations. return ioContext .awaitIo(js, parent.state.get>()->sink->end(), [](jsg::Lock&) {}) - .then(js, ioContext.addFunctor([this](jsg::Lock& js) { parent.finishClose(js); }), - ioContext.addFunctor([this](jsg::Lock& js, jsg::Value reason) { - parent.finishError(js, reason.getHandle(js)); + .then(js, ioContext.addFunctor([state = kj::addRef(*this)](jsg::Lock& js) { + if (state->aborted) return; + state->parent.finishClose(js); + }), + ioContext.addFunctor([state = kj::addRef(*this)](jsg::Lock& js, jsg::Value reason) { + if (state->aborted) return; + state->parent.finishError(js, reason.getHandle(js)); })); } parent.writeState.init(); @@ -1942,8 +1954,9 @@ jsg::Promise WritableStreamInternalController::Pipe::pipeLoop(jsg::Lock& j } return source.read(js).then(js, - ioContext.addFunctor([this](jsg::Lock& js, ReadResult result) -> jsg::Promise { - if (checkSignal(js) || result.done) { + ioContext.addFunctor([state = kj::addRef(*this)]( + jsg::Lock& js, ReadResult result) mutable -> jsg::Promise { + if (state->aborted || state->checkSignal(js) || result.done) { return js.resolvedPromise(); } @@ -1953,27 +1966,40 @@ jsg::Promise WritableStreamInternalController::Pipe::pipeLoop(jsg::Lock& j KJ_IF_SOME(value, result.value) { auto handle = value.getHandle(js); if (handle->IsArrayBuffer() || handle->IsArrayBufferView()) { - return write(handle).then(js, [this](jsg::Lock& js) -> jsg::Promise { + return state->write(handle).then(js, + [state = kj::addRef(*state)](jsg::Lock& js) mutable -> jsg::Promise { + if (state->aborted) { + return js.resolvedPromise(); + } // The signal will be checked again at the start of the next loop iteration. - return pipeLoop(js); - }, [this](jsg::Lock& js, jsg::Value reason) -> jsg::Promise { - parent.doError(js, reason.getHandle(js)); - return pipeLoop(js); + return state->pipeLoop(js); + }, + [state = kj::addRef(*state)]( + jsg::Lock& js, jsg::Value reason) mutable -> jsg::Promise { + if (state->aborted) { + return js.resolvedPromise(); + } + state->parent.doError(js, reason.getHandle(js)); + return state->pipeLoop(js); }); } } // Undefined and null are perfectly valid values to pass through a ReadableStream, // but we can't interpret them as bytes so if we get them here, we error the pipe. auto error = js.v8TypeError("This WritableStream only supports writing byte types."_kj); - auto& writable = parent.state.get>(); + auto& writable = state->parent.state.get>(); auto ex = js.exceptionToKj(js.v8Ref(error)); writable->abort(kj::mv(ex)); // The error condition will be handled at the start of the next iteration. - return pipeLoop(js); + return state->pipeLoop(js); }), - ioContext.addFunctor([this](jsg::Lock& js, jsg::Value reason) -> jsg::Promise { + ioContext.addFunctor([state = kj::addRef(*this)]( + jsg::Lock& js, jsg::Value reason) mutable -> jsg::Promise { + if (state->aborted) { + return js.resolvedPromise(); + } // The error will be processed and propagated in the next iteration. - return pipeLoop(js); + return state->pipeLoop(js); })); } @@ -1985,10 +2011,10 @@ void WritableStreamInternalController::drain(jsg::Lock& js, v8::Local maybeRejectPromise(js, writeRequest->promise, reason); } KJ_CASE_ONEOF(pipeRequest, kj::Own) { - if (!pipeRequest->preventCancel) { - pipeRequest->source.cancel(js, reason); + if (!pipeRequest->preventCancel()) { + pipeRequest->source().cancel(js, reason); } - maybeRejectPromise(js, pipeRequest->promise, reason); + maybeRejectPromise(js, pipeRequest->promise(), reason); } KJ_CASE_ONEOF(closeRequest, kj::Own) { maybeRejectPromise(js, closeRequest->promise, reason); @@ -2014,7 +2040,7 @@ void WritableStreamInternalController::visitForGc(jsg::GcVisitor& visitor) { visitor.visit(flush->promise); } KJ_CASE_ONEOF(pipe, kj::Own) { - visitor.visit(pipe->maybeSignal, pipe->promise); + visitor.visit(pipe->maybeSignal(), pipe->promise()); } } } diff --git a/src/workerd/api/streams/internal.h b/src/workerd/api/streams/internal.h index 3a926ebbb5a..73f52ff227c 100644 --- a/src/workerd/api/streams/internal.h +++ b/src/workerd/api/streams/internal.h @@ -11,6 +11,8 @@ #include #include +#include + namespace workerd::api { // ======================================================================================= @@ -320,21 +322,102 @@ class WritableStreamInternalController: public WritableStreamController { } }; struct Pipe { - WritableStreamInternalController& parent; - ReadableStreamController::PipeController& source; - kj::Maybe::Resolver> promise; - bool preventAbort; - bool preventClose; - bool preventCancel; - kj::Maybe> maybeSignal; + // PipeState is ref-counted so that it can be safely captured by lambdas in pipeLoop(). + // When drain() destroys the Pipe, the state survives as long as pending callbacks need it. + // The `aborted` flag is set when the Pipe is destroyed. + struct State: public kj::Refcounted { + WritableStreamInternalController& parent; + ReadableStreamController::PipeController& source; + kj::Maybe::Resolver> promise; + kj::Maybe> maybeSignal; + + bool preventAbort; + bool preventClose; + bool preventCancel; + + // True when the Pipe is being destroyed + bool aborted = false; + + State(WritableStreamInternalController& parent, + ReadableStreamController::PipeController& source, + kj::Maybe::Resolver> promise, + bool preventAbort, + bool preventClose, + bool preventCancel, + kj::Maybe> maybeSignal) + : parent(parent), + source(source), + promise(kj::mv(promise)), + maybeSignal(kj::mv(maybeSignal)), + preventAbort(preventAbort), + preventClose(preventClose), + preventCancel(preventCancel) {} + + bool checkSignal(jsg::Lock& js); + jsg::Promise pipeLoop(jsg::Lock& js); + jsg::Promise write(v8::Local value); + + JSG_MEMORY_INFO(State) { + tracker.trackField("resolver", promise); + tracker.trackField("signal", maybeSignal); + } + }; + + kj::Own state; + + Pipe(WritableStreamInternalController& parent, + ReadableStreamController::PipeController& source, + kj::Maybe::Resolver> promise, + bool preventAbort, + bool preventClose, + bool preventCancel, + kj::Maybe> maybeSignal) + : state(kj::refcounted(parent, + source, + kj::mv(promise), + preventAbort, + preventClose, + preventCancel, + kj::mv(maybeSignal))) {} + + ~Pipe() noexcept(false) { + state->aborted = true; + } + + WritableStreamInternalController& parent() { + return state->parent; + } + ReadableStreamController::PipeController& source() { + return state->source; + } + kj::Maybe::Resolver>& promise() { + return state->promise; + } + bool preventAbort() const { + return state->preventAbort; + } + bool preventClose() const { + return state->preventClose; + } + bool preventCancel() const { + return state->preventCancel; + } + kj::Maybe>& maybeSignal() { + return state->maybeSignal; + } - bool checkSignal(jsg::Lock& js); - jsg::Promise pipeLoop(jsg::Lock& js); - jsg::Promise write(v8::Local value); + bool checkSignal(jsg::Lock& js) { + return state->checkSignal(js); + } + jsg::Promise pipeLoop(jsg::Lock& js) { + return state->pipeLoop(js); + } + jsg::Promise write(v8::Local value) { + return state->write(value); + } JSG_MEMORY_INFO(Pipe) { - tracker.trackField("resolver", promise); - tracker.trackField("signal", maybeSignal); + tracker.trackField("state", state); } }; struct WriteEvent {