Skip to content

Commit 8159d8c

Browse files
authored
Merge pull request #3740 from cloudflare/maizatskyi/2025-03-17-atomicref
kj::Arc<Worker>
2 parents ec83a32 + 1ebab7e commit 8159d8c

File tree

11 files changed

+42
-43
lines changed

11 files changed

+42
-43
lines changed

src/workerd/api/cache.c++

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ kj::Own<kj::HttpClient> Cache::getHttpClient(IoContext& context,
553553
.featureFlagsForFl = kj::none,
554554
};
555555
if (enableCompatFlags) {
556-
metadata.featureFlagsForFl = context.getWorker().getIsolate().getFeatureFlagsForFl();
556+
metadata.featureFlagsForFl = context.getWorker()->getIsolate().getFeatureFlagsForFl();
557557
}
558558
auto httpClient =
559559
cacheName.map([&](kj::String& n) {

src/workerd/api/queue.c++

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ kj::Promise<WorkerInterface::CustomEvent::Result> QueueCustomEventImpl::run(
600600
// returns and the promise that it returns (if any) has resolved.
601601
// * In the disabled path, the queue event isn't complete until all waitUntil'ed promises resolve.
602602
// This was how Queues originally worked, but made for a poor user experience.
603-
auto compatFlags = context.getWorker().getIsolate().getApi().getFeatureFlags();
603+
auto compatFlags = context.getWorker()->getIsolate().getApi().getFeatureFlags();
604604
if (compatFlags.getQueueConsumerNoWaitForWaitUntil()) {
605605
// The user has opted in to only waiting on their event handler rather than all waitUntil'd
606606
// promises.
@@ -691,7 +691,7 @@ kj::Promise<WorkerInterface::CustomEvent::Result> QueueCustomEventImpl::run(
691691
}
692692
}
693693
auto& ioContext = incomingRequest->getContext();
694-
auto scriptId = ioContext.getWorker().getScript().getId();
694+
auto scriptId = ioContext.getWorker()->getScript().getId();
695695
auto tasks = ioContext.getWaitUntilTasks().trace();
696696
if (result == IoContext_IncomingRequest::FinishScheduledResult::TIMEOUT) {
697697
KJ_LOG(WARNING, "NOSENTRY queue event hit timeout", scriptId, status, tasks);

src/workerd/io/io-context.c++

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ class IoContext::TimeoutManagerImpl::TimeoutState {
125125
};
126126

127127
IoContext::IoContext(ThreadContext& thread,
128-
kj::Own<const Worker> workerParam,
128+
kj::Arc<Worker> workerParam,
129129
kj::Maybe<Worker::Actor&> actorParam,
130130
kj::Own<LimitEnforcer> limitEnforcerParam)
131131
: thread(thread),
@@ -1382,7 +1382,7 @@ kj::Promise<void> IoContext::startDeleteQueueSignalTask(IoContext* context) {
13821382
// ======================================================================================
13831383

13841384
WarningAggregator::WarningAggregator(IoContext& context, EmitCallback emitter)
1385-
: worker(kj::atomicAddRef(context.getWorker())),
1385+
: worker(context.getWorker().addRef()),
13861386
requestMetrics(kj::addRef(context.getMetrics())),
13871387
emitter(kj::mv(emitter)) {}
13881388

src/workerd/io/io-context.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ class WarningAggregator final: public kj::AtomicRefcounted {
8787
using Map = kj::HashMap<const Key&, kj::Own<WarningAggregator>>;
8888

8989
private:
90-
kj::Own<const Worker> worker;
90+
kj::Arc<Worker> worker;
9191
kj::Own<RequestObserver> requestMetrics;
9292
EmitCallback emitter;
9393
kj::MutexGuarded<kj::Vector<kj::Own<WarningContext>>> warnings;
@@ -224,7 +224,7 @@ class IoContext final: public kj::Refcounted, private kj::TaskSet::ErrorHandler
224224

225225
// Construct a new IoContext. Before using it, you must also create an IncomingRequest.
226226
IoContext(ThreadContext& thread,
227-
kj::Own<const Worker> worker,
227+
kj::Arc<Worker> worker,
228228
kj::Maybe<Worker::Actor&> actor,
229229
kj::Own<LimitEnforcer> limitEnforcer);
230230

@@ -233,8 +233,8 @@ class IoContext final: public kj::Refcounted, private kj::TaskSet::ErrorHandler
233233

234234
using IncomingRequest = IoContext_IncomingRequest;
235235

236-
const Worker& getWorker() {
237-
return *worker;
236+
const kj::Arc<Worker>& getWorker() {
237+
return worker;
238238
}
239239
Worker::Lock& getCurrentLock() {
240240
return KJ_REQUIRE_NONNULL(currentLock);
@@ -835,7 +835,7 @@ class IoContext final: public kj::Refcounted, private kj::TaskSet::ErrorHandler
835835

836836
kj::Own<WeakRef> selfRef = kj::refcounted<WeakRef>(kj::Badge<IoContext>(), *this);
837837

838-
kj::Own<const Worker> worker;
838+
kj::Arc<Worker> worker;
839839
kj::Maybe<Worker::Actor&> actor;
840840
kj::Own<LimitEnforcer> limitEnforcer;
841841

src/workerd/io/worker-entrypoint.c++

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class WorkerEntrypoint final: public WorkerInterface {
4343
// topLevelRequest.getZoneDefaultWorkerLimits(), since the top level request may be shared between
4444
// zone and non-zone workers.
4545
static kj::Own<WorkerInterface> construct(ThreadContext& threadContext,
46-
kj::Own<const Worker> worker,
46+
kj::Arc<Worker> worker,
4747
kj::Maybe<kj::StringPtr> entrypointName,
4848
Frankenvalue props,
4949
kj::Maybe<kj::Own<Worker::Actor>> actor,
@@ -93,7 +93,7 @@ class WorkerEntrypoint final: public WorkerInterface {
9393
kj::Maybe<kj::Own<WorkerInterface>> failOpenService;
9494
bool loggedExceptionEarlier = false;
9595

96-
void init(kj::Own<const Worker> worker,
96+
void init(kj::Arc<Worker> worker,
9797
kj::Maybe<kj::Own<Worker::Actor>> actor,
9898
kj::Own<LimitEnforcer> limitEnforcer,
9999
kj::Own<void> ioContextDependency,
@@ -153,7 +153,7 @@ class WorkerEntrypoint::ResponseSentTracker final: public kj::HttpService::Respo
153153
};
154154

155155
kj::Own<WorkerInterface> WorkerEntrypoint::construct(ThreadContext& threadContext,
156-
kj::Own<const Worker> worker,
156+
kj::Arc<Worker> worker,
157157
kj::Maybe<kj::StringPtr> entrypointName,
158158
Frankenvalue props,
159159
kj::Maybe<kj::Own<Worker::Actor>> actor,
@@ -197,7 +197,7 @@ WorkerEntrypoint::WorkerEntrypoint(kj::Badge<WorkerEntrypoint> badge,
197197
props(kj::mv(props)),
198198
cfBlobJson(kj::mv(cfBlobJson)) {}
199199

200-
void WorkerEntrypoint::init(kj::Own<const Worker> worker,
200+
void WorkerEntrypoint::init(kj::Arc<Worker> worker,
201201
kj::Maybe<kj::Own<Worker::Actor>> actor,
202202
kj::Own<LimitEnforcer> limitEnforcer,
203203
kj::Own<void> ioContextDependency,
@@ -753,10 +753,10 @@ kj::Promise<WorkerInterface::CustomEvent::Result> WorkerEntrypoint::customEvent(
753753
}
754754

755755
#ifdef KJ_DEBUG
756-
void requestGc(const Worker& worker) {
756+
void requestGc(kj::Arc<Worker> worker) {
757757
TRACE_EVENT("workerd", "Debug: requestGc()");
758758
jsg::runInV8Stack([&](jsg::V8StackScope& stackScope) {
759-
auto& isolate = worker.getIsolate();
759+
auto& isolate = worker->getIsolate();
760760
auto lock = isolate.getApi().lock(stackScope);
761761
lock->requestGcForTesting();
762762
});
@@ -765,13 +765,13 @@ void requestGc(const Worker& worker) {
765765
template <typename T>
766766
kj::Promise<T> addGcPassForTest(IoContext& context, kj::Promise<T> promise) {
767767
TRACE_EVENT("workerd", "Debug: addGcPassForTest");
768-
auto worker = kj::atomicAddRef(context.getWorker());
768+
auto worker = context.getWorker().addRef();
769769
if constexpr (kj::isSameType<T, void>()) {
770770
co_await promise;
771-
requestGc(*worker);
771+
requestGc(kj::mv(worker));
772772
} else {
773773
auto ret = co_await promise;
774-
requestGc(*worker);
774+
requestGc(kj::mv(worker));
775775
co_return kj::mv(ret);
776776
}
777777
}
@@ -790,7 +790,7 @@ kj::Promise<T> WorkerEntrypoint::maybeAddGcPassForTest(IoContext& context, kj::P
790790
} // namespace
791791

792792
kj::Own<WorkerInterface> newWorkerEntrypoint(ThreadContext& threadContext,
793-
kj::Own<const Worker> worker,
793+
kj::Arc<Worker> worker,
794794
kj::Maybe<kj::StringPtr> entrypointName,
795795
Frankenvalue props,
796796
kj::Maybe<kj::Own<Worker::Actor>> actor,

src/workerd/io/worker-entrypoint.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class InvocationSpanContext;
2929
// - Or, falling back to proxying if passThroughOnException() was used.
3030
// - Finish waitUntil() tasks.
3131
kj::Own<WorkerInterface> newWorkerEntrypoint(ThreadContext& threadContext,
32-
kj::Own<const Worker> worker,
32+
kj::Arc<Worker> worker,
3333
kj::Maybe<kj::StringPtr> entrypointName,
3434
Frankenvalue props,
3535
kj::Maybe<kj::Own<Worker::Actor>> actor,

src/workerd/io/worker.c++

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,7 @@ struct Worker::Script::Impl {
803803
KJ_ASSERT(!FeatureFlags::get(js).getNewModuleRegistry(),
804804
"legacy dynamic imports must not be used with the new module registry");
805805
static auto constexpr handleDynamicImport =
806-
[](kj::Own<const Worker> worker, DynamicImportHandler handler,
806+
[](kj::Arc<Worker> worker, DynamicImportHandler handler,
807807
kj::Maybe<jsg::Ref<jsg::AsyncContextFrame>> asyncContext)
808808
-> kj::Promise<DynamicImportResult> {
809809
co_await kj::yield();
@@ -853,7 +853,7 @@ struct Worker::Script::Impl {
853853
auto& context = IoContext::current();
854854

855855
return context.awaitIo(js,
856-
handleDynamicImport(kj::atomicAddRef(context.getWorker()), kj::mv(handler),
856+
handleDynamicImport(context.getWorker().addRef(), kj::mv(handler),
857857
jsg::AsyncContextFrame::currentRef(js)),
858858
[](jsg::Lock& js, DynamicImportResult result) {
859859
if (result.isException) {
@@ -3431,7 +3431,7 @@ kj::Maybe<Worker::ConnectFn&> Worker::getConnectOverride(kj::StringPtr networkAd
34313431
return connectOverrides.find(networkAddress);
34323432
}
34333433

3434-
Worker::Actor::Actor(const Worker& worker,
3434+
Worker::Actor::Actor(kj::Arc<Worker> worker,
34353435
kj::Maybe<RequestTracker&> tracker,
34363436
Actor::Id actorId,
34373437
bool hasTransient,
@@ -3445,7 +3445,7 @@ Worker::Actor::Actor(const Worker& worker,
34453445
kj::Maybe<kj::Own<HibernationManager>> manager,
34463446
kj::Maybe<uint16_t> hibernationEventType,
34473447
kj::Maybe<rpc::Container::Client> container)
3448-
: worker(kj::atomicAddRef(worker)),
3448+
: worker(kj::mv(worker)),
34493449
tracker(tracker.map([](RequestTracker& tracker) { return tracker.addRef(); })) {
34503450
impl = kj::heap<Impl>(*this, lock, kj::mv(actorId), hasTransient, kj::mv(makeActorCache),
34513451
kj::mv(makeStorage), kj::mv(loopback), timerChannel, kj::mv(metrics), kj::mv(manager),

src/workerd/io/worker.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -763,7 +763,7 @@ class Worker::Actor final: public kj::Refcounted {
763763

764764
// Create a new Actor hosted by this Worker. Note that this Actor object may only be manipulated
765765
// from the thread that created it.
766-
Actor(const Worker& worker,
766+
Actor(kj::Arc<Worker> worker,
767767
kj::Maybe<RequestTracker&> tracker,
768768
Id actorId,
769769
bool hasTransient,
@@ -849,8 +849,8 @@ class Worker::Actor final: public kj::Refcounted {
849849
// Only needs to be called when allocating a HibernationManager!
850850
kj::Maybe<uint16_t> getHibernationEventType();
851851

852-
inline const Worker& getWorker() {
853-
return *worker;
852+
inline kj::Arc<Worker> getWorker() {
853+
return worker.addRef();
854854
}
855855

856856
void assertCanSetAlarm();
@@ -872,7 +872,7 @@ class Worker::Actor final: public kj::Refcounted {
872872
private:
873873
kj::Promise<WorkerInterface::ScheduleAlarmResult> handleAlarm(kj::Date scheduledTime);
874874

875-
kj::Own<const Worker> worker;
875+
kj::Arc<Worker> worker;
876876
kj::Maybe<kj::Own<RequestTracker>> tracker;
877877
struct Impl;
878878
kj::Own<Impl> impl;

src/workerd/server/server.c++

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1734,7 +1734,7 @@ class Server::WorkerService final: public Service,
17341734
using AbortActorsCallback = kj::Function<void()>;
17351735

17361736
WorkerService(ThreadContext& threadContext,
1737-
kj::Own<const Worker> worker,
1737+
kj::Arc<Worker> worker,
17381738
kj::Maybe<kj::HashSet<kj::String>> defaultEntrypointHandlers,
17391739
kj::HashMap<kj::String, kj::HashSet<kj::String>> namedEntrypoints,
17401740
kj::HashSet<kj::String> actorClassEntrypoints,
@@ -1921,8 +1921,8 @@ class Server::WorkerService final: public Service,
19211921
observer = kj::refcounted<RequestObserverWithTracer>(
19221922
mapAddRef(workerTracer), kj::mv(streamingTailWorkers), waitUntilTasks);
19231923

1924-
return newWorkerEntrypoint(threadContext, kj::atomicAddRef(*worker), entrypointName,
1925-
kj::mv(props), kj::mv(actor), kj::Own<LimitEnforcer>(this, kj::NullDisposer::instance),
1924+
return newWorkerEntrypoint(threadContext, worker.addRef(), entrypointName, kj::mv(props),
1925+
kj::mv(actor), kj::Own<LimitEnforcer>(this, kj::NullDisposer::instance),
19261926
{}, // ioContextDependency
19271927
kj::Own<IoChannelFactory>(this, kj::NullDisposer::instance), kj::mv(observer),
19281928
waitUntilTasks,
@@ -2081,16 +2081,15 @@ class Server::WorkerService final: public Service,
20812081
co_return;
20822082
}
20832083
KJ_IF_SOME(m, manager) {
2084-
auto& worker = a->getWorker();
2085-
auto workerStrongRef = kj::atomicAddRef(worker);
2084+
auto worker = a->getWorker();
20862085
// Take an async lock, we can't use `takeAsyncLock(RequestObserver&)` since we don't
20872086
// have an `IncomingRequest` at this point.
20882087
//
20892088
// Note that we do not have a race here because this is part of the `shutdownTask`
20902089
// promise. If a new request comes in while we're waiting to get the lock then we will
20912090
// cancel this promise.
2092-
Worker::AsyncLock asyncLock = co_await worker.takeAsyncLockWithoutRequest(nullptr);
2093-
workerStrongRef->runInLockScope(
2091+
Worker::AsyncLock asyncLock = co_await worker->takeAsyncLockWithoutRequest(nullptr);
2092+
worker->runInLockScope(
20942093
asyncLock, [&](Worker::Lock& lock) { m->hibernateWebSockets(lock); });
20952094
}
20962095
a->shutdown(
@@ -2375,7 +2374,7 @@ class Server::WorkerService final: public Service,
23752374
static constexpr uint16_t hibernationEventTypeId = 8;
23762375

23772376
actorContainer->actor.emplace(
2378-
kj::refcounted<Worker::Actor>(*service.worker, actorContainer->getTracker(),
2377+
kj::refcounted<Worker::Actor>(service.worker.addRef(), actorContainer->getTracker(),
23792378
kj::str(idPtr), true, kj::mv(makeActorCache), className, kj::mv(makeStorage),
23802379
lock, kj::mv(loopback), timerChannel, kj::refcounted<ActorObserver>(),
23812380
actorContainer->tryGetManagerRef(), hibernationEventTypeId));
@@ -2466,7 +2465,7 @@ class Server::WorkerService final: public Service,
24662465
// LinkedIoChannels owns the SqliteDatabase::Vfs, so make sure it is destroyed last.
24672466
kj::OneOf<LinkCallback, LinkedIoChannels> ioChannels;
24682467

2469-
kj::Own<const Worker> worker;
2468+
kj::Arc<Worker> worker;
24702469
kj::Maybe<kj::HashSet<kj::String>> defaultEntrypointHandlers;
24712470
kj::HashMap<kj::String, kj::HashSet<kj::String>> namedEntrypoints;
24722471
kj::HashSet<kj::String> actorClassEntrypoints;
@@ -3285,7 +3284,7 @@ kj::Own<Server::Service> Server::makeWorker(kj::StringPtr name,
32853284
}
32863285

32873286
jsg::V8Ref<v8::Object> ctxExportsHandle = nullptr;
3288-
auto worker = kj::atomicRefcounted<Worker>(kj::mv(script), kj::atomicRefcounted<WorkerObserver>(),
3287+
auto worker = kj::arc<Worker>(kj::mv(script), kj::atomicRefcounted<WorkerObserver>(),
32893288
[&](jsg::Lock& lock, const Worker::Api& api, v8::Local<v8::Object> target,
32903289
v8::Local<v8::Object> ctxExports) {
32913290
// We can't fill in ctx.exports yet because we need to run the validator first to discover

src/workerd/tests/test-fixture.c++

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ TestFixture::TestFixture(SetupParams&& params)
343343
IsolateObserver::StartType::COLD,
344344
false,
345345
nullptr)),
346-
worker(kj::atomicRefcounted<Worker>(kj::atomicAddRef(*workerScript),
346+
worker(kj::arc<Worker>(kj::atomicAddRef(*workerScript),
347347
kj::atomicRefcounted<WorkerObserver>(),
348348
[](jsg::Lock&, const Worker::Api&, v8::Local<v8::Object>, v8::Local<v8::Object>) {
349349
// no bindings, nothing to do
@@ -367,7 +367,7 @@ TestFixture::TestFixture(SetupParams&& params)
367367
return jsg::alloc<api::DurableObjectStorage>(
368368
IoContext::current().addObject(actorCache), /*enableSql=*/false);
369369
};
370-
actor = kj::refcounted<Worker::Actor>(*worker, /*tracker=*/kj::none, kj::mv(id),
370+
actor = kj::refcounted<Worker::Actor>(worker.addRef(), /*tracker=*/kj::none, kj::mv(id),
371371
/*hasTransient=*/false, makeActorCache,
372372
/*classname=*/kj::none, makeStorage, lock, kj::refcounted<MockActorLoopback>(),
373373
*timerChannel, kj::refcounted<ActorObserver>(), kj::none, kj::none);
@@ -408,7 +408,7 @@ void TestFixture::runInIoContext(kj::Function<kj::Promise<void>(const Environmen
408408

409409
kj::Own<IoContext::IncomingRequest> TestFixture::createIncomingRequest() {
410410
auto context = kj::refcounted<IoContext>(
411-
threadContext, kj::atomicAddRef(*worker), actor, kj::heap<MockLimitEnforcer>());
411+
threadContext, worker.addRef(), actor, kj::heap<MockLimitEnforcer>());
412412
auto invocationSpanContext = tracing::InvocationSpanContext::newForInvocation(kj::none, kj::none);
413413
auto incomingRequest = kj::heap<IoContext::IncomingRequest>(kj::addRef(*context),
414414
kj::heap<DummyIoChannelFactory>(*timerChannel), kj::refcounted<RequestObserver>(), nullptr,

0 commit comments

Comments
 (0)