@@ -6,7 +6,7 @@ let newLine = 0x0A
66let headerPreamble = " codervpn "
77
88/// A message that has the `rpc` property for recording participation in a unary RPC.
9- protocol RPCMessage {
9+ protocol RPCMessage : Sendable {
1010 var rpc : Vpn_RPC { get set }
1111 /// Returns true if `rpc` has been explicitly set.
1212 var hasRpc : Bool { get }
@@ -49,8 +49,8 @@ struct ProtoVersion: CustomStringConvertible, Equatable, Codable {
4949 }
5050}
5151
52- /// An abstract base class for implementations that need to communicate using the VPN protocol.
53- class Speaker < SendMsg: RPCMessage & Message , RecvMsg: RPCMessage & Message > {
52+ /// An actor that communicates using the VPN protocol
53+ actor Speaker < SendMsg: RPCMessage & Message , RecvMsg: RPCMessage & Message > {
5454 private let logger = Logger ( subsystem: " com.coder.Coder-Desktop " , category: " proto " )
5555 private let writeFD : FileHandle
5656 private let readFD : FileHandle
@@ -93,43 +93,6 @@ class Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
9393 try _ = await hndsh. handshake ( )
9494 }
9595
96- /// Reads and handles protocol messages.
97- func readLoop( ) async throws {
98- for try await msg in try await receiver. messages ( ) {
99- guard msg. hasRpc else {
100- handleMessage ( msg)
101- continue
102- }
103- guard msg. rpc. msgID == 0 else {
104- let req = RPCRequest < SendMsg , RecvMsg > ( req: msg, sender: sender)
105- handleRPC ( req)
106- continue
107- }
108- guard msg. rpc. responseTo == 0 else {
109- logger. debug ( " got RPC reply for msgID \( msg. rpc. responseTo) " )
110- do throws ( RPCError) {
111- try await self . secretary. route ( reply: msg)
112- } catch {
113- logger. error (
114- " couldn't route RPC reply for \( msg. rpc. responseTo) : \( error) " )
115- }
116- continue
117- }
118- }
119- }
120-
121- /// Handles a single non-RPC message. It is expected that subclasses override this method with their own handlers.
122- func handleMessage( _ msg: RecvMsg) {
123- // just log
124- logger. debug ( " got non-RPC message \( msg. textFormatString ( ) ) " )
125- }
126-
127- /// Handle a single RPC request. It is expected that subclasses override this method with their own handlers.
128- func handleRPC( _ req: RPCRequest < SendMsg , RecvMsg > ) {
129- // just log
130- logger. debug ( " got RPC message \( req. msg. textFormatString ( ) ) " )
131- }
132-
13396 /// Send a unary RPC message and handle the response
13497 func unaryRPC( _ req: SendMsg ) async throws -> RecvMsg {
13598 return try await withCheckedThrowingContinuation { continuation in
@@ -166,10 +129,45 @@ class Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
166129 logger. error ( " failed to close read file handle: \( error) " )
167130 }
168131 }
132+
133+ enum IncomingMessage {
134+ case message( RecvMsg )
135+ case RPC( RPCRequest < SendMsg , RecvMsg > )
136+ }
137+ }
138+
139+ extension Speaker : AsyncSequence , AsyncIteratorProtocol {
140+ typealias Element = IncomingMessage
141+
142+ public nonisolated func makeAsyncIterator( ) -> Speaker < SendMsg , RecvMsg > {
143+ self
144+ }
145+
146+ func next( ) async throws -> IncomingMessage ? {
147+ for try await msg in try await receiver. messages ( ) {
148+ guard msg. hasRpc else {
149+ return . message( msg)
150+ }
151+ guard msg. rpc. msgID == 0 else {
152+ return . RPC( RPCRequest < SendMsg , RecvMsg > ( req: msg, sender: sender) )
153+ }
154+ guard msg. rpc. responseTo == 0 else {
155+ logger. debug ( " got RPC reply for msgID \( msg. rpc. responseTo) " )
156+ do throws ( RPCError) {
157+ try await self . secretary. route ( reply: msg)
158+ } catch {
159+ logger. error (
160+ " couldn't route RPC reply for \( msg. rpc. responseTo) : \( error) " )
161+ }
162+ continue
163+ }
164+ }
165+ return nil
166+ }
169167}
170168
171- /// A class that performs the initial VPN protocol handshake and version negotiation.
172- class Handshaker : @ unchecked Sendable {
169+ /// An actor performs the initial VPN protocol handshake and version negotiation.
170+ actor Handshaker {
173171 private let writeFD: FileHandle
174172 private let dispatch : DispatchIO
175173 private var theirData : Data = . init( )
@@ -193,17 +191,19 @@ class Handshaker: @unchecked Sendable {
193191 func handshake( ) async throws -> ProtoVersion {
194192 // kick off the read async before we try to write, synchronously, so we don't deadlock, both
195193 // waiting to write with nobody reading.
196- async let theirs = try withCheckedThrowingContinuation { cont in
197- continuation = cont
198- // send in a nil read to kick us off
199- handleRead ( false , nil , 0 )
194+ let readTask = Task {
195+ try await withCheckedThrowingContinuation { cont in
196+ self . continuation = cont
197+ // send in a nil read to kick us off
198+ self . handleRead ( false , nil , 0 )
199+ }
200200 }
201201
202202 let vStr = versions. map { $0. description } . joined ( separator: " , " )
203203 let ours = String ( format: " \( headerPreamble) \( role) \( vStr) \n " )
204204 try writeFD. write ( contentsOf: ours. data ( using: . utf8) !)
205205
206- let theirData = try await theirs
206+ let theirData = try await readTask . value
207207 guard let theirsString = String ( bytes: theirData, encoding: . utf8) else {
208208 throw HandshakeError . invalidHeader ( " <unparsable: \( theirData) " )
209209 }
0 commit comments