diff --git a/Sources/Sharing/Internal/PersistentReferences.swift b/Sources/Sharing/Internal/PersistentReferences.swift index 87a1a12..506c219 100644 --- a/Sources/Sharing/Internal/PersistentReferences.swift +++ b/Sources/Sharing/Internal/PersistentReferences.swift @@ -6,9 +6,8 @@ final class PersistentReferences: @unchecked Sendable, DependencyKey { static var liveValue: PersistentReferences { PersistentReferences() } static var testValue: PersistentReferences { PersistentReferences() } - struct Pair { - var cachedValue: Key.Value - var reference: _PersistentReference? + struct Weak { + weak var reference: _PersistentReference? } private var storage: [AnyHashable: Any] = [:] @@ -18,37 +17,33 @@ final class PersistentReferences: @unchecked Sendable, DependencyKey { forKey key: Key, default value: @autoclosure () throws -> Key.Value, skipInitialLoad: Bool - ) rethrows -> _ManagedReference { - try lock.withLock { - guard var pair = storage[key.id] as? Pair else { - let value = try value() - let persistentReference = _PersistentReference( - key: key, - value: value, - skipInitialLoad: skipInitialLoad - ) - storage[key.id] = Pair(cachedValue: value, reference: persistentReference) - return _ManagedReference(persistentReference) + ) rethrows -> _PersistentReference { + if let reference = lock.withLock({ (storage[key.id] as? Weak)?.reference }) { + return reference + } else { + let value = try value() + let reference = _PersistentReference( + key: key, + value: value, + skipInitialLoad: skipInitialLoad + ) + return lock.withLock { + if let reference = (storage[key.id] as? Weak)?.reference { + return reference + } else { + storage[key.id] = Weak(reference: reference) + reference.onDeinit = { [self] in + removeReference(forKey: key) + } + return reference + } } - guard let persistentReference = pair.reference else { - let persistentReference = _PersistentReference( - key: key, - value: skipInitialLoad ? (try? value()) ?? pair.cachedValue : pair.cachedValue, - skipInitialLoad: skipInitialLoad - ) - pair.reference = persistentReference - storage[key.id] = pair - return _ManagedReference(persistentReference) - } - return _ManagedReference(persistentReference) } } func removeReference(forKey key: Key) { lock.withLock { - guard var pair = storage[key.id] as? Pair else { return } - pair.reference = nil - storage[key.id] = pair + _ = storage.removeValue(forKey: key.id) } } } diff --git a/Sources/Sharing/Internal/Reference.swift b/Sources/Sharing/Internal/Reference.swift index b537a52..8610bbd 100644 --- a/Sources/Sharing/Internal/Reference.swift +++ b/Sources/Sharing/Internal/Reference.swift @@ -194,8 +194,8 @@ final class _PersistentReference: private var _isLoading = false private var _loadError: (any Error)? private var _saveError: (any Error)? - private var _referenceCount = 0 private var subscription: SharedSubscription? + internal var onDeinit: (() -> Void)? init(key: Key, value initialValue: Key.Value, skipInitialLoad: Bool) { self.key = key @@ -231,6 +231,10 @@ final class _PersistentReference: ) } + deinit { + onDeinit?() + } + var id: ObjectIdentifier { ObjectIdentifier(self) } var isLoading: Bool { @@ -309,20 +313,6 @@ final class _PersistentReference: withMutation(keyPath: \._saveError) {} } - func retain() { - lock.withLock { _referenceCount += 1 } - } - - func release() { - let shouldRelease = lock.withLock { - _referenceCount -= 1 - return _referenceCount <= 0 - } - guard shouldRelease else { return } - @Dependency(PersistentReferences.self) var persistentReferences - persistentReferences.removeReference(forKey: key) - } - func access( keyPath: KeyPath<_PersistentReference, Member>, fileID: StaticString = #fileID, @@ -454,85 +444,6 @@ extension _PersistentReference: MutableReference, Equatable where Key: SharedKey } } -final class _ManagedReference: Reference, Observable { - private let base: _PersistentReference - - init(_ base: _PersistentReference) { - base.retain() - self.base = base - } - - deinit { - base.release() - } - - var id: ObjectIdentifier { - base.id - } - - var isLoading: Bool { - base.isLoading - } - - var loadError: (any Error)? { - base.loadError - } - - var wrappedValue: Key.Value { - base.wrappedValue - } - - func load() async throws { - try await base.load() - } - - func touch() { - base.touch() - } - - #if canImport(Combine) - var publisher: any Publisher { - base.publisher - } - #endif - - var description: String { - base.description - } -} - -extension _ManagedReference: MutableReference, Equatable where Key: SharedKey { - var saveError: (any Error)? { - base.saveError - } - - var snapshot: Key.Value? { - base.snapshot - } - - func takeSnapshot( - _ value: Key.Value, - fileID: StaticString, - filePath: StaticString, - line: UInt, - column: UInt - ) { - base.takeSnapshot(value, fileID: fileID, filePath: filePath, line: line, column: column) - } - - func withLock(_ body: (inout Key.Value) throws -> R) rethrows -> R { - try base.withLock(body) - } - - func save() async throws { - try await base.save() - } - - static func == (lhs: _ManagedReference, rhs: _ManagedReference) -> Bool { - lhs.base == rhs.base - } -} - final class _AppendKeyPathReference< Base: Reference, Value, Path: KeyPath & Sendable >: Reference, Observable { diff --git a/Sources/Sharing/SharedContinuations.swift b/Sources/Sharing/SharedContinuations.swift index aaac3a0..dbf1bc2 100644 --- a/Sources/Sharing/SharedContinuations.swift +++ b/Sources/Sharing/SharedContinuations.swift @@ -210,7 +210,6 @@ public struct SaveContinuation: Sendable { private final class ContinuationBox: Sendable { private let callback: Mutex<(@Sendable (Result) -> Void)?> private let description: @Sendable () -> String - private let resumeCount = Mutex(0) init( callback: @escaping @Sendable (Result) -> Void, @@ -221,24 +220,23 @@ private final class ContinuationBox: Sendable { } deinit { - let isComplete = resumeCount.withLock { $0 } > 0 - if !isComplete { + if let callback = callback.withLock({ $0 }) { reportIssue( """ \(description()) leaked its continuation without one of its resume methods being \ invoked. This will cause tasks waiting on it to resume immediately. """ ) - callback.withLock { $0?(.success(nil)) } + callback(.success(nil)) } } func resume(with result: Result) { - let resumeCount = resumeCount.withLock { - $0 += 1 - return $0 + let callback = callback.withLock { callback in + defer { callback = nil } + return callback } - guard resumeCount == 1 else { + guard let callback else { reportIssue( """ \(description()) tried to resume its continuation more than once. @@ -246,9 +244,6 @@ private final class ContinuationBox: Sendable { ) return } - callback.withLock { callback in - defer { callback = nil } - callback?(result) - } + callback(result) } } diff --git a/Tests/SharingTests/SharedTests.swift b/Tests/SharingTests/SharedTests.swift index 0262420..a461555 100644 --- a/Tests/SharingTests/SharedTests.swift +++ b/Tests/SharingTests/SharedTests.swift @@ -1,3 +1,4 @@ +import Dependencies import Foundation import IdentifiedCollections import PerceptionCore @@ -35,6 +36,45 @@ import Testing let a = A() #expect(a.b.c == C()) } + + @Test func lockingOrderWithDependencies() async { + struct D: TestDependencyKey { + @Shared(.inMemory("count")) var count = 0 + init() { + Thread.sleep(forTimeInterval: 0.2) + $count.withLock { $0 += 1 } + } + static var testValue: D { D() } + } + let a = Task { + do { + try await Task.sleep(nanoseconds: 100_000_000) + @Dependency(D.self) var d + #expect(d.count == 1) + } catch {} + } + let b = Task { + @Shared(.inMemory("count")) var count: Int = { + Thread.sleep(forTimeInterval: 0.2) + return 2 + }() + #expect(count == 1) + } + let c = Task { + do { + try await Task.sleep(nanoseconds: 500_000_000) + Issue.record("Deadlock detected") + exit(1) + } catch {} + } + await withTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + _ = await (a.value, b.value) + c.cancel() + } + taskGroup.addTask { await c.value } + } + } } @Suite struct BoxReference {