1
1
use futures:: { SinkExt , StreamExt } ;
2
+ use thiserror:: Error ;
2
3
3
4
use super :: * ;
4
5
use crate :: model:: {
5
- CancelledNotification , CancelledNotificationParam , ClientInfo , ClientNotification ,
6
- ClientRequest , ClientResult , CreateMessageRequest , CreateMessageRequestParam ,
7
- CreateMessageResult , ListRootsRequest , ListRootsResult , LoggingMessageNotification ,
8
- LoggingMessageNotificationParam , ProgressNotification , ProgressNotificationParam ,
9
- PromptListChangedNotification , ResourceListChangedNotification , ResourceUpdatedNotification ,
10
- ResourceUpdatedNotificationParam , ServerInfo , ServerMessage , ServerNotification , ServerRequest ,
11
- ServerResult , ToolListChangedNotification ,
6
+ CancelledNotification , CancelledNotificationParam , ClientInfo , ClientJsonRpcMessage ,
7
+ ClientMessage , ClientNotification , ClientRequest , ClientResult , CreateMessageRequest ,
8
+ CreateMessageRequestParam , CreateMessageResult , ListRootsRequest , ListRootsResult ,
9
+ LoggingMessageNotification , LoggingMessageNotificationParam , ProgressNotification ,
10
+ ProgressNotificationParam , PromptListChangedNotification , ResourceListChangedNotification ,
11
+ ResourceUpdatedNotification , ResourceUpdatedNotificationParam , ServerInfo , ServerMessage ,
12
+ ServerNotification , ServerRequest , ServerResult , ToolListChangedNotification ,
12
13
} ;
13
14
14
15
#[ derive( Debug , Clone , Copy , Default , PartialEq , Eq ) ]
@@ -26,6 +27,24 @@ impl ServiceRole for RoleServer {
26
27
const IS_CLIENT : bool = false ;
27
28
}
28
29
30
+ /// It represents the error that may occur when serving the server.
31
+ ///
32
+ /// if you want to handle the error, you can use `serve_server_with_ct` or `serve_server` with `Result<RunningService<RoleServer, S>, ServerError>`
33
+ #[ derive( Error , Debug ) ]
34
+ pub enum ServerError {
35
+ #[ error( "expect initialized request, but received: {0:?}" ) ]
36
+ ExpectedInitRequest ( Option < ClientMessage > ) ,
37
+
38
+ #[ error( "expect initialized notification, but received: {0:?}" ) ]
39
+ ExpectedInitNotification ( Option < ClientMessage > ) ,
40
+
41
+ #[ error( "connection closed: {0}" ) ]
42
+ ConnectionClosed ( String ) ,
43
+
44
+ #[ error( "IO error: {0}" ) ]
45
+ Io ( #[ from] std:: io:: Error ) ,
46
+ }
47
+
29
48
pub type ClientSink = Peer < RoleServer > ;
30
49
31
50
impl < S : Service < RoleServer > > ServiceExt < RoleServer > for S {
55
74
serve_server_with_ct ( service, transport, CancellationToken :: new ( ) ) . await
56
75
}
57
76
77
+ /// Helper function to get the next message from the stream
78
+ async fn expect_next_message < S > ( stream : & mut S , context : & str ) -> Result < ClientMessage , ServerError >
79
+ where
80
+ S : StreamExt < Item = ClientJsonRpcMessage > + Unpin ,
81
+ {
82
+ Ok ( stream
83
+ . next ( )
84
+ . await
85
+ . ok_or_else ( || ServerError :: ConnectionClosed ( context. to_string ( ) ) ) ?
86
+ . into_message ( ) )
87
+ }
88
+
89
+ /// Helper function to expect a request from the stream
90
+ async fn expect_request < S > (
91
+ stream : & mut S ,
92
+ context : & str ,
93
+ ) -> Result < ( ClientRequest , RequestId ) , ServerError >
94
+ where
95
+ S : StreamExt < Item = ClientJsonRpcMessage > + Unpin ,
96
+ {
97
+ let msg = expect_next_message ( stream, context) . await ?;
98
+ let msg_clone = msg. clone ( ) ;
99
+ msg. into_request ( )
100
+ . ok_or ( ServerError :: ExpectedInitRequest ( Some ( msg_clone) ) )
101
+ }
102
+
103
+ /// Helper function to expect a notification from the stream
104
+ async fn expect_notification < S > (
105
+ stream : & mut S ,
106
+ context : & str ,
107
+ ) -> Result < ClientNotification , ServerError >
108
+ where
109
+ S : StreamExt < Item = ClientJsonRpcMessage > + Unpin ,
110
+ {
111
+ let msg = expect_next_message ( stream, context) . await ?;
112
+ let msg_clone = msg. clone ( ) ;
113
+ msg. into_notification ( )
114
+ . ok_or ( ServerError :: ExpectedInitNotification ( Some ( msg_clone) ) )
115
+ }
116
+
58
117
pub async fn serve_server_with_ct < S , T , E , A > (
59
118
service : S ,
60
119
transport : T ,
@@ -70,54 +129,45 @@ where
70
129
let mut stream = Box :: pin ( stream) ;
71
130
let id_provider = <Arc < AtomicU32RequestIdProvider > >:: default ( ) ;
72
131
73
- // service
74
- let ( request, id) = stream
75
- . next ( )
132
+ // Convert ServerError to std::io::Error, then to E
133
+ let handle_server_error = |e : ServerError | -> E {
134
+ match e {
135
+ ServerError :: Io ( io_err) => io_err. into ( ) ,
136
+ other => std:: io:: Error :: new ( std:: io:: ErrorKind :: Other , format ! ( "{}" , other) ) . into ( ) ,
137
+ }
138
+ } ;
139
+
140
+ // Get initialize request
141
+ let ( request, id) = expect_request ( & mut stream, "initialized request" )
76
142
. await
77
- . ok_or ( std:: io:: Error :: new (
78
- std:: io:: ErrorKind :: UnexpectedEof ,
79
- "expect initialize request" ,
80
- ) ) ?
81
- . into_message ( )
82
- . into_request ( )
83
- . ok_or ( std:: io:: Error :: new (
84
- std:: io:: ErrorKind :: InvalidData ,
85
- "expect initialize request" ,
86
- ) ) ?;
143
+ . map_err ( handle_server_error) ?;
144
+
87
145
let ClientRequest :: InitializeRequest ( peer_info) = request else {
88
- return Err ( std:: io:: Error :: new (
89
- std:: io:: ErrorKind :: InvalidData ,
90
- "expect initialize request" ,
91
- )
92
- . into ( ) ) ;
146
+ return Err ( handle_server_error ( ServerError :: ExpectedInitRequest ( Some (
147
+ ClientMessage :: Request ( request, id) ,
148
+ ) ) ) ) ;
93
149
} ;
150
+
151
+ // Send initialize response
94
152
let init_response = service. get_info ( ) ;
95
153
sink. send (
96
154
ServerMessage :: Response ( ServerResult :: InitializeResult ( init_response) , id)
97
155
. into_json_rpc_message ( ) ,
98
156
)
99
157
. await ?;
100
- // waiting for notification
101
- let notification = stream
102
- . next ( )
158
+
159
+ // Wait for initialize notification
160
+ let notification = expect_notification ( & mut stream , "initialize notification" )
103
161
. await
104
- . ok_or ( std:: io:: Error :: new (
105
- std:: io:: ErrorKind :: UnexpectedEof ,
106
- "expect initialize notification" ,
107
- ) ) ?
108
- . into_message ( )
109
- . into_notification ( )
110
- . ok_or ( std:: io:: Error :: new (
111
- std:: io:: ErrorKind :: InvalidData ,
112
- "expect initialize notification" ,
113
- ) ) ?;
162
+ . map_err ( handle_server_error) ?;
163
+
114
164
let ClientNotification :: InitializedNotification ( _) = notification else {
115
- return Err ( std:: io:: Error :: new (
116
- std:: io:: ErrorKind :: InvalidData ,
117
- "expect initialize notification" ,
118
- )
119
- . into ( ) ) ;
165
+ return Err ( handle_server_error ( ServerError :: ExpectedInitNotification (
166
+ Some ( ClientMessage :: Notification ( notification) ) ,
167
+ ) ) ) ;
120
168
} ;
169
+
170
+ // Continue processing service
121
171
serve_inner ( service, ( sink, stream) , peer_info. params , id_provider, ct) . await
122
172
}
123
173
0 commit comments