Skip to content

Commit 8f8fef0

Browse files
pyrtsastephencelis
andauthored
Fix deadlock when accessing @Shared in a Dependency (#153)
* Add failing test for persistence locking order w.r.t. Dependencies * Avoid deadlock when accessing `@Shared` in a `@Dependency` ...by not holding locks while accessing `Dependency` logic with locks of its own. Also simplify `PersistentReferences` implementation by maintaining weak references to `_PersistentReference<Key>` objects directly, without custom reference counting. Fixes #149. * Simplify ContinuationBox implementation * Update Reference.swift * Remove needless withExtendedLifetime and change guard to if/else --------- Co-authored-by: Stephen Celis <[email protected]>
1 parent 2298167 commit 8f8fef0

File tree

4 files changed

+75
-134
lines changed

4 files changed

+75
-134
lines changed

Sources/Sharing/Internal/PersistentReferences.swift

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@ final class PersistentReferences: @unchecked Sendable, DependencyKey {
66
static var liveValue: PersistentReferences { PersistentReferences() }
77
static var testValue: PersistentReferences { PersistentReferences() }
88

9-
struct Pair<Key: SharedReaderKey> {
10-
var cachedValue: Key.Value
11-
var reference: _PersistentReference<Key>?
9+
struct Weak<Key: SharedReaderKey> {
10+
weak var reference: _PersistentReference<Key>?
1211
}
1312

1413
private var storage: [AnyHashable: Any] = [:]
@@ -18,37 +17,33 @@ final class PersistentReferences: @unchecked Sendable, DependencyKey {
1817
forKey key: Key,
1918
default value: @autoclosure () throws -> Key.Value,
2019
skipInitialLoad: Bool
21-
) rethrows -> _ManagedReference<Key> {
22-
try lock.withLock {
23-
guard var pair = storage[key.id] as? Pair<Key> else {
24-
let value = try value()
25-
let persistentReference = _PersistentReference(
26-
key: key,
27-
value: value,
28-
skipInitialLoad: skipInitialLoad
29-
)
30-
storage[key.id] = Pair(cachedValue: value, reference: persistentReference)
31-
return _ManagedReference(persistentReference)
20+
) rethrows -> _PersistentReference<Key> {
21+
if let reference = lock.withLock({ (storage[key.id] as? Weak<Key>)?.reference }) {
22+
return reference
23+
} else {
24+
let value = try value()
25+
let reference = _PersistentReference(
26+
key: key,
27+
value: value,
28+
skipInitialLoad: skipInitialLoad
29+
)
30+
return lock.withLock {
31+
if let reference = (storage[key.id] as? Weak<Key>)?.reference {
32+
return reference
33+
} else {
34+
storage[key.id] = Weak(reference: reference)
35+
reference.onDeinit = { [self] in
36+
removeReference(forKey: key)
37+
}
38+
return reference
39+
}
3240
}
33-
guard let persistentReference = pair.reference else {
34-
let persistentReference = _PersistentReference(
35-
key: key,
36-
value: skipInitialLoad ? (try? value()) ?? pair.cachedValue : pair.cachedValue,
37-
skipInitialLoad: skipInitialLoad
38-
)
39-
pair.reference = persistentReference
40-
storage[key.id] = pair
41-
return _ManagedReference(persistentReference)
42-
}
43-
return _ManagedReference(persistentReference)
4441
}
4542
}
4643

4744
func removeReference<Key: SharedReaderKey>(forKey key: Key) {
4845
lock.withLock {
49-
guard var pair = storage[key.id] as? Pair<Key> else { return }
50-
pair.reference = nil
51-
storage[key.id] = pair
46+
_ = storage.removeValue(forKey: key.id)
5247
}
5348
}
5449
}

Sources/Sharing/Internal/Reference.swift

Lines changed: 5 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,8 @@ final class _PersistentReference<Key: SharedReaderKey>:
194194
private var _isLoading = false
195195
private var _loadError: (any Error)?
196196
private var _saveError: (any Error)?
197-
private var _referenceCount = 0
198197
private var subscription: SharedSubscription?
198+
internal var onDeinit: (() -> Void)?
199199

200200
init(key: Key, value initialValue: Key.Value, skipInitialLoad: Bool) {
201201
self.key = key
@@ -231,6 +231,10 @@ final class _PersistentReference<Key: SharedReaderKey>:
231231
)
232232
}
233233

234+
deinit {
235+
onDeinit?()
236+
}
237+
234238
var id: ObjectIdentifier { ObjectIdentifier(self) }
235239

236240
var isLoading: Bool {
@@ -309,20 +313,6 @@ final class _PersistentReference<Key: SharedReaderKey>:
309313
withMutation(keyPath: \._saveError) {}
310314
}
311315

312-
func retain() {
313-
lock.withLock { _referenceCount += 1 }
314-
}
315-
316-
func release() {
317-
let shouldRelease = lock.withLock {
318-
_referenceCount -= 1
319-
return _referenceCount <= 0
320-
}
321-
guard shouldRelease else { return }
322-
@Dependency(PersistentReferences.self) var persistentReferences
323-
persistentReferences.removeReference(forKey: key)
324-
}
325-
326316
func access<Member>(
327317
keyPath: KeyPath<_PersistentReference, Member>,
328318
fileID: StaticString = #fileID,
@@ -454,85 +444,6 @@ extension _PersistentReference: MutableReference, Equatable where Key: SharedKey
454444
}
455445
}
456446

457-
final class _ManagedReference<Key: SharedReaderKey>: Reference, Observable {
458-
private let base: _PersistentReference<Key>
459-
460-
init(_ base: _PersistentReference<Key>) {
461-
base.retain()
462-
self.base = base
463-
}
464-
465-
deinit {
466-
base.release()
467-
}
468-
469-
var id: ObjectIdentifier {
470-
base.id
471-
}
472-
473-
var isLoading: Bool {
474-
base.isLoading
475-
}
476-
477-
var loadError: (any Error)? {
478-
base.loadError
479-
}
480-
481-
var wrappedValue: Key.Value {
482-
base.wrappedValue
483-
}
484-
485-
func load() async throws {
486-
try await base.load()
487-
}
488-
489-
func touch() {
490-
base.touch()
491-
}
492-
493-
#if canImport(Combine)
494-
var publisher: any Publisher<Key.Value, Never> {
495-
base.publisher
496-
}
497-
#endif
498-
499-
var description: String {
500-
base.description
501-
}
502-
}
503-
504-
extension _ManagedReference: MutableReference, Equatable where Key: SharedKey {
505-
var saveError: (any Error)? {
506-
base.saveError
507-
}
508-
509-
var snapshot: Key.Value? {
510-
base.snapshot
511-
}
512-
513-
func takeSnapshot(
514-
_ value: Key.Value,
515-
fileID: StaticString,
516-
filePath: StaticString,
517-
line: UInt,
518-
column: UInt
519-
) {
520-
base.takeSnapshot(value, fileID: fileID, filePath: filePath, line: line, column: column)
521-
}
522-
523-
func withLock<R>(_ body: (inout Key.Value) throws -> R) rethrows -> R {
524-
try base.withLock(body)
525-
}
526-
527-
func save() async throws {
528-
try await base.save()
529-
}
530-
531-
static func == (lhs: _ManagedReference, rhs: _ManagedReference) -> Bool {
532-
lhs.base == rhs.base
533-
}
534-
}
535-
536447
final class _AppendKeyPathReference<
537448
Base: Reference, Value, Path: KeyPath<Base.Value, Value> & Sendable
538449
>: Reference, Observable {

Sources/Sharing/SharedContinuations.swift

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,6 @@ public struct SaveContinuation: Sendable {
210210
private final class ContinuationBox<Value>: Sendable {
211211
private let callback: Mutex<(@Sendable (Result<Value?, any Error>) -> Void)?>
212212
private let description: @Sendable () -> String
213-
private let resumeCount = Mutex(0)
214213

215214
init(
216215
callback: @escaping @Sendable (Result<Value?, any Error>) -> Void,
@@ -221,34 +220,30 @@ private final class ContinuationBox<Value>: Sendable {
221220
}
222221

223222
deinit {
224-
let isComplete = resumeCount.withLock { $0 } > 0
225-
if !isComplete {
223+
if let callback = callback.withLock({ $0 }) {
226224
reportIssue(
227225
"""
228226
\(description()) leaked its continuation without one of its resume methods being \
229227
invoked. This will cause tasks waiting on it to resume immediately.
230228
"""
231229
)
232-
callback.withLock { $0?(.success(nil)) }
230+
callback(.success(nil))
233231
}
234232
}
235233

236234
func resume(with result: Result<Value?, any Error>) {
237-
let resumeCount = resumeCount.withLock {
238-
$0 += 1
239-
return $0
235+
let callback = callback.withLock { callback in
236+
defer { callback = nil }
237+
return callback
240238
}
241-
guard resumeCount == 1 else {
239+
guard let callback else {
242240
reportIssue(
243241
"""
244242
\(description()) tried to resume its continuation more than once.
245243
"""
246244
)
247245
return
248246
}
249-
callback.withLock { callback in
250-
defer { callback = nil }
251-
callback?(result)
252-
}
247+
callback(result)
253248
}
254249
}

Tests/SharingTests/SharedTests.swift

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import Dependencies
12
import Foundation
23
import IdentifiedCollections
34
import PerceptionCore
@@ -35,6 +36,45 @@ import Testing
3536
let a = A()
3637
#expect(a.b.c == C())
3738
}
39+
40+
@Test func lockingOrderWithDependencies() async {
41+
struct D: TestDependencyKey {
42+
@Shared(.inMemory("count")) var count = 0
43+
init() {
44+
Thread.sleep(forTimeInterval: 0.2)
45+
$count.withLock { $0 += 1 }
46+
}
47+
static var testValue: D { D() }
48+
}
49+
let a = Task {
50+
do {
51+
try await Task.sleep(nanoseconds: 100_000_000)
52+
@Dependency(D.self) var d
53+
#expect(d.count == 1)
54+
} catch {}
55+
}
56+
let b = Task {
57+
@Shared(.inMemory("count")) var count: Int = {
58+
Thread.sleep(forTimeInterval: 0.2)
59+
return 2
60+
}()
61+
#expect(count == 1)
62+
}
63+
let c = Task {
64+
do {
65+
try await Task.sleep(nanoseconds: 500_000_000)
66+
Issue.record("Deadlock detected")
67+
exit(1)
68+
} catch {}
69+
}
70+
await withTaskGroup(of: Void.self) { taskGroup in
71+
taskGroup.addTask {
72+
_ = await (a.value, b.value)
73+
c.cancel()
74+
}
75+
taskGroup.addTask { await c.value }
76+
}
77+
}
3878
}
3979

4080
@Suite struct BoxReference {

0 commit comments

Comments
 (0)