Skip to content

Commit fbd6b65

Browse files
committed
feat: add async semaphore
1 parent 2160c12 commit fbd6b65

File tree

8 files changed

+332
-6
lines changed

8 files changed

+332
-6
lines changed

.gitignore

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,14 @@ playground.xcworkspace
3737
# Swift Package Manager
3838
#
3939
# Add this line if you want to avoid checking in source code from Swift Package Manager dependencies.
40-
# Packages/
41-
# Package.pins
42-
# Package.resolved
43-
# *.xcodeproj
44-
#
40+
Packages/
41+
Package.pins
42+
Package.resolved
43+
*.xcodeproj
44+
4545
# Xcode automatically generates this directory with a .xcworkspacedata file and xcuserdata
4646
# hence it is not needed unless you have added a package configuration file to your project
47-
# .swiftpm
47+
.swiftpm
4848

4949
.build/
5050

Package.swift

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// swift-tools-version: 5.5
2+
3+
import PackageDescription
4+
5+
let package = Package(
6+
name: "AsyncObject",
7+
platforms: [
8+
.macOS(.v10_15),
9+
.iOS(.v13),
10+
.tvOS(.v13),
11+
.watchOS(.v6),
12+
],
13+
products: [
14+
.library(
15+
name: "AsyncObject",
16+
targets: ["AsyncObject"]
17+
),
18+
],
19+
dependencies: [
20+
.package(
21+
url: "https://github.com/apple/swift-collections.git",
22+
.upToNextMajor(from: "1.0.0")
23+
),
24+
],
25+
targets: [
26+
.target(
27+
name: "AsyncObject",
28+
dependencies: [
29+
.product(name: "OrderedCollections", package: "swift-collections"),
30+
]
31+
),
32+
.testTarget(
33+
name: "AsyncObjectTests",
34+
dependencies: ["AsyncObject"]
35+
),
36+
]
37+
)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import Foundation
2+
3+
public actor AsyncMutex {
4+
private typealias Continuation = UnsafeContinuation<Void, Error>
5+
private var continuations: [UUID: Continuation] = [:]
6+
private var locked: Bool
7+
8+
private func addContinuation(
9+
_ continuation: Continuation,
10+
withKey key: UUID
11+
) {
12+
continuations[key] = continuation
13+
}
14+
15+
private func removeContinuation(withKey key: UUID) {
16+
continuations.removeValue(forKey: key)
17+
}
18+
19+
public init(lockedInitially locked: Bool = true) {
20+
self.locked = locked
21+
}
22+
23+
public func lock() {
24+
locked = true
25+
}
26+
27+
public func release() {
28+
continuations.forEach { $0.value.resume() }
29+
continuations = [:]
30+
locked = false
31+
}
32+
33+
public func wait() async {
34+
guard locked else { return }
35+
let key = UUID()
36+
do {
37+
try await withUnsafeThrowingContinuationCancellationHandler(
38+
handler: { (continuation: Continuation) in
39+
Task { await removeContinuation(withKey: key) }
40+
},
41+
{ addContinuation($0, withKey: key) }
42+
)
43+
} catch {
44+
debugPrint(
45+
"Wait on mutex for continuation task with key: \(key)"
46+
+ " cancelled with error \(error)"
47+
)
48+
}
49+
}
50+
51+
@discardableResult
52+
public func wait(
53+
forNanoseconds duration: UInt64
54+
) async -> TaskTimeoutResult {
55+
guard locked else { return .success }
56+
await withTaskGroup(of: Void.self) { group in
57+
group.addTask { [weak self] in await self?.wait() }
58+
group.addTask {
59+
do {
60+
try await Task.sleep(nanoseconds: duration)
61+
} catch {}
62+
}
63+
64+
for await _ in group.prefix(1) {
65+
group.cancelAll()
66+
}
67+
}
68+
return locked ? .timedOut : .success
69+
}
70+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import Foundation
2+
import OrderedCollections
3+
4+
actor AsyncSemaphore {
5+
private typealias Continuation = UnsafeContinuation<Void, Error>
6+
private var continuations: OrderedDictionary<UUID, Continuation> = [:]
7+
private var limit: UInt
8+
private var count: Int
9+
10+
private func addContinuation(
11+
_ continuation: Continuation,
12+
withKey key: UUID
13+
) {
14+
continuations[key] = continuation
15+
}
16+
17+
private func removeContinuation(withKey key: UUID) {
18+
continuations.removeValue(forKey: key)
19+
}
20+
21+
public init(value count: UInt = 0) {
22+
self.limit = count + 1
23+
self.count = Int(limit)
24+
}
25+
26+
public func signal() {
27+
guard count < limit else { return }
28+
count += 1
29+
guard !continuations.isEmpty else { return }
30+
let (_, continuation) = continuations.removeFirst()
31+
continuation.resume()
32+
}
33+
34+
public func wait() async {
35+
count -= 1
36+
if count > 0 { return }
37+
let key = UUID()
38+
do {
39+
try await withUnsafeThrowingContinuationCancellationHandler(
40+
handler: { (continuation: Continuation) in
41+
Task { await removeContinuation(withKey: key) }
42+
},
43+
{ addContinuation($0, withKey: key) }
44+
)
45+
} catch {
46+
debugPrint(
47+
"Wait on semaphore for continuation task with key: \(key)"
48+
+ " cancelled with error \(error)"
49+
)
50+
}
51+
}
52+
53+
@discardableResult
54+
public func wait(
55+
forNanoseconds duration: UInt64
56+
) async -> TaskTimeoutResult {
57+
var timedOut = true
58+
await withTaskGroup(of: Bool.self) { group in
59+
group.addTask {
60+
[weak self] in await self?.wait()
61+
return true
62+
}
63+
group.addTask {
64+
do {
65+
try await Task.sleep(nanoseconds: duration)
66+
return false
67+
} catch {
68+
return true
69+
}
70+
}
71+
72+
for await result in group.prefix(1) {
73+
timedOut = !result
74+
group.cancelAll()
75+
}
76+
}
77+
return timedOut ? .timedOut : .success
78+
}
79+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
public enum TaskTimeoutResult {
2+
case success
3+
case timedOut
4+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
func withUnsafeThrowingContinuationCancellationHandler<T: Sendable>(
2+
handler: @Sendable (UnsafeContinuation<T, Error>) -> Void,
3+
_ fn: (UnsafeContinuation<T, Error>) -> Void
4+
) async throws -> T {
5+
typealias Continuation = UnsafeContinuation<T, Error>
6+
let wrapper = Continuation.Wrapper()
7+
let value = try await withTaskCancellationHandler {
8+
guard let continuation = wrapper.value else { return }
9+
wrapper.cancel(withError: CancellationError())
10+
handler(continuation)
11+
} operation: { () -> T in
12+
let value = try await withUnsafeThrowingContinuation { (c: Continuation) in
13+
wrapper.value = c
14+
fn(c)
15+
}
16+
return value
17+
}
18+
return value
19+
}
20+
21+
extension UnsafeContinuation {
22+
class Wrapper {
23+
var value: UnsafeContinuation?
24+
25+
init(value: UnsafeContinuation? = nil) {
26+
self.value = value
27+
}
28+
29+
func cancel(withError error: E) {
30+
value?.resume(throwing: error)
31+
}
32+
}
33+
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import XCTest
2+
@testable import AsyncObject
3+
4+
class AsyncMutexTests: XCTestCase {
5+
6+
func checkWait(
7+
for mutex: AsyncMutex,
8+
durationInSeconds seconds: Int = 0
9+
) async throws {
10+
Task.detached {
11+
try await Task.sleep(nanoseconds: UInt64(5E9))
12+
await mutex.release()
13+
}
14+
await checkExecInterval(
15+
for: { await mutex.wait() },
16+
durationInSeconds: seconds
17+
)
18+
}
19+
20+
func testMutexWait() async throws {
21+
let mutex = AsyncMutex()
22+
try await checkWait(for: mutex, durationInSeconds: 5)
23+
}
24+
25+
func testMutexLockAndWait() async throws {
26+
let mutex = AsyncMutex(lockedInitially: false)
27+
await mutex.lock()
28+
try await checkWait(for: mutex, durationInSeconds: 5)
29+
}
30+
31+
func testReleasedMutexWait() async throws {
32+
let mutex = AsyncMutex(lockedInitially: false)
33+
try await checkWait(for: mutex)
34+
}
35+
36+
func testMutexWaitWithTimeout() async throws {
37+
let mutex = AsyncMutex()
38+
var result: TaskTimeoutResult = .success
39+
await checkExecInterval(
40+
for: {
41+
result = await mutex.wait(forNanoseconds: UInt64(4E9))
42+
},
43+
durationInSeconds: 4
44+
)
45+
XCTAssertEqual(result, .timedOut)
46+
}
47+
48+
func testMutexWaitSuccessWithoutTimeout() async throws {
49+
let mutex = AsyncMutex()
50+
var result: TaskTimeoutResult = .timedOut
51+
Task.detached {
52+
try await Task.sleep(nanoseconds: UInt64(5E9))
53+
await mutex.release()
54+
}
55+
await checkExecInterval(
56+
for: {
57+
result = await mutex.wait(forNanoseconds: UInt64(10E9))
58+
},
59+
durationInSeconds: 5
60+
)
61+
XCTAssertEqual(result, .success)
62+
}
63+
64+
func testReleasedMutexWaitSuccessWithoutTimeout() async throws {
65+
let mutex = AsyncMutex(lockedInitially: false)
66+
var result: TaskTimeoutResult = .timedOut
67+
await checkExecInterval(
68+
for: {
69+
result = await mutex.wait(forNanoseconds: UInt64(10E9))
70+
},
71+
durationInSeconds: 0
72+
)
73+
XCTAssertEqual(result, .success)
74+
}
75+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import XCTest
2+
3+
extension XCTestCase {
4+
func checkExecInterval(
5+
for task: () async throws -> Void,
6+
durationInSeconds seconds: Int = 0
7+
) async rethrows {
8+
let time = DispatchTime.now()
9+
try await task()
10+
let execTime = time.distance(to: DispatchTime.now())
11+
switch execTime {
12+
case .seconds(let value):
13+
XCTAssertEqual(seconds, value)
14+
case .microseconds(let value):
15+
XCTAssertEqual(seconds, value/Int(1E6))
16+
case .milliseconds(let value):
17+
XCTAssertEqual(seconds, value/Int(1E3))
18+
case .nanoseconds(let value):
19+
XCTAssertEqual(seconds, value/Int(1E9))
20+
case .never: fallthrough
21+
@unknown default:
22+
NSException(
23+
name: NSExceptionName(rawValue: "UnExpectedInterval"),
24+
reason: "UnExpected time interval"
25+
).raise()
26+
}
27+
}
28+
}

0 commit comments

Comments
 (0)