@@ -7,13 +7,44 @@ const createMockResponse = () => {
7
7
writeHead : jest . fn < http . ServerResponse [ 'writeHead' ] > ( ) ,
8
8
write : jest . fn < http . ServerResponse [ 'write' ] > ( ) . mockReturnValue ( true ) ,
9
9
on : jest . fn < http . ServerResponse [ 'on' ] > ( ) ,
10
+ end : jest . fn < http . ServerResponse [ 'end' ] > ( ) ,
10
11
} ;
11
12
res . writeHead . mockReturnThis ( ) ;
12
13
res . on . mockReturnThis ( ) ;
13
14
14
15
return res as unknown as http . ServerResponse ;
15
16
} ;
16
17
18
+ const createMockRequest = ( { headers = { } , body } : { headers ?: Record < string , string > , body ?: string } = { } ) => {
19
+ const mockReq = {
20
+ headers,
21
+ body : body ? body : undefined ,
22
+ auth : {
23
+ token : 'test-token' ,
24
+ } ,
25
+ on : jest . fn < http . IncomingMessage [ 'on' ] > ( ) . mockImplementation ( ( event , listener ) => {
26
+ const mockListener = listener as unknown as ( ...args : unknown [ ] ) => void ;
27
+ if ( event === 'data' ) {
28
+ mockListener ( Buffer . from ( body || '' ) as unknown as Error ) ;
29
+ }
30
+ if ( event === 'error' ) {
31
+ mockListener ( new Error ( 'test' ) ) ;
32
+ }
33
+ if ( event === 'end' ) {
34
+ mockListener ( ) ;
35
+ }
36
+ if ( event === 'close' ) {
37
+ setTimeout ( listener , 100 ) ;
38
+ }
39
+ return mockReq ;
40
+ } ) ,
41
+ listeners : jest . fn < http . IncomingMessage [ 'listeners' ] > ( ) ,
42
+ removeListener : jest . fn < http . IncomingMessage [ 'removeListener' ] > ( ) ,
43
+ } as unknown as http . IncomingMessage ;
44
+
45
+ return mockReq ;
46
+ } ;
47
+
17
48
describe ( 'SSEServerTransport' , ( ) => {
18
49
describe ( 'start method' , ( ) => {
19
50
it ( 'should correctly append sessionId to a simple relative endpoint' , async ( ) => {
@@ -106,4 +137,124 @@ describe('SSEServerTransport', () => {
106
137
) ;
107
138
} ) ;
108
139
} ) ;
109
- } ) ;
140
+
141
+ describe ( 'handlePostMessage method' , ( ) => {
142
+ it ( 'should return 500 if server has not started' , async ( ) => {
143
+ const mockReq = createMockRequest ( ) ;
144
+ const mockRes = createMockResponse ( ) ;
145
+ const endpoint = '/messages' ;
146
+ const transport = new SSEServerTransport ( endpoint , mockRes ) ;
147
+
148
+ const error = 'SSE connection not established' ;
149
+ await expect ( transport . handlePostMessage ( mockReq , mockRes ) )
150
+ . rejects . toThrow ( error ) ;
151
+ expect ( mockRes . writeHead ) . toHaveBeenCalledWith ( 500 ) ;
152
+ expect ( mockRes . end ) . toHaveBeenCalledWith ( error ) ;
153
+ } ) ;
154
+
155
+ it ( 'should return 400 if content-type is not application/json' , async ( ) => {
156
+ const mockReq = createMockRequest ( { headers : { 'content-type' : 'text/plain' } } ) ;
157
+ const mockRes = createMockResponse ( ) ;
158
+ const endpoint = '/messages' ;
159
+ const transport = new SSEServerTransport ( endpoint , mockRes ) ;
160
+ await transport . start ( ) ;
161
+
162
+ transport . onerror = jest . fn ( ) ;
163
+ const error = 'Unsupported content-type: text/plain' ;
164
+ await expect ( transport . handlePostMessage ( mockReq , mockRes ) )
165
+ . resolves . toBe ( undefined ) ;
166
+ expect ( mockRes . writeHead ) . toHaveBeenCalledWith ( 400 ) ;
167
+ expect ( mockRes . end ) . toHaveBeenCalledWith ( expect . stringContaining ( error ) ) ;
168
+ expect ( transport . onerror ) . toHaveBeenCalledWith ( new Error ( error ) ) ;
169
+ } ) ;
170
+
171
+ it ( 'should return 400 if message has not a valid schema' , async ( ) => {
172
+ const invalidMessage = JSON . stringify ( {
173
+ // missing jsonrpc field
174
+ method : 'call' ,
175
+ params : [ 1 , 2 , 3 ] ,
176
+ id : 1 ,
177
+ } )
178
+ const mockReq = createMockRequest ( {
179
+ headers : { 'content-type' : 'application/json' } ,
180
+ body : invalidMessage ,
181
+ } ) ;
182
+ const mockRes = createMockResponse ( ) ;
183
+ const endpoint = '/messages' ;
184
+ const transport = new SSEServerTransport ( endpoint , mockRes ) ;
185
+ await transport . start ( ) ;
186
+
187
+ transport . onmessage = jest . fn ( ) ;
188
+ await transport . handlePostMessage ( mockReq , mockRes ) ;
189
+ expect ( mockRes . writeHead ) . toHaveBeenCalledWith ( 400 ) ;
190
+ expect ( transport . onmessage ) . not . toHaveBeenCalled ( ) ;
191
+ expect ( mockRes . end ) . toHaveBeenCalledWith ( `Invalid message: ${ invalidMessage } ` ) ;
192
+ } ) ;
193
+
194
+ it ( 'should return 202 if message has a valid schema' , async ( ) => {
195
+ const validMessage = JSON . stringify ( {
196
+ jsonrpc : "2.0" ,
197
+ method : 'call' ,
198
+ params : {
199
+ a : 1 ,
200
+ b : 2 ,
201
+ c : 3 ,
202
+ } ,
203
+ id : 1
204
+ } )
205
+ const mockReq = createMockRequest ( {
206
+ headers : { 'content-type' : 'application/json' } ,
207
+ body : validMessage ,
208
+ } ) ;
209
+ const mockRes = createMockResponse ( ) ;
210
+ const endpoint = '/messages' ;
211
+ const transport = new SSEServerTransport ( endpoint , mockRes ) ;
212
+ await transport . start ( ) ;
213
+
214
+ transport . onmessage = jest . fn ( ) ;
215
+ await transport . handlePostMessage ( mockReq , mockRes ) ;
216
+ expect ( mockRes . writeHead ) . toHaveBeenCalledWith ( 202 ) ;
217
+ expect ( mockRes . end ) . toHaveBeenCalledWith ( 'Accepted' ) ;
218
+ expect ( transport . onmessage ) . toHaveBeenCalledWith ( {
219
+ jsonrpc : "2.0" ,
220
+ method : 'call' ,
221
+ params : {
222
+ a : 1 ,
223
+ b : 2 ,
224
+ c : 3 ,
225
+ } ,
226
+ id : 1
227
+ } , {
228
+ authInfo : {
229
+ token : 'test-token' ,
230
+ }
231
+ } ) ;
232
+ } ) ;
233
+ } ) ;
234
+
235
+ describe ( 'close method' , ( ) => {
236
+ it ( 'should call onclose' , async ( ) => {
237
+ const mockRes = createMockResponse ( ) ;
238
+ const endpoint = '/messages' ;
239
+ const transport = new SSEServerTransport ( endpoint , mockRes ) ;
240
+ await transport . start ( ) ;
241
+ transport . onclose = jest . fn ( ) ;
242
+ await transport . close ( ) ;
243
+ expect ( transport . onclose ) . toHaveBeenCalled ( ) ;
244
+ } ) ;
245
+ } ) ;
246
+
247
+ describe ( 'send method' , ( ) => {
248
+ it ( 'should call onsend' , async ( ) => {
249
+ const mockRes = createMockResponse ( ) ;
250
+ const endpoint = '/messages' ;
251
+ const transport = new SSEServerTransport ( endpoint , mockRes ) ;
252
+ await transport . start ( ) ;
253
+ expect ( mockRes . write ) . toHaveBeenCalledTimes ( 1 ) ;
254
+ expect ( mockRes . write ) . toHaveBeenCalledWith (
255
+ expect . stringContaining ( 'event: endpoint' ) ) ;
256
+ expect ( mockRes . write ) . toHaveBeenCalledWith (
257
+ expect . stringContaining ( `data: /messages?sessionId=${ transport . sessionId } ` ) ) ;
258
+ } ) ;
259
+ } ) ;
260
+ } ) ;
0 commit comments