Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions packages/core/router/router-execution-context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ export class RouterExecutionContext {
}
const isSseHandler = !!this.reflectSse(callback);
if (isSseHandler) {
return <
return async <
TResult extends Observable<unknown> = any,
TResponse extends HeaderStream = any,
TRequest extends IncomingMessage = any,
Expand All @@ -443,7 +443,7 @@ export class RouterExecutionContext {
res: TResponse,
req: TRequest,
) => {
this.responseController.sse(
await this.responseController.sse(
result,
(res as any).raw || res,
(req as any).raw || req,
Expand Down
20 changes: 15 additions & 5 deletions packages/core/router/router-response-controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ export class RouterResponseController {
this.applicationRef.status(response, statusCode);
}

public sse<
public async sse<
TInput extends Observable<unknown> = any,
TResponse extends WritableHeaderStream = any,
TRequest extends IncomingMessage = any,
>(
result: TInput,
result: TInput | Promise<TInput>,
response: TResponse,
request: TRequest,
options?: { additionalHeaders: AdditionalHeaders },
Expand All @@ -114,12 +114,22 @@ export class RouterResponseController {
return;
}

this.assertObservable(result);
const observableResult = await Promise.resolve(result);

this.assertObservable(observableResult);

const stream = new SseStream(request);
stream.pipe(response, options);

const subscription = result
// Extract custom status code from response if it was set
const customStatusCode = (response as any).statusCode;
const pipeOptions =
customStatusCode && customStatusCode !== 200
? { ...options, statusCode: customStatusCode }
: options;

stream.pipe(response, pipeOptions);

const subscription = observableResult
.pipe(
map((message): MessageEvent => {
if (isObject(message)) {
Expand Down
4 changes: 3 additions & 1 deletion packages/core/router/sse-stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,13 @@ export class SseStream extends Transform {
destination: T,
options?: {
additionalHeaders?: AdditionalHeaders;
statusCode?: number;
end?: boolean;
},
): T {
if (destination.writeHead) {
destination.writeHead(200, {
const statusCode = options?.statusCode ?? 200;
destination.writeHead(statusCode, {
...options?.additionalHeaders,
// See https://github.com/dunglas/mercure/blob/master/hub/subscribe.go#L124-L130
'Content-Type': 'text/event-stream',
Expand Down
72 changes: 71 additions & 1 deletion packages/core/test/router/router-response-controller.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ describe('RouterResponseController', () => {
it('should accept only observables', async () => {
const result = Promise.resolve('test');
try {
routerResponseController.sse(
await routerResponseController.sse(
result as unknown as any,
{} as unknown as ServerResponse,
{} as unknown as IncomingMessage,
Expand All @@ -275,6 +275,76 @@ describe('RouterResponseController', () => {
}
});

it('should accept Promise<Observable>', async () => {
class Sink extends Writable {
private readonly chunks: string[] = [];

_write(
chunk: any,
encoding: string,
callback: (error?: Error | null) => void,
): void {
this.chunks.push(chunk);
callback();
}

get content() {
return this.chunks.join('');
}
}

const written = (stream: Writable) =>
new Promise((resolve, reject) =>
stream.on('error', reject).on('finish', resolve),
);

const result = Promise.resolve(of('test'));
const response = new Sink();
const request = new PassThrough();
await routerResponseController.sse(
result,
response as unknown as ServerResponse,
request as unknown as IncomingMessage,
);
request.destroy();
await written(response);
expect(response.content).to.eql(
`
id: 1
data: test

`,
);
});

it('should use custom status code from response', async () => {
class SinkWithStatusCode extends Writable {
statusCode = 404;
writeHead = sinon.spy();
flushHeaders = sinon.spy();

_write(
chunk: any,
encoding: string,
callback: (error?: Error | null) => void,
): void {
callback();
}
}

const result = of('test');
const response = new SinkWithStatusCode();
const request = new PassThrough();
await routerResponseController.sse(
result,
response as unknown as ServerResponse,
request as unknown as IncomingMessage,
);

expect(response.writeHead.firstCall.args[0]).to.equal(404);
request.destroy();
});

it('should write string', async () => {
class Sink extends Writable {
private readonly chunks: string[] = [];
Expand Down
28 changes: 28 additions & 0 deletions packages/core/test/router/sse-stream.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,34 @@ data: hello
});
});

it('sets custom status code when provided', callback => {
const sse = new SseStream();
const sink = new Sink(
(status: number, headers: string | OutgoingHttpHeaders) => {
expect(status).to.equal(404);
callback();
return sink;
},
);

sse.pipe(sink, {
statusCode: 404,
});
});

it('defaults to 200 status code when not provided', callback => {
const sse = new SseStream();
const sink = new Sink(
(status: number, headers: string | OutgoingHttpHeaders) => {
expect(status).to.equal(200);
callback();
return sink;
},
);

sse.pipe(sink);
});

it('allows an eventsource to connect', callback => {
let sse: SseStream;
const server = createServer((req, res) => {
Expand Down