diff --git a/packages/core/router/router-execution-context.ts b/packages/core/router/router-execution-context.ts index d1e98383402..9f783f06e7d 100644 --- a/packages/core/router/router-execution-context.ts +++ b/packages/core/router/router-execution-context.ts @@ -434,7 +434,7 @@ export class RouterExecutionContext { } const isSseHandler = !!this.reflectSse(callback); if (isSseHandler) { - return < + return async < TResult extends Observable = any, TResponse extends HeaderStream = any, TRequest extends IncomingMessage = any, @@ -443,7 +443,7 @@ export class RouterExecutionContext { res: TResponse, req: TRequest, ) => { - this.responseController.sse( + await this.responseController.sse( result, (res as any).raw || res, (req as any).raw || req, diff --git a/packages/core/router/router-response-controller.ts b/packages/core/router/router-response-controller.ts index e66544a1ded..d30ec33aae9 100644 --- a/packages/core/router/router-response-controller.ts +++ b/packages/core/router/router-response-controller.ts @@ -99,12 +99,12 @@ export class RouterResponseController { this.applicationRef.status(response, statusCode); } - public sse< + public async sse< TInput extends Observable = any, TResponse extends WritableHeaderStream = any, TRequest extends IncomingMessage = any, >( - result: TInput, + result: TInput | Promise, response: TResponse, request: TRequest, options?: { additionalHeaders: AdditionalHeaders }, @@ -114,12 +114,22 @@ export class RouterResponseController { return; } - this.assertObservable(result); + const observableResult = await Promise.resolve(result); + + this.assertObservable(observableResult); const stream = new SseStream(request); - stream.pipe(response, options); - const subscription = result + // Extract custom status code from response if it was set + const customStatusCode = (response as any).statusCode; + const pipeOptions = + customStatusCode && customStatusCode !== 200 + ? { ...options, statusCode: customStatusCode } + : options; + + stream.pipe(response, pipeOptions); + + const subscription = observableResult .pipe( map((message): MessageEvent => { if (isObject(message)) { diff --git a/packages/core/router/sse-stream.ts b/packages/core/router/sse-stream.ts index a354c8a890c..26b5a252202 100644 --- a/packages/core/router/sse-stream.ts +++ b/packages/core/router/sse-stream.ts @@ -66,11 +66,13 @@ export class SseStream extends Transform { destination: T, options?: { additionalHeaders?: AdditionalHeaders; + statusCode?: number; end?: boolean; }, ): T { if (destination.writeHead) { - destination.writeHead(200, { + const statusCode = options?.statusCode ?? 200; + destination.writeHead(statusCode, { ...options?.additionalHeaders, // See https://github.com/dunglas/mercure/blob/master/hub/subscribe.go#L124-L130 'Content-Type': 'text/event-stream', diff --git a/packages/core/test/router/router-response-controller.spec.ts b/packages/core/test/router/router-response-controller.spec.ts index 20c4ed8ba7c..997e335aaee 100644 --- a/packages/core/test/router/router-response-controller.spec.ts +++ b/packages/core/test/router/router-response-controller.spec.ts @@ -263,7 +263,7 @@ describe('RouterResponseController', () => { it('should accept only observables', async () => { const result = Promise.resolve('test'); try { - routerResponseController.sse( + await routerResponseController.sse( result as unknown as any, {} as unknown as ServerResponse, {} as unknown as IncomingMessage, @@ -275,6 +275,76 @@ describe('RouterResponseController', () => { } }); + it('should accept Promise', async () => { + class Sink extends Writable { + private readonly chunks: string[] = []; + + _write( + chunk: any, + encoding: string, + callback: (error?: Error | null) => void, + ): void { + this.chunks.push(chunk); + callback(); + } + + get content() { + return this.chunks.join(''); + } + } + + const written = (stream: Writable) => + new Promise((resolve, reject) => + stream.on('error', reject).on('finish', resolve), + ); + + const result = Promise.resolve(of('test')); + const response = new Sink(); + const request = new PassThrough(); + await routerResponseController.sse( + result, + response as unknown as ServerResponse, + request as unknown as IncomingMessage, + ); + request.destroy(); + await written(response); + expect(response.content).to.eql( + ` +id: 1 +data: test + +`, + ); + }); + + it('should use custom status code from response', async () => { + class SinkWithStatusCode extends Writable { + statusCode = 404; + writeHead = sinon.spy(); + flushHeaders = sinon.spy(); + + _write( + chunk: any, + encoding: string, + callback: (error?: Error | null) => void, + ): void { + callback(); + } + } + + const result = of('test'); + const response = new SinkWithStatusCode(); + const request = new PassThrough(); + await routerResponseController.sse( + result, + response as unknown as ServerResponse, + request as unknown as IncomingMessage, + ); + + expect(response.writeHead.firstCall.args[0]).to.equal(404); + request.destroy(); + }); + it('should write string', async () => { class Sink extends Writable { private readonly chunks: string[] = []; diff --git a/packages/core/test/router/sse-stream.spec.ts b/packages/core/test/router/sse-stream.spec.ts index 4a41e4ecc73..52666806efd 100644 --- a/packages/core/test/router/sse-stream.spec.ts +++ b/packages/core/test/router/sse-stream.spec.ts @@ -160,6 +160,34 @@ data: hello }); }); + it('sets custom status code when provided', callback => { + const sse = new SseStream(); + const sink = new Sink( + (status: number, headers: string | OutgoingHttpHeaders) => { + expect(status).to.equal(404); + callback(); + return sink; + }, + ); + + sse.pipe(sink, { + statusCode: 404, + }); + }); + + it('defaults to 200 status code when not provided', callback => { + const sse = new SseStream(); + const sink = new Sink( + (status: number, headers: string | OutgoingHttpHeaders) => { + expect(status).to.equal(200); + callback(); + return sink; + }, + ); + + sse.pipe(sink); + }); + it('allows an eventsource to connect', callback => { let sse: SseStream; const server = createServer((req, res) => {