Skip to content

Commit e992e2c

Browse files
committed
feat(ws-manual-ack): add ack decorator
1 parent aed0771 commit e992e2c

File tree

7 files changed

+56
-11
lines changed

7 files changed

+56
-11
lines changed

packages/common/interfaces/websockets/web-socket-adapter.interface.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import { Observable } from 'rxjs';
66
export interface WsMessageHandler<T = string> {
77
message: T;
88
callback: (...args: any[]) => Observable<any> | Promise<any>;
9+
isAckHandledManually: boolean;
910
}
1011

1112
/**

packages/platform-socket.io/adapters/io-adapter.ts

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,22 +44,24 @@ export class IoAdapter extends AbstractWsAdapter {
4444
first(),
4545
);
4646

47-
handlers.forEach(({ message, callback }) => {
47+
handlers.forEach(({ message, callback, isAckHandledManually }) => {
4848
const source$ = fromEvent(socket, message).pipe(
4949
mergeMap((payload: any) => {
5050
const { data, ack } = this.mapPayload(payload);
5151
return transform(callback(data, ack)).pipe(
5252
filter((response: any) => !isNil(response)),
53-
map((response: any) => [response, ack]),
53+
map((response: any) => [response, ack, isAckHandledManually]),
5454
);
5555
}),
5656
takeUntil(disconnect$),
5757
);
58-
source$.subscribe(([response, ack]) => {
58+
source$.subscribe(([response, ack, isAckHandledManually]) => {
5959
if (response.event) {
6060
return socket.emit(response.event, response.data);
6161
}
62-
isFunction(ack) && ack(response);
62+
if (!isAckHandledManually && isFunction(ack)) {
63+
ack(response);
64+
}
6365
});
6466
});
6567
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import { WsParamtype } from '../enums/ws-paramtype.enum';
2+
import { createPipesWsParamDecorator } from '../utils/param.utils';
3+
4+
/**
5+
* WebSockets `ack` parameter decorator.
6+
* Extracts the `ack` callback function from the arguments of a ws event.
7+
*
8+
* This decorator signals to the framework that the `ack` callback will be
9+
* handled manually within the method, preventing the framework from
10+
* automatically sending an acknowledgement based on the return value.
11+
*
12+
* @example
13+
* ```typescript
14+
* @SubscribeMessage('events')
15+
* onEvent(
16+
* @MessageBody() data: string,
17+
* @Ack() ack: (response: any) => void
18+
* ) {
19+
* // Manually call the ack callback
20+
* ack({ status: 'ok' });
21+
* }
22+
* ```
23+
*
24+
* @publicApi
25+
*/
26+
export function Ack(): ParameterDecorator {
27+
return createPipesWsParamDecorator(WsParamtype.ACK)();
28+
}

packages/websockets/decorators/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ export * from './gateway-server.decorator';
33
export * from './message-body.decorator';
44
export * from './socket-gateway.decorator';
55
export * from './subscribe-message.decorator';
6+
export * from './ack.decorator';

packages/websockets/factories/ws-params-factory.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { isFunction } from '@nestjs/common/utils/shared.utils';
12
import { WsParamtype } from '../enums/ws-paramtype.enum';
23

34
export class WsParamsFactory {
@@ -14,6 +15,9 @@ export class WsParamsFactory {
1415
return args[0];
1516
case WsParamtype.PAYLOAD:
1617
return data ? args[1]?.[data] : args[1];
18+
case WsParamtype.ACK: {
19+
return args.find(arg => isFunction(arg));
20+
}
1721
default:
1822
return null;
1923
}

packages/websockets/gateway-metadata-explorer.ts

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import {
1010
import { NestGateway } from './interfaces/nest-gateway.interface';
1111
import { ParamsMetadata } from '@nestjs/core/helpers/interfaces';
1212
import { WsParamtype } from './enums/ws-paramtype.enum';
13+
import { ContextUtils } from '@nestjs/core/helpers/context-utils';
1314

1415
export interface MessageMappingProperties {
1516
message: any;
@@ -19,6 +20,7 @@ export interface MessageMappingProperties {
1920
}
2021

2122
export class GatewayMetadataExplorer {
23+
private readonly contextUtils = new ContextUtils();
2224
constructor(private readonly metadataScanner: MetadataScanner) {}
2325

2426
public explore(instance: NestGateway): MessageMappingProperties[] {
@@ -68,9 +70,12 @@ export class GatewayMetadataExplorer {
6870
if (!paramsMetadata) {
6971
return false;
7072
}
73+
const metadataKeys = Object.keys(paramsMetadata);
74+
return metadataKeys.some(key => {
75+
const type = this.contextUtils.mapParamType(key);
7176

72-
const params = Object.values(paramsMetadata);
73-
return params.some((param: any) => param.type === WsParamtype.ACK);
77+
return (Number(type) as WsParamtype) === WsParamtype.ACK;
78+
});
7479
}
7580

7681
public *scanForServerHooks(instance: NestGateway): IterableIterator<string> {

packages/websockets/web-sockets-controller.ts

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ export class WebSocketsController {
7272
) {
7373
const nativeMessageHandlers = this.metadataExplorer.explore(instance);
7474
const messageHandlers = nativeMessageHandlers.map(
75-
({ callback, message, methodName }) => ({
75+
({ callback, isAckHandledManually, message, methodName }) => ({
7676
message,
7777
methodName,
7878
callback: this.contextCreator.create(
@@ -81,6 +81,7 @@ export class WebSocketsController {
8181
moduleKey,
8282
methodName,
8383
),
84+
isAckHandledManually,
8485
}),
8586
);
8687

@@ -174,10 +175,13 @@ export class WebSocketsController {
174175
instance: NestGateway,
175176
) {
176177
const adapter = this.config.getIoAdapter();
177-
const handlers = subscribersMap.map(({ callback, message }) => ({
178-
message,
179-
callback: callback.bind(instance, client),
180-
}));
178+
const handlers = subscribersMap.map(
179+
({ callback, message, isAckHandledManually }) => ({
180+
message,
181+
callback: callback.bind(instance, client),
182+
isAckHandledManually,
183+
}),
184+
);
181185
adapter.bindMessageHandlers(client, handlers, data =>
182186
fromPromise(this.pickResult(data)).pipe(mergeAll()),
183187
);

0 commit comments

Comments
 (0)