diff --git a/src/PolykeyAgent.ts b/src/PolykeyAgent.ts index df55cc239c..226216f18c 100644 --- a/src/PolykeyAgent.ts +++ b/src/PolykeyAgent.ts @@ -386,6 +386,14 @@ class PolykeyAgent { optionsDefaulted.nodes.connectionInitialMaxStreamsBidi, rpcParserBufferSize: optionsDefaulted.nodes.rpcParserBufferSize, rpcCallTimeoutTime: optionsDefaulted.nodes.rpcCallTimeoutTime, + authenticateNetworkForwardCallback: + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + optionsDefaulted.network, + ), + authenticateNetworkReverseCallback: + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + optionsDefaulted.network, + ), logger: logger.getChild(NodeConnectionManager.name), }); nodeManager = new NodeManager({ diff --git a/src/client/ClientService.ts b/src/client/ClientService.ts index 9d48514e73..a75b432514 100644 --- a/src/client/ClientService.ts +++ b/src/client/ClientService.ts @@ -34,8 +34,9 @@ class ClientService { const conn = evt.detail; const streamHandler = (evt: wsEvents.EventWebSocketConnectionStream) => { const stream = evt.detail; + // If the RPCServer is stopping or stopped then we want to reject new streams outright if (!this.rpcServer[running] || this.rpcServer[status] === 'stopping') { - stream.cancel(Error('TMP RPCServer not running')); + stream.cancel(new errors.ErrorClientServiceNotRunning()); return; } this.rpcServer.handleStream(stream); diff --git a/src/nodes/NodeConnectionManager.ts b/src/nodes/NodeConnectionManager.ts index f55e2978e9..27bd2e6f62 100644 --- a/src/nodes/NodeConnectionManager.ts +++ b/src/nodes/NodeConnectionManager.ts @@ -1,6 +1,15 @@ import type { ResourceAcquire } from '@matrixai/resources'; import type { ContextTimed, ContextTimedInput } from '@matrixai/contexts'; import type { QUICConnection } from '@matrixai/quic'; +import type { JSONRPCRequest, JSONRPCResponse } from '@matrixai/rpc'; +import type { + AuthenticateNetworkForwardCallback, + AuthenticateNetworkReverseCallback, + NodeId, + NodeIdString, +} from './types'; +import type { NodesAuthenticateConnectionMessage } from './agent/types'; +import type { AgentServerManifest } from './agent/handlers'; import type KeyRing from '../keys/KeyRing'; import type { CertificatePEM } from '../keys/types'; import type { @@ -10,8 +19,8 @@ import type { Port, TLSConfig, } from '../network/types'; -import type { AgentServerManifest } from './agent/handlers'; -import type { NodeId, NodeIdString } from './types'; +import type { JSONValue } from '../types'; +import { TransformStream } from 'stream/web'; import { events as quicEvents, QUICServer, @@ -42,6 +51,7 @@ import agentClientManifest from './agent/callers'; import * as nodesUtils from './utils'; import * as nodesErrors from './errors'; import * as nodesEvents from './events'; +import * as agentUtils from './agent/utils'; import * as keysUtils from '../keys/utils'; import * as networkUtils from '../network/utils'; import * as utils from '../utils'; @@ -54,9 +64,24 @@ type ConnectionAndTimer = { usageCount: number; }; +enum AuthenticatingState { + PENDING = 1, + SUCCESS = 2, + FAIL = 3, +} + type ConnectionsEntry = { activeConnection: string; connections: Record; + // This tracks the authentication state machine + authenticatedForward: AuthenticatingState; + reasonForward?: Error; + authenticatedReverse: AuthenticatingState; + reasonReverse?: Error; + authenticateComplete: boolean; + authenticatedP: Promise; + authenticatedResolveP: (value: void) => void; + authenticatedRejectP: (reason?: Error) => void; }; type ConnectionInfo = { @@ -76,12 +101,24 @@ const abortPendingConnectionsReason = Symbol( 'abort pending connections reason', ); +const timerCancellationReason = Symbol('timer cancellation reason'); + +const activePunchCancellationReason = Symbol( + 'active punch cancellation reason', +); + +const activeForwardAuthenticateCancellationReason = Symbol( + 'active forward authenticate cancellation reason', +); + +const rpcMethodsWhitelist = ['nodesAuthenticateConnection']; + /** * NodeConnectionManager is a server that manages all node connections. * It manages both initiated and received connections. * * It acts like a phone call system. - * It can maintain mulitple calls to other nodes. + * It can maintain multiple calls to other nodes. * There's no guarantee that we need to make it. * * Node connections make use of the QUIC protocol. @@ -154,7 +191,7 @@ class NodeConnectionManager { public readonly connectionInitialMaxStreamsUni: number; /** - * Max parse buffer size before RPC parser throws an parse error. + * Max parse buffer size before RPC parser throws a parse error. */ public readonly rpcParserBufferSize: number; @@ -189,6 +226,23 @@ class NodeConnectionManager { */ protected rateLimiter = new RateLimiter(60000, 20, 10, 1); + /** + * Used to track the active authentication RPC calls + */ + protected activeForwardAuthenticateCalls = new Map< + string, + PromiseCancellable + >(); + + /** + * Callback used to generate authentication data when making the authentication call + */ + protected authenticateNetworkForwardCallback: AuthenticateNetworkForwardCallback; + /** + * Callback used to authenticate the peer when processing an authentication request from the peer + */ + protected authenticateNetworkReverseCallback: AuthenticateNetworkReverseCallback; + protected logger: Logger; protected keyRing: KeyRing; protected tlsConfig: TLSConfig; @@ -197,8 +251,8 @@ class NodeConnectionManager { protected quicServer: QUICServer; /** - * Data structure to store all NodeConnections. If a connection to a node n does - * not exist, no entry for n will exist in the map. Alternatively, if a + * Data structure to store all NodeConnections. If a connection to a node `N` does + * not exist, no entry for `N` will exist in the map. Alternatively, if a * connection is currently being instantiated by some thread, an entry will * exist in the map, but only with the lock (no connection object). Once a * connection is instantiated, the entry in the map is updated to include the @@ -254,7 +308,7 @@ class NodeConnectionManager { const connectionAndTimer = connectionsEntry.connections[connectionId]; if (connectionAndTimer == null) utils.never('should have a connection'); connectionAndTimer.usageCount += 1; - connectionAndTimer.timer?.cancel(); + connectionAndTimer.timer?.cancel(timerCancellationReason); connectionAndTimer.timer = null; void stream.closedP.finally(() => { connectionAndTimer.usageCount -= 1; @@ -272,6 +326,8 @@ class NodeConnectionManager { await this.destroyConnection(nodeId, false, connectionId), delay, }); + // Prevent unhandled exceptions when cancelling + connectionAndTimer.timer.catch(() => {}); } }); }; @@ -292,7 +348,7 @@ class NodeConnectionManager { }; /** - * Redispatches `QUICSOcket` or `QUICServer` error events as `NodeConnectionManager` error events. + * Redispatches `QUICSocket` or `QUICServer` error events as `NodeConnectionManager` error events. * This should trigger the destruction of the `NodeConnection` through the * `EventNodeConnectionError` -> `EventNodeConnectionClose` event path. */ @@ -308,7 +364,7 @@ class NodeConnectionManager { /** * Handle unexpected stoppage of the QUICSocket. Not expected to happen - * without error but we have it just in case. + * without error, but we have it just in case. */ protected handleEventQUICSocketStopped = ( _evt: quicEvents.EventQUICSocketStopped, @@ -323,7 +379,7 @@ class NodeConnectionManager { /** * Handle unexpected stoppage of the QUICServer. Not expected to happen - * without error but we have it just in case. + * without error, but we have it just in case. */ protected handleEventQUICServerStopped = ( _evt: quicEvents.EventQUICServerStopped, @@ -338,8 +394,8 @@ class NodeConnectionManager { /** * Handles `EventQUICServerConnection` events. These are reverser or server - * peer initated connections that needs to be handled and added to the - * connectio map. + * peer initiated connections that needs to be handled and added to the + * connection map. */ protected handleEventQUICServerConnection = ( evt: quicEvents.EventQUICServerConnection, @@ -386,6 +442,8 @@ class NodeConnectionManager { .nodesConnectionInitialMaxStreamsUni, rpcParserBufferSize = config.defaultsSystem.rpcParserBufferSize, rpcCallTimeoutTime = config.defaultsSystem.rpcCallTimeoutTime, + authenticateNetworkForwardCallback = nodesUtils.nodesAuthenticateConnectionForwardDefault, + authenticateNetworkReverseCallback = nodesUtils.nodesAuthenticateConnectionReverseDefault, logger, }: { keyRing: KeyRing; @@ -403,6 +461,8 @@ class NodeConnectionManager { connectionInitialMaxStreamsUni?: number; rpcParserBufferSize?: number; rpcCallTimeoutTime?: number; + authenticateNetworkForwardCallback?: AuthenticateNetworkForwardCallback; + authenticateNetworkReverseCallback?: AuthenticateNetworkReverseCallback; logger?: Logger; }) { this.logger = logger ?? new Logger(this.constructor.name); @@ -421,6 +481,10 @@ class NodeConnectionManager { this.connectionInitialMaxStreamsUni = connectionInitialMaxStreamsUni; this.rpcParserBufferSize = rpcParserBufferSize; this.rpcCallTimeoutTime = rpcCallTimeoutTime; + this.authenticateNetworkForwardCallback = + authenticateNetworkForwardCallback; + this.authenticateNetworkReverseCallback = + authenticateNetworkReverseCallback; const quicSocket = new QUICSocket({ resolveHostname: () => { @@ -450,7 +514,7 @@ class NodeConnectionManager { }); const rpcServer = new RPCServer({ middlewareFactory: rpcMiddleware.defaultServerMiddlewareWrapper( - undefined, + this.authenticationMiddlewareClient, this.rpcParserBufferSize, ), fromError: networkUtils.fromError, @@ -590,25 +654,40 @@ class NodeConnectionManager { ); this.quicSocket.removeEventListener(EventAll.name, this.handleEventAll); - const destroyPs: Array> = []; - for (const [nodeId] of this.connections) { - // It exists so we want to destroy it - const destroyP = this.destroyConnection( - IdInternal.fromString(nodeId), - force, + const destroyConnectionPs: Array> = []; + const cancelSignallingPs: Array | Promise> = + []; + const authenticationCancelPs: Array> = []; + const cancelAuthenticationPs: Array> = []; + const cancelReason = new nodesErrors.ErrorNodeConnectionManagerStopping(); + for (const [nodeIdString] of this.connections) { + const destroyP = this.authenticateCancel(nodeIdString, cancelReason).then( + async () => { + return await this.destroyConnection( + IdInternal.fromString(nodeIdString), + force, + ); + }, ); - destroyPs.push(destroyP); + destroyConnectionPs.push(destroyP); } - await Promise.all(destroyPs); - const signallingPs: Array | Promise> = []; for (const [, activePunch] of this.activeHolePunchPs) { - signallingPs.push(activePunch); - activePunch.cancel(); + cancelSignallingPs.push(activePunch); + activePunch.cancel(activePunchCancellationReason); } for (const activeSignal of this.activeSignalFinalPs) { - signallingPs.push(activeSignal); + cancelSignallingPs.push(activeSignal); + } + for (const activeForwardAuthenticateCall of this.activeForwardAuthenticateCalls.values()) { + cancelAuthenticationPs.push(activeForwardAuthenticateCall); + activeForwardAuthenticateCall.cancel( + activeForwardAuthenticateCancellationReason, + ); } - await Promise.allSettled(signallingPs); + await Promise.all(destroyConnectionPs); + await Promise.allSettled(cancelSignallingPs); + await Promise.allSettled(authenticationCancelPs); + await Promise.allSettled(cancelAuthenticationPs); await this.quicServer.stop({ force: true }); await this.quicSocket.stop({ force: true }); await this.rpcServer.stop({ force: true }); @@ -616,6 +695,7 @@ class NodeConnectionManager { } /** + * This is the internal acquireConnection for using connections without authentication. * For usage with withF, to acquire a connection * This unique acquire function structure of returning the ResourceAcquire * itself is such that we can pass targetNodeId as a parameter (as opposed to @@ -623,8 +703,7 @@ class NodeConnectionManager { * @param targetNodeId Id of target node to communicate with * @returns ResourceAcquire Resource API for use in with contexts */ - @ready(new nodesErrors.ErrorNodeConnectionManagerNotRunning()) - public acquireConnection( + protected acquireConnectionInternal( targetNodeId: NodeId, ): ResourceAcquire { if (this.keyRing.getNodeId().equals(targetNodeId)) { @@ -647,7 +726,7 @@ class NodeConnectionManager { // Increment usage count, and cancel timer connectionAndTimer.usageCount += 1; - connectionAndTimer.timer?.cancel(); + connectionAndTimer.timer?.cancel(timerCancellationReason); connectionAndTimer.timer = null; // Return tuple of [ResourceRelease, Resource] return [ @@ -667,9 +746,15 @@ class NodeConnectionManager { ); connectionAndTimer.timer = new Timer({ handler: async () => - await this.destroyConnection(targetNodeId, false), + await this.destroyConnection( + targetNodeId, + false, + connectionAndTimer.connection.connectionId, + ), delay, }); + // Prevent unhandled exceptions when cancelling + connectionAndTimer.timer.catch(() => {}); } }, connectionAndTimer.connection, @@ -677,6 +762,23 @@ class NodeConnectionManager { }; } + /** + * For usage with withF, to acquire a connection + * This unique acquire function structure of returning the ResourceAcquire + * itself is such that we can pass targetNodeId as a parameter (as opposed to + * an acquire function with no parameters). + * @param targetNodeId Id of target node to communicate with + * @returns ResourceAcquire Resource API for use in with contexts + */ + public acquireConnection( + targetNodeId: NodeId, + ): ResourceAcquire { + return async () => { + await this.isAuthenticatedP(targetNodeId); + return await this.acquireConnectionInternal(targetNodeId)(); + }; + } + /** * Perform some function on another node over the network with a connection. * Will either retrieve an existing connection, or create a new one if it @@ -712,7 +814,7 @@ class NodeConnectionManager { ): AsyncGenerator { const acquire = this.acquireConnection(targetNodeId); const [release, conn] = await acquire(); - let caughtError; + let caughtError: Error | undefined; try { if (conn == null) utils.never('NodeConnection should exist'); return yield* g(conn); @@ -781,6 +883,13 @@ class NodeConnectionManager { detail: connectionData, }), ); + if (this.isAuthenticated(connectionData.remoteNodeId)) { + this.dispatchEvent( + new nodesEvents.EventNodeConnectionManagerConnectionAuthenticated({ + detail: connectionData, + }), + ); + } return nodeConnection; } @@ -945,13 +1054,32 @@ class NodeConnectionManager { await this.destroyConnection(nodeId, false, connectionId), delay: this.getStickyTimeoutValue(nodeId, true), }); + // Prevent unhandled exceptions when cancelling + newConnAndTimer.timer.catch(() => {}); + const { + p: authenticatedP, + resolveP: authenticatedResolveP, + rejectP: authenticatedRejectP, + } = utils.promise(); + // Prevent unhandled rejections + authenticatedP.then( + () => {}, + () => {}, + ); entry = { activeConnection: connectionId, connections: { [connectionId]: newConnAndTimer, }, + authenticatedForward: AuthenticatingState.PENDING, + authenticatedReverse: AuthenticatingState.PENDING, + authenticateComplete: false, + authenticatedP, + authenticatedResolveP, + authenticatedRejectP, }; this.connections.set(nodeIdString, entry); + this.initiateForwardAuthenticate(nodeId); } else { newConnAndTimer.timer = new Timer({ handler: async () => @@ -961,6 +1089,8 @@ class NodeConnectionManager { entry.activeConnection > connectionId, ), }); + // Prevent unhandled exceptions when cancelling + newConnAndTimer.timer.catch(() => {}); // Updating existing entry entry.connections[connectionId] = newConnAndTimer; // If the new connection ID is less than the old then replace it @@ -1012,14 +1142,27 @@ class NodeConnectionManager { ); destroyPs.push(connAndTimer.connection.destroy({ force })); // Destroying TTL timer - if (connAndTimer.timer != null) connAndTimer.timer.cancel(); + connAndTimer.timer?.cancel(timerCancellationReason); + connAndTimer.timer = null; delete connections[connectionId]; } } // If empty then remove the entry const remainingKeys = Object.keys(connectionsEntry.connections); if (remainingKeys.length === 0) { + // Clean up authentication + await this.authenticateCancel( + targetNodeIdString, + new nodesErrors.ErrorNodeManagerAuthenticationFailed( + 'Connection destroyed before authentication could complete', + ), + ); this.connections.delete(targetNodeIdString); + this.dispatchEvent( + new nodesEvents.EventNodeConnectionManagerConnectionDestroyed({ + detail: targetNodeId, + }), + ); } else { // Check if the active connection was removed. if (connections[connectionsEntry.activeConnection] == null) { @@ -1034,7 +1177,7 @@ class NodeConnectionManager { /** * Will determine how long to keep a node around for. * - * Timeout is scaled linearly from 1 min to 2 hours based on it's bucket. + * Timeout is scaled linearly from 1 min to 2 hours based on its bucket. * The value will be symmetric for two nodes, * they will assign the same timeout for each other. */ @@ -1101,6 +1244,13 @@ class NodeConnectionManager { detail: connectionData, }), ); + if (this.isAuthenticated(nodeId)) { + this.dispatchEvent( + new nodesEvents.EventNodeConnectionManagerConnectionAuthenticated({ + detail: connectionData, + }), + ); + } } /** @@ -1150,7 +1300,7 @@ class NodeConnectionManager { try { while (true) { const message = keysUtils.getRandomBytes(32); - // Since the intention is to abstract away the success/failure of the holepunch operation, + // Since the intention is to abstract away the success/failure of the hole-punch operation, // We should catch any errors thrown out of this, as the caller does not expect the method to throw await this.quicSocket .send(Buffer.from(message), port, host) @@ -1236,7 +1386,7 @@ class NodeConnectionManager { * Will validate the message, and initiate hole punching in the background and return immediately. * Attempts to the same host and port are coalesced. * Attempts to the same host are limited by a semaphore. - * Active attempts are tracked inside of the `activeHolePunchPs` set and are cancelled and awaited when the + * Active attempts are tracked inside the `activeHolePunchPs` set and are cancelled and awaited when the * `NodeConnectionManager` stops. */ @ready(new nodesErrors.ErrorNodeManagerNotRunning()) @@ -1252,17 +1402,24 @@ class NodeConnectionManager { } const holePunchAttempt = new PromiseCancellable( async (res, rej, signal) => { - await semaphore!.withF(async () => { - this.holePunch(host, port, { signal }) - .finally(() => { - this.activeHolePunchPs.delete(id); - if (semaphore!.count === 0) { - this.activeHolePunchAddresses.delete(host); - } - }) - .then(res, rej); - }); + await semaphore! + .withF(async () => { + await this.holePunch(host, port, { signal }); + }) + .finally(() => { + this.activeHolePunchPs.delete(id); + if (semaphore!.count === 0) { + this.activeHolePunchAddresses.delete(host); + } + }) + .then(res, rej); }, + ).finally(() => { + this.activeHolePunchPs.delete(id); + }); + holePunchAttempt.then( + () => {}, + () => {}, ); // Prevent promise rejection leak void holePunchAttempt.catch(() => {}); @@ -1396,6 +1553,353 @@ class NodeConnectionManager { return entryRecord; }); } + + public forwardAuthenticate( + nodeId: NodeId, + ctx?: Partial, + ): PromiseCancellable; + @timedCancellable( + true, + (nodeConnectionManager: NodeConnectionManager) => + nodeConnectionManager.connectionConnectTimeoutTime, + ) + public async forwardAuthenticate( + nodeId: NodeId, + @context ctx: ContextTimed, + ): Promise { + const targetNodeIdString = nodeId.toString() as NodeIdString; + const connectionsEntry = this.connections.get(targetNodeIdString); + if (connectionsEntry == null) { + throw new nodesErrors.ErrorNodeConnectionManagerConnectionNotFound(); + } + // Need to make an authenticate request here. Get the connection and RPC. + try { + const authenticateMessage = + await this.authenticateNetworkForwardCallback(ctx); + await withF([this.acquireConnectionInternal(nodeId)], async ([conn]) => { + await conn.rpcClient.methods.nodesAuthenticateConnection( + authenticateMessage, + ctx, + ); + }); + connectionsEntry.authenticatedForward = AuthenticatingState.SUCCESS; + } catch (e) { + const err = new nodesErrors.ErrorNodeManagerAuthenticationFailedForward( + undefined, + { cause: e }, + ); + connectionsEntry.authenticatedForward = AuthenticatingState.FAIL; + connectionsEntry.reasonForward = err; + this.authenticateFail(targetNodeIdString); + return; + } + // Check the reverse result + switch (connectionsEntry.authenticatedReverse) { + case AuthenticatingState.SUCCESS: + // Authentication succeeded + connectionsEntry.authenticatedResolveP(); + connectionsEntry.authenticateComplete = true; + // Dispatching authenticated events for every active connection + for (const connAndTimer of Object.values( + connectionsEntry.connections, + )) { + const connectionData: ConnectionData = { + remoteNodeId: connAndTimer.connection.nodeId, + remoteHost: connAndTimer.connection.host, + remotePort: connAndTimer.connection.port, + }; + this.dispatchEvent( + new nodesEvents.EventNodeConnectionManagerConnectionAuthenticated({ + detail: connectionData, + }), + ); + } + return; + case AuthenticatingState.FAIL: + // Authenticating failed + this.authenticateFail(targetNodeIdString); + return; + case AuthenticatingState.PENDING: + return; + default: + utils.never('authenticatedReverse has invalid state'); + } + } + + public handleReverseAuthenticate( + nodeId: NodeId, + message: NodesAuthenticateConnectionMessage, + ctx?: Partial, + ): PromiseCancellable; + @timedCancellable( + true, + (nodeConnectionManager: NodeConnectionManager) => + nodeConnectionManager.connectionConnectTimeoutTime, + ) + public async handleReverseAuthenticate( + nodeId: NodeId, + message: NodesAuthenticateConnectionMessage, + @context ctx: ContextTimed, + ): Promise { + const targetNodeIdString = nodeId.toString() as NodeIdString; + const connectionsEntry = this.connections.get(targetNodeIdString); + if (connectionsEntry == null) { + throw new nodesErrors.ErrorNodeConnectionManagerConnectionNotFound(); + } + try { + // Should resolve without issue if authentication succeeds. + await this.authenticateNetworkReverseCallback(message, ctx); + connectionsEntry.authenticatedReverse = AuthenticatingState.SUCCESS; + } catch (e) { + const err = new nodesErrors.ErrorNodeManagerAuthenticationFailedReverse( + undefined, + { cause: e }, + ); + connectionsEntry.authenticatedReverse = AuthenticatingState.FAIL; + connectionsEntry.reasonReverse = err; + this.authenticateFail(targetNodeIdString); + // Throw back up the RPC + throw err; + } + // Check the forward result + switch (connectionsEntry.authenticatedForward) { + case AuthenticatingState.SUCCESS: + // Authentication succeeded + connectionsEntry.authenticatedResolveP(); + connectionsEntry.authenticateComplete = true; + // Dispatching authenticated events for every active connection + for (const connAndTimer of Object.values( + connectionsEntry.connections, + )) { + const connectionData: ConnectionData = { + remoteNodeId: connAndTimer.connection.nodeId, + remoteHost: connAndTimer.connection.host, + remotePort: connAndTimer.connection.port, + }; + this.dispatchEvent( + new nodesEvents.EventNodeConnectionManagerConnectionAuthenticated({ + detail: connectionData, + }), + ); + } + return; + case AuthenticatingState.FAIL: + // Authenticating failed + this.authenticateFail(targetNodeIdString); + return; + case AuthenticatingState.PENDING: + return; + default: + utils.never('authenticatedForward has invalid state'); + } + } + + /** + * Will initiate a forward authentication call and coalesce + */ + public initiateForwardAuthenticate(nodeId: NodeId) { + // Needs check the map if one is already running, otherwise it needs to start one and manage it. + const nodeIdString = nodeId.toString() as NodeIdString; + const authenticationEntry = this.connections.get(nodeIdString); + if (authenticationEntry == null) { + utils.never('authenticationEntry must be defined'); + } + const existingAuthenticate = + this.activeForwardAuthenticateCalls.get(nodeIdString); + // If it exists in the map then we don't need to start one and can just return + if (existingAuthenticate != null) return; + if ( + authenticationEntry.authenticatedForward !== AuthenticatingState.PENDING + ) { + return; + } + // Otherwise we need to start one and add it to the map + const forwardAuthenticateP = this.forwardAuthenticate(nodeId).finally( + () => { + this.activeForwardAuthenticateCalls.delete(nodeIdString); + }, + ); + // Prevent unhandled errors + forwardAuthenticateP.then( + () => {}, + () => {}, + ); + this.activeForwardAuthenticateCalls.set(nodeIdString, forwardAuthenticateP); + } + + /** + * Returns true if the connection has been authenticated + */ + public isAuthenticated(nodeId: NodeId): boolean { + const targetNodeIdString = nodeId.toString() as NodeIdString; + const connectionsEntry = this.connections.get(targetNodeIdString); + if (connectionsEntry == null) return false; + const forwardAuthenticated = + connectionsEntry.authenticatedForward === AuthenticatingState.SUCCESS; + const reverseAuthenticated = + connectionsEntry.authenticatedReverse === AuthenticatingState.SUCCESS; + return forwardAuthenticated && reverseAuthenticated; + } + + /** + * Returns a promise that resolves once the connection has authenticated, + * otherwise it rejects with the authentication failure + * @param nodeId + */ + public async isAuthenticatedP(nodeId: NodeId): Promise { + const targetNodeIdString = nodeId.toString() as NodeIdString; + const connectionsEntry = this.connections.get(targetNodeIdString); + if (connectionsEntry == null) { + throw new nodesErrors.ErrorNodeConnectionManagerConnectionNotFound(); + } + try { + return await connectionsEntry.authenticatedP; + } catch (e) { + // Capture the stacktrace here since knowing where we're waiting for authentication is more useful + Error.captureStackTrace(e); + throw e; + } + } + + protected authenticateFail(targetNodeIdString: NodeIdString) { + const connectionsEntry = this.connections.get(targetNodeIdString); + if (connectionsEntry == null) { + return; + } + // Wait for both directions of authentication to complete first + if ( + connectionsEntry.authenticatedForward === AuthenticatingState.PENDING || + connectionsEntry.authenticatedReverse === AuthenticatingState.PENDING + ) { + return; + } + // Skip if already completed + if (connectionsEntry.authenticateComplete) { + return; + } + connectionsEntry.authenticateComplete = true; + const authenticatedRejectP = connectionsEntry.authenticatedRejectP; + let reason: Error; + if ( + connectionsEntry.reasonForward != null && + connectionsEntry.reasonReverse != null + ) { + // Both errors + reason = new AggregateError([ + connectionsEntry.reasonForward, + connectionsEntry.reasonReverse, + ]); + } else if (connectionsEntry.reasonForward != null) { + // Just the forward error + reason = connectionsEntry.reasonForward; + } else if (connectionsEntry.reasonReverse != null) { + // Just the reverse error + reason = connectionsEntry.reasonReverse; + } else { + utils.never('No reason was provided'); + } + // Removing authentication entry + authenticatedRejectP( + new nodesErrors.ErrorNodeManagerAuthenticationFailed(undefined, { + cause: reason, + }), + ); + } + + protected async authenticateCancel( + targetNodeIdString: NodeIdString, + reason: Error, + ) { + const authenticationEntry = this.connections.get(targetNodeIdString); + if (authenticationEntry == null) { + return; + } + if (authenticationEntry.authenticateComplete) { + return; + } + if ( + authenticationEntry!.authenticatedForward === AuthenticatingState.PENDING + ) { + authenticationEntry!.authenticatedForward = AuthenticatingState.FAIL; + authenticationEntry!.reasonForward = reason; + } + if ( + authenticationEntry!.authenticatedReverse === AuthenticatingState.PENDING + ) { + authenticationEntry!.authenticatedReverse = AuthenticatingState.FAIL; + authenticationEntry!.reasonReverse = reason; + } + if ( + authenticationEntry!.authenticatedForward === AuthenticatingState.FAIL || + authenticationEntry!.authenticatedReverse === AuthenticatingState.FAIL + ) { + this.authenticateFail(targetNodeIdString); + } + } + + public setAuthenticateNetworkForwardCallback( + authenticateNetworkForwardCallback: AuthenticateNetworkForwardCallback, + ) { + this.authenticateNetworkForwardCallback = + authenticateNetworkForwardCallback; + } + + public setAuthenticateNetworkReverseCallback( + authenticateNetworkReverseCallback: AuthenticateNetworkReverseCallback, + ) { + this.authenticateNetworkReverseCallback = + authenticateNetworkReverseCallback; + } + + protected authenticationMiddlewareClient = ( + _ctx: ContextTimed, + _cancel: (reason?: any) => void, + meta: Record | undefined, + ) => { + const nodeId = agentUtils.nodeIdFromMeta(meta); + if (nodeId == null) utils.never('NodeId should be defined here'); + let isAllowed = this.isAuthenticated(nodeId); + const { + p: waitP, + resolveP: resolveWaitP, + rejectP: rejectWaitP, + } = utils.promise(); + return { + forward: new TransformStream({ + transform: (chunk, controller) => { + if (isAllowed) { + controller.enqueue(chunk); + } else { + if (rpcMethodsWhitelist.includes(chunk.method)) { + // Success + isAllowed = true; + controller.enqueue(chunk); + resolveWaitP(); + return; + } else { + // Fail + const e = new nodesErrors.ErrorNodeConnectionManagerRPCDenied(); + controller.error(e); + rejectWaitP(e); + return; + } + } + }, + }), + reverse: new TransformStream< + JSONRPCResponse, + JSONRPCResponse + >({ + transform: async (chunk, controller) => { + if (!isAllowed) { + await waitP.catch((e) => controller.error(e)); + return; + } + controller.enqueue(chunk); + }, + }), + }; + }; } export default NodeConnectionManager; diff --git a/src/nodes/NodeManager.ts b/src/nodes/NodeManager.ts index 6cfb2b4e4d..363ad25976 100644 --- a/src/nodes/NodeManager.ts +++ b/src/nodes/NodeManager.ts @@ -7,9 +7,9 @@ import type Sigchain from '../sigchain/Sigchain'; import type TaskManager from '../tasks/TaskManager'; import type GestaltGraph from '../gestalts/GestaltGraph'; import type { + Task, TaskHandler, TaskHandlerId, - Task, TaskInfo, } from '../tasks/types'; import type { SignedTokenEncoded } from '../tokens/types'; @@ -23,28 +23,28 @@ import type { import type { ClaimLinkNode } from '../claims/payloads'; import type NodeConnection from '../nodes/NodeConnection'; import type { + AgentClaimMessage, AgentRPCRequestParams, AgentRPCResponseResult, - AgentClaimMessage, } from './agent/types'; import type { - NodeId, NodeAddress, NodeBucket, NodeBucketIndex, NodeContactAddressData, + NodeId, NodeIdEncoded, } from './types'; import type NodeConnectionManager from './NodeConnectionManager'; import type NodeGraph from './NodeGraph'; import type { ServicePOJO } from '@matrixai/mdns'; +import { withF } from '@matrixai/resources'; +import { events as mdnsEvents, MDNS, utils as mdnsUtils } from '@matrixai/mdns'; import Logger from '@matrixai/logger'; -import { StartStop, ready } from '@matrixai/async-init/dist/StartStop'; -import { Semaphore, Lock } from '@matrixai/async-locks'; +import { ready, StartStop } from '@matrixai/async-init/dist/StartStop'; +import { Lock, LockBox, Semaphore } from '@matrixai/async-locks'; import { IdInternal } from '@matrixai/id'; -import { timedCancellable, context } from '@matrixai/contexts/dist/decorators'; -import { withF } from '@matrixai/resources'; -import { MDNS, events as mdnsEvents, utils as mdnsUtils } from '@matrixai/mdns'; +import { context, timedCancellable } from '@matrixai/contexts/dist/decorators'; import * as nodesUtils from './utils'; import * as nodesEvents from './events'; import * as nodesErrors from './errors'; @@ -122,6 +122,11 @@ class NodeManager { protected concurrencyLimit = 3; protected dnsServers: Array | undefined = undefined; + /** + * Used to track locks for authentication failure and acquiring connections + */ + protected connectionLockBox: LockBox = new LockBox(); + protected refreshBucketHandler: TaskHandler = async ( ctx, _taskInfo, @@ -303,8 +308,8 @@ class NodeManager { public readonly syncNodeGraphHandlerId: TaskHandlerId = `${this.tasksPath}.syncNodeGraphHandler` as TaskHandlerId; - protected handleEventNodeConnectionManagerConnection = async ( - e: nodesEvents.EventNodeConnectionManagerConnection, + protected handleEventNodeConnectionManagerConnectionAuthenticated = async ( + e: nodesEvents.EventNodeConnectionManagerConnectionAuthenticated, ) => { await this.setNode( e.detail.remoteNodeId, @@ -433,8 +438,8 @@ class NodeManager { } // Add handling for connections this.nodeConnectionManager.addEventListener( - nodesEvents.EventNodeConnectionManagerConnection.name, - this.handleEventNodeConnectionManagerConnection, + nodesEvents.EventNodeConnectionManagerConnectionAuthenticated.name, + this.handleEventNodeConnectionManagerConnectionAuthenticated, ); this.logger.info(`Started ${this.constructor.name}`); } @@ -443,8 +448,8 @@ class NodeManager { this.logger.info(`Stopping ${this.constructor.name}`); // Remove handling for connections this.nodeConnectionManager.removeEventListener( - nodesEvents.EventNodeConnectionManagerConnection.name, - this.handleEventNodeConnectionManagerConnection, + nodesEvents.EventNodeConnectionManagerConnectionAuthenticated.name, + this.handleEventNodeConnectionManagerConnectionAuthenticated, ); await this.mdns?.stop(); await this.stopTasks(); @@ -469,24 +474,29 @@ class NodeManager { ctx?: Partial, ): ResourceAcquire { if (this.keyRing.getNodeId().equals(nodeId)) { - this.logger.warn('Attempting connection to our own NodeId'); throw new nodesErrors.ErrorNodeManagerNodeIdOwn(); } return async () => { - // Checking if connection already exists - if (!this.nodeConnectionManager.hasConnection(nodeId)) { - // Establish the connection - const result = await this.findNode( - { - nodeId: nodeId, - }, - ctx, - ); - if (result == null) { - throw new nodesErrors.ErrorNodeManagerConnectionFailed(); - } - } - return await this.nodeConnectionManager.acquireConnection(nodeId)(); + return await this.connectionLockBox.withF( + [['acquireConnection', nodeId.toString()].join('.'), Lock], + async () => { + // Checking if connection already exists + if (!this.nodeConnectionManager.hasConnection(nodeId)) { + // Establish the connection + const result = await this.findNode( + { + nodeId: nodeId, + }, + ctx, + ); + if (result == null) { + throw new nodesErrors.ErrorNodeManagerConnectionFailed(); + } + } + // Initiate authentication and await + return await this.nodeConnectionManager.acquireConnection(nodeId)(); + }, + ); }; } @@ -535,7 +545,7 @@ class NodeManager { ): AsyncGenerator { const acquire = this.acquireConnection(nodeId, ctx); const [release, conn] = await acquire(); - let caughtError; + let caughtError: Error | undefined; try { if (conn == null) utils.never('NodeConnection should exist'); return yield* g(conn); @@ -637,7 +647,16 @@ class NodeManager { try { return await Promise.any([findBySignal, findByDirect, findByMDNS]); } catch (e) { - // FIXME: check error type and throw if not connection related failure + if (e instanceof AggregateError) { + for (const error of e.errors) { + // Checking if each error is an expected error + if (!(error instanceof nodesErrors.ErrorNodeManagerFindNodeFailed)) { + throw e; + } + } + } else if (!(e instanceof nodesErrors.ErrorNodeManagerFindNodeFailed)) { + throw e; + } return; } finally { abortController.abort(abortPendingConnectionsReason); @@ -1578,7 +1597,7 @@ class NodeManager { } catch { continue; } - // No need to check if local claims are correctly signed by an Network Authority. + // No need to check if local claims are correctly signed by a Network Authority. if ( authorityToken.verifyWithPublicKey( keysUtils.publicKeyFromNodeId( @@ -1657,8 +1676,7 @@ class NodeManager { ); } - // Need to await node connection verification. If failed, need to reject - // connection. + // Need to await node connection verification, if failed, need to reject connection. // When adding a node we need to handle 3 cases // 1. The node already exists. We need to update it's last updated field @@ -2177,7 +2195,7 @@ class NodeManager { * * From the spec: * To join the network, a node u must have a contact to an already participating node w. u inserts w into the - * appropriate k-bucket. u then performs a node lookup for its own node ID. Finally, u refreshes all kbuckets further + * appropriate k-bucket. u then performs a node lookup for its own node ID. Finally, u refreshes all k-buckets further * away than its closest neighbor. During the refreshes, u both populates its own k-buckets and inserts itself into * other nodes’ k-buckets as necessary. * diff --git a/src/nodes/agent/callers/index.ts b/src/nodes/agent/callers/index.ts index e213dbb4ec..c88bf3da2a 100644 --- a/src/nodes/agent/callers/index.ts +++ b/src/nodes/agent/callers/index.ts @@ -1,3 +1,4 @@ +import nodesAuthenticateConnection from './nodesAuthenticateConnection'; import nodesClaimsGet from './nodesClaimsGet'; import nodesClosestActiveConnectionsGet from './nodesClosestActiveConnectionsGet'; import nodesClosestLocalNodesGet from './nodesClosestLocalNodesGet'; @@ -15,6 +16,7 @@ import vaultsScan from './vaultsScan'; * Client manifest */ const manifestClient = { + nodesAuthenticateConnection, nodesClaimsGet, nodesClosestActiveConnectionsGet, nodesClosestLocalNodesGet, @@ -34,6 +36,7 @@ type AgentClientManifest = typeof manifestClient; export default manifestClient; export { + nodesAuthenticateConnection, nodesClaimsGet, nodesClosestActiveConnectionsGet, nodesClosestLocalNodesGet, diff --git a/src/nodes/agent/callers/nodesAuthenticateConnection.ts b/src/nodes/agent/callers/nodesAuthenticateConnection.ts new file mode 100644 index 0000000000..d8814a60cb --- /dev/null +++ b/src/nodes/agent/callers/nodesAuthenticateConnection.ts @@ -0,0 +1,12 @@ +import type { HandlerTypes } from '@matrixai/rpc'; +import type NodesAuthenticateConnection from '../handlers/NodesAuthenticateConnection'; +import { UnaryCaller } from '@matrixai/rpc'; + +type CallerTypes = HandlerTypes; + +const nodesAuthenticateConnection = new UnaryCaller< + CallerTypes['input'], + CallerTypes['output'] +>(); + +export default nodesAuthenticateConnection; diff --git a/src/nodes/agent/handlers/NodesAuthenticateConnection.ts b/src/nodes/agent/handlers/NodesAuthenticateConnection.ts new file mode 100644 index 0000000000..a4fbedfdeb --- /dev/null +++ b/src/nodes/agent/handlers/NodesAuthenticateConnection.ts @@ -0,0 +1,45 @@ +import type { + AgentRPCRequestParams, + AgentRPCResponseResult, + NodesAuthenticateConnectionMessage, + SuccessMessage, +} from '../types'; +import type NodeConnectionManager from '../../../nodes/NodeConnectionManager'; +import type { JSONValue } from '../../../types'; +import type { ContextTimed } from '@matrixai/contexts'; +import { UnaryHandler } from '@matrixai/rpc'; +import * as agentErrors from '../errors'; +import * as agentUtils from '../utils'; + +class NodesAuthenticateConnection extends UnaryHandler< + { + nodeConnectionManager: NodeConnectionManager; + }, + AgentRPCRequestParams, + AgentRPCResponseResult +> { + public handle = async ( + input: AgentRPCRequestParams, + _cancel, + meta: Record | undefined, + ctx: ContextTimed, + ): Promise> => { + const { nodeConnectionManager } = this.container; + // Connections should always be validated + const requestingNodeId = agentUtils.nodeIdFromMeta(meta); + if (requestingNodeId == null) { + throw new agentErrors.ErrorAgentNodeIdMissing(); + } + await nodeConnectionManager.handleReverseAuthenticate( + requestingNodeId, + input, + ctx, + ); + return { + type: 'success', + success: true, + }; + }; +} + +export default NodesAuthenticateConnection; diff --git a/src/nodes/agent/handlers/index.ts b/src/nodes/agent/handlers/index.ts index 1a212c510d..4c0af86c5b 100644 --- a/src/nodes/agent/handlers/index.ts +++ b/src/nodes/agent/handlers/index.ts @@ -8,6 +8,7 @@ import type NodeManager from '../../../nodes/NodeManager'; import type NodeConnectionManager from '../../../nodes/NodeConnectionManager'; import type NotificationsManager from '../../../notifications/NotificationsManager'; import type VaultManager from '../../../vaults/VaultManager'; +import NodesAuthenticateConnection from './NodesAuthenticateConnection'; import NodesClaimsGet from './NodesClaimsGet'; import NodesClosestActiveConnectionsGet from './NodesClosestActiveConnectionsGet'; import NodesClosestLocalNodesGet from './NodesClosestLocalNodesGet'; @@ -36,6 +37,7 @@ const manifestServer = (container: { vaultManager: VaultManager; }) => { return { + nodesAuthenticateConnection: new NodesAuthenticateConnection(container), nodesClaimsGet: new NodesClaimsGet(container), nodesClosestActiveConnectionsGet: new NodesClosestActiveConnectionsGet( container, @@ -57,6 +59,7 @@ type AgentServerManifest = ReturnType; export default manifestServer; export { + NodesAuthenticateConnection, NodesClaimsGet, NodesClosestActiveConnectionsGet, NodesClosestLocalNodesGet, diff --git a/src/nodes/agent/types.ts b/src/nodes/agent/types.ts index abec3d5a49..32e26c4f7f 100644 --- a/src/nodes/agent/types.ts +++ b/src/nodes/agent/types.ts @@ -8,7 +8,7 @@ import type { ClaimIdEncoded, NodeIdEncoded, VaultIdEncoded } from '../../ids'; import type { VaultAction, VaultName } from '../../vaults/types'; import type { SignedNotification } from '../../notifications/types'; import type { Host, Hostname, Port } from '../../network/types'; -import type { NodeContact } from '../../nodes/types'; +import type { NetworkId, NodeContact } from '../../nodes/types'; type AgentRPCRequestParams = JSONRPCRequestParams; @@ -77,6 +77,23 @@ type VaultsScanMessage = VaultInfo & { vaultPermissions: Array; }; +type SuccessMessage = { + type: 'success'; + success: boolean; +}; + +type NodesAuthenticateConnectionMessage = + | NodesAuthenticateConnectionMessageBasicPublic + | NodesAuthenticateConnectionMessageNone; + +type NodesAuthenticateConnectionMessageBasicPublic = { + type: 'NodesAuthenticateConnectionMessageBasicPublic'; + networkId: NetworkId; +}; +type NodesAuthenticateConnectionMessageNone = { + type: 'NodesAuthenticateConnectionMessageNone'; +}; + export type { AgentRPCRequestParams, AgentRPCResponseResult, @@ -91,4 +108,8 @@ export type { SignedNotificationEncoded, VaultInfo, VaultsScanMessage, + SuccessMessage, + NodesAuthenticateConnectionMessage, + NodesAuthenticateConnectionMessageBasicPublic, + NodesAuthenticateConnectionMessageNone, }; diff --git a/src/nodes/errors.ts b/src/nodes/errors.ts index e5fae11982..76dbbcc737 100644 --- a/src/nodes/errors.ts +++ b/src/nodes/errors.ts @@ -35,6 +35,34 @@ class ErrorNodeManagerSyncNodeGraphFailed extends ErrorNodeManager { exitCode = sysexits.TEMPFAIL; } +class ErrorNodeManagerAuthenticationCallbackNotProvided< + T, +> extends ErrorNodeManager { + static description = 'Authentication callback was not provided'; + exitCode = sysexits.USAGE; +} + +class ErrorNodeManagerAuthenticationFailed extends ErrorNodeManager { + static description = + 'Node connection failed to authenticate, authentication message or token was not valid'; + exitCode = sysexits.NOPERM; +} + +class ErrorNodeManagerAuthenticationFailedForward extends ErrorNodes { + static description = 'Failed to complete forward authentication'; + exitCode = sysexits.USAGE; +} + +class ErrorNodeManagerAuthenticationFailedReverse extends ErrorNodes { + static description = 'Failed to complete reverse authentication'; + exitCode = sysexits.USAGE; +} + +class ErrorNodeManagerAuthenticatonTimedOut extends ErrorNodes { + static description = 'Failed to complete authentication before timing out'; + exitCode = sysexits.USAGE; +} + class ErrorNodeGraph extends ErrorNodes {} class ErrorNodeGraphRunning extends ErrorNodeGraph { @@ -200,6 +228,13 @@ class ErrorNodeConnectionManagerSignalFailed< exitCode = sysexits.TEMPFAIL; } +class ErrorNodeConnectionManagerRPCDenied< + T, +> extends ErrorNodeConnectionManager { + static description = 'RPC call was denied due to being unauthenticated'; + exitCode = sysexits.USAGE; +} + class ErrorNodePingFailed extends ErrorNodes { static description = 'Failed to ping the node when attempting to authenticate'; @@ -216,6 +251,11 @@ class ErrorNodeLookupNotFound extends ErrorNodes { exitCode = sysexits.NOHOST; } +class ErrorNodeAuthenticationFailed extends ErrorNodes { + static description = 'Node failed to authenticate'; + exitCode = sysexits.NOPERM; +} + export { ErrorNodes, ErrorNodeManager, @@ -225,6 +265,11 @@ export { ErrorNodeManagerFindNodeFailed, ErrorNodeManagerResolveNodeFailed, ErrorNodeManagerSyncNodeGraphFailed, + ErrorNodeManagerAuthenticationCallbackNotProvided, + ErrorNodeManagerAuthenticationFailed, + ErrorNodeManagerAuthenticationFailedForward, + ErrorNodeManagerAuthenticationFailedReverse, + ErrorNodeManagerAuthenticatonTimedOut, ErrorNodeGraph, ErrorNodeGraphRunning, ErrorNodeGraphNotRunning, @@ -254,7 +299,9 @@ export { ErrorNodeConnectionManagerConnectionNotFound, ErrorNodeConnectionManagerRequestRateExceeded, ErrorNodeConnectionManagerSignalFailed, + ErrorNodeConnectionManagerRPCDenied, ErrorNodePingFailed, ErrorNodePermissionDenied, ErrorNodeLookupNotFound, + ErrorNodeAuthenticationFailed, }; diff --git a/src/nodes/events.ts b/src/nodes/events.ts index bca173aadf..93cccd0a47 100644 --- a/src/nodes/events.ts +++ b/src/nodes/events.ts @@ -1,5 +1,6 @@ import type { QUICStream } from '@matrixai/quic'; import type { ConnectionData } from '../network/types'; +import type { NodeId } from '../ids/types'; import EventPolykey from '../EventPolykey'; abstract class EventNode extends EventPolykey {} @@ -36,6 +37,10 @@ class EventNodeConnectionManagerConnectionForward extends EventNodeConnectionMan class EventNodeConnectionManagerConnectionReverse extends EventNodeConnectionManagerConnection {} +class EventNodeConnectionManagerConnectionAuthenticated extends EventNodeConnectionManagerConnection {} + +class EventNodeConnectionManagerConnectionDestroyed extends EventNodeConnectionManager {} + abstract class EventNodeGraph extends EventPolykey {} class EventNodeGraphStart extends EventNodeGraph {} @@ -78,6 +83,8 @@ export { EventNodeConnectionManagerConnection, EventNodeConnectionManagerConnectionForward, EventNodeConnectionManagerConnectionReverse, + EventNodeConnectionManagerConnectionAuthenticated, + EventNodeConnectionManagerConnectionDestroyed, EventNodeGraph, EventNodeGraphStart, EventNodeGraphStarted, diff --git a/src/nodes/types.ts b/src/nodes/types.ts index 8add7555bd..ac08a6ffd2 100644 --- a/src/nodes/types.ts +++ b/src/nodes/types.ts @@ -1,5 +1,7 @@ +import type { ContextTimed } from '@matrixai/contexts'; import type { NodeId, NodeIdString, NodeIdEncoded } from '../ids/types'; import type { Host, Hostname, Port } from '../network/types'; +import type { NodesAuthenticateConnectionMessage } from '../nodes/agent/types'; import type { Opaque } from '../types'; /** @@ -71,6 +73,19 @@ enum ConnectionErrorReason { ForceClose = 'NodeConnection is forcing destruction', } +type NetworkId = string; +type AuthenticateNetworkForwardCallback = ( + ctx: ContextTimed, +) => Promise; + +/** + * Callback should throw on authentication failure + */ +type AuthenticateNetworkReverseCallback = ( + message: NodesAuthenticateConnectionMessage, + ctx: ContextTimed, +) => Promise; + export type { NodeId, NodeIdString, @@ -85,6 +100,9 @@ export type { NodeBucketMeta, NodeBucket, NodeGraphSpace, + NetworkId, + AuthenticateNetworkForwardCallback, + AuthenticateNetworkReverseCallback, }; export { ConnectionErrorCode, ConnectionErrorReason }; diff --git a/src/nodes/utils.ts b/src/nodes/utils.ts index 76ffd66e15..72a0a3fbe2 100644 --- a/src/nodes/utils.ts +++ b/src/nodes/utils.ts @@ -13,6 +13,11 @@ import type { NodeId, SeedNodes, } from './types'; +import type { + NodesAuthenticateConnectionMessage, + NodesAuthenticateConnectionMessageBasicPublic, + NodesAuthenticateConnectionMessageNone, +} from './agent/types'; import dns from 'dns'; import { utils as dbUtils } from '@matrixai/db'; import { IdInternal } from '@matrixai/id'; @@ -365,6 +370,9 @@ const reasonToCode = (_type: 'read' | 'write', reason?: any): number => { if (reason instanceof rpcErrors.ErrorRPCRemote) return 5; if (reason instanceof rpcErrors.ErrorRPCStreamEnded) return 6; if (reason instanceof rpcErrors.ErrorRPCTimedOut) return 7; + if (reason instanceof nodesErrors.ErrorNodeConnectionManagerRPCDenied) { + return 8; + } return 0; }; @@ -390,6 +398,8 @@ const codeToReason = (_type: 'read' | 'write', code: number): any => { return new rpcErrors.ErrorRPCStreamEnded(); case 7: return new rpcErrors.ErrorRPCTimedOut(); + case 8: + return new nodesErrors.ErrorNodeConnectionManagerRPCDenied(); // Base cases case 0: return new nodesErrors.ErrorNodeConnectionTransportGenericError(); @@ -772,6 +782,51 @@ async function* collectNodeContacts( if (nodeId != null) yield [nodeId, nodeContact]; } +// Authentication utils +async function nodesAuthenticateConnectionForwardDefault(): Promise { + return { + type: 'NodesAuthenticateConnectionMessageNone', + }; +} + +async function nodesAuthenticateConnectionReverseDefault(): Promise { + return; +} + +function nodesAuthenticateConnectionForwardBasicPublicFactory( + networkId: string, +) { + return async (): Promise => { + return { + type: 'NodesAuthenticateConnectionMessageBasicPublic', + networkId, + }; + }; +} + +function nodesAuthenticateConnectionReverseBasicPublicFactory( + networkId: string, +) { + return async (message: NodesAuthenticateConnectionMessage): Promise => { + if (message.type !== 'NodesAuthenticateConnectionMessageBasicPublic') { + throw new nodesErrors.ErrorNodeAuthenticationFailed( + 'must be basic message', + ); + } + if (message.networkId !== networkId) { + throw new nodesErrors.ErrorNodeAuthenticationFailed( + 'network must be "${networkId}"', + ); + } + }; +} + +async function nodesAuthenticateConnectionReverseDeny() { + throw new nodesErrors.ErrorNodeAuthenticationFailed( + 'All connections are being denied', + ); +} + export { sepBuffer, nodeContactAddress, @@ -806,6 +861,11 @@ export { quicClientCrypto, quicServerCrypto, collectNodeContacts, + nodesAuthenticateConnectionForwardDefault, + nodesAuthenticateConnectionReverseDefault, + nodesAuthenticateConnectionForwardBasicPublicFactory, + nodesAuthenticateConnectionReverseBasicPublicFactory, + nodesAuthenticateConnectionReverseDeny, }; export { encodeNodeId, decodeNodeId } from '../ids'; diff --git a/tests/PolykeyAgent.test.ts b/tests/PolykeyAgent.test.ts index b52a9d7bb7..b7031f0390 100644 --- a/tests/PolykeyAgent.test.ts +++ b/tests/PolykeyAgent.test.ts @@ -12,6 +12,7 @@ import config from '@/config'; import { promise } from '@/utils'; import * as keysUtils from '@/keys/utils'; import * as keysEvents from '@/keys/events'; +import * as testsUtils from './utils'; describe('PolykeyAgent', () => { const password = 'password'; @@ -37,6 +38,7 @@ describe('PolykeyAgent', () => { password, options: { nodePath, + network: testsUtils.testNetworkName, agentServiceHost: localhost, clientServiceHost: localhost, keys: { @@ -64,6 +66,7 @@ describe('PolykeyAgent', () => { password, options: { nodePath, + network: testsUtils.testNetworkName, workers: 0, agentServiceHost: localhost, clientServiceHost: localhost, @@ -108,6 +111,7 @@ describe('PolykeyAgent', () => { password, options: { nodePath, + network: testsUtils.testNetworkName, agentServiceHost: localhost, clientServiceHost: localhost, keys: { @@ -147,6 +151,7 @@ describe('PolykeyAgent', () => { password, options: { nodePath, + network: testsUtils.testNetworkName, agentServiceHost: localhost, clientServiceHost: localhost, keys: { @@ -178,6 +183,7 @@ describe('PolykeyAgent', () => { password, options: { nodePath, + network: testsUtils.testNetworkName, agentServiceHost: localhost, clientServiceHost: localhost, keys: { @@ -203,6 +209,7 @@ describe('PolykeyAgent', () => { password, options: { nodePath, + network: testsUtils.testNetworkName, agentServiceHost: localhost, clientServiceHost: localhost, keys: { @@ -223,6 +230,7 @@ describe('PolykeyAgent', () => { password, options: { nodePath, + network: testsUtils.testNetworkName, agentServiceHost: localhost, clientServiceHost: localhost, keys: { @@ -258,6 +266,7 @@ describe('PolykeyAgent', () => { password, options: { nodePath, + network: testsUtils.testNetworkName, agentServiceHost: localhost, clientServiceHost: localhost, keys: { @@ -293,6 +302,7 @@ describe('PolykeyAgent', () => { password, options: { nodePath, + network: testsUtils.testNetworkName, agentServiceHost: localhost, clientServiceHost: localhost, keys: { diff --git a/tests/PolykeyClient.test.ts b/tests/PolykeyClient.test.ts index 62efd49aa8..537ecab06f 100644 --- a/tests/PolykeyClient.test.ts +++ b/tests/PolykeyClient.test.ts @@ -19,7 +19,7 @@ import * as keysUtils from '@/keys/utils'; import * as errors from '@/errors'; import * as events from '@/events'; import * as utils from '@/utils'; -import * as testUtils from './utils'; +import * as testsUtils from './utils'; describe(PolykeyClient.name, () => { const logger = new Logger(`${PolykeyClient.name} Test`, LogLevel.WARN, [ @@ -96,6 +96,7 @@ describe(PolykeyClient.name, () => { password, options: { nodePath, + network: testsUtils.testNetworkName, agentServiceHost: localHost, clientServiceHost: localHost, keys: { @@ -201,7 +202,7 @@ describe(PolykeyClient.name, () => { const callP = pkClient.rpcClient.methods.agentStatus({}); // Authentication error await expect(callP).rejects.toThrow(errors.ErrorPolykeyRemote); - await testUtils.expectRemoteError(callP, errors.ErrorClientAuthMissing); + await testsUtils.expectRemoteError(callP, errors.ErrorClientAuthMissing); // Correct auth runs without error await pkClient.rpcClient.methods.agentStatus({ metadata: { diff --git a/tests/client/handlers/agent.test.ts b/tests/client/handlers/agent.test.ts index b44b64e8f1..a88b680877 100644 --- a/tests/client/handlers/agent.test.ts +++ b/tests/client/handlers/agent.test.ts @@ -150,6 +150,7 @@ describe('agentStatus', () => { password, options: { nodePath, + network: testsUtils.testNetworkName, keys: { passwordOpsLimit: keysUtils.passwordOpsLimits.min, passwordMemLimit: keysUtils.passwordMemLimits.min, @@ -268,6 +269,7 @@ describe('agentStop', () => { password, options: { nodePath, + network: testsUtils.testNetworkName, keys: { passwordOpsLimit: keysUtils.passwordOpsLimits.min, passwordMemLimit: keysUtils.passwordMemLimits.min, diff --git a/tests/client/handlers/gestalts.test.ts b/tests/client/handlers/gestalts.test.ts index c2d4311bd8..cdbefcec42 100644 --- a/tests/client/handlers/gestalts.test.ts +++ b/tests/client/handlers/gestalts.test.ts @@ -452,6 +452,14 @@ describe('gestaltsDiscoveryByIdentity', () => { connectionConnectTimeoutTime: 2000, connectionIdleTimeoutTimeMin: 2000, connectionIdleTimeoutTimeScale: 0, + authenticateNetworkForwardCallback: + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + testsUtils.testNetworkName, + ), + authenticateNetworkReverseCallback: + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + testsUtils.testNetworkName, + ), logger: logger.getChild('NodeConnectionManager'), }); nodeManager = new NodeManager({ @@ -631,6 +639,14 @@ describe('gestaltsDiscoveryByNode', () => { connectionConnectTimeoutTime: 2000, connectionIdleTimeoutTimeMin: 2000, connectionIdleTimeoutTimeScale: 0, + authenticateNetworkForwardCallback: + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + testsUtils.testNetworkName, + ), + authenticateNetworkReverseCallback: + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + testsUtils.testNetworkName, + ), logger: logger.getChild('NodeConnectionManager'), }); nodeManager = new NodeManager({ @@ -807,6 +823,14 @@ describe('gestaltsDiscoveryQueue', () => { connectionConnectTimeoutTime: 2000, connectionIdleTimeoutTimeMin: 2000, connectionIdleTimeoutTimeScale: 0, + authenticateNetworkForwardCallback: + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + testsUtils.testNetworkName, + ), + authenticateNetworkReverseCallback: + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + testsUtils.testNetworkName, + ), logger: logger.getChild('NodeConnectionManager'), }); nodeManager = new NodeManager({ @@ -1433,6 +1457,14 @@ describe('gestaltsGestaltTrustByIdentity', () => { connectionConnectTimeoutTime: 2000, connectionIdleTimeoutTimeMin: 2000, connectionIdleTimeoutTimeScale: 0, + authenticateNetworkForwardCallback: + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + testsUtils.testNetworkName, + ), + authenticateNetworkReverseCallback: + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + testsUtils.testNetworkName, + ), logger: logger.getChild('NodeConnectionManager'), }); nodeManager = new NodeManager({ @@ -1740,6 +1772,7 @@ describe('gestaltsGestaltTrustByNode', () => { password, options: { nodePath, + network: testsUtils.testNetworkName, agentServiceHost: localhost, clientServiceHost: localhost, keys: { @@ -1842,6 +1875,14 @@ describe('gestaltsGestaltTrustByNode', () => { connectionConnectTimeoutTime: 2000, connectionIdleTimeoutTimeMin: 2000, connectionIdleTimeoutTimeScale: 0, + authenticateNetworkForwardCallback: + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + testsUtils.testNetworkName, + ), + authenticateNetworkReverseCallback: + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + testsUtils.testNetworkName, + ), logger: logger.getChild('NodeConnectionManager'), }); nodeManager = new NodeManager({ diff --git a/tests/client/handlers/keys.test.ts b/tests/client/handlers/keys.test.ts index eead4030a0..32ab50d639 100644 --- a/tests/client/handlers/keys.test.ts +++ b/tests/client/handlers/keys.test.ts @@ -500,6 +500,7 @@ describe('keysKeyPairRenew', () => { password, options: { nodePath, + network: testsUtils.testNetworkName, agentServiceHost: localhost, clientServiceHost: localhost, keys: { @@ -628,6 +629,7 @@ describe('keysKeyPairReset', () => { password, options: { nodePath, + network: testsUtils.testNetworkName, agentServiceHost: localhost, clientServiceHost: localhost, keys: { diff --git a/tests/client/handlers/notifications.test.ts b/tests/client/handlers/notifications.test.ts index aa52de6e4a..8142955493 100644 --- a/tests/client/handlers/notifications.test.ts +++ b/tests/client/handlers/notifications.test.ts @@ -122,6 +122,14 @@ describe('notificationsInboxClear', () => { connectionConnectTimeoutTime: 2000, connectionIdleTimeoutTimeMin: 2000, connectionIdleTimeoutTimeScale: 0, + authenticateNetworkForwardCallback: + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + testsUtils.testNetworkName, + ), + authenticateNetworkReverseCallback: + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + testsUtils.testNetworkName, + ), logger: logger.getChild('NodeConnectionManager'), }); nodeManager = new NodeManager({ @@ -282,6 +290,14 @@ describe('notificationsInboxRead', () => { connectionConnectTimeoutTime: 2000, connectionIdleTimeoutTimeMin: 2000, connectionIdleTimeoutTimeScale: 0, + authenticateNetworkForwardCallback: + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + testsUtils.testNetworkName, + ), + authenticateNetworkReverseCallback: + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + testsUtils.testNetworkName, + ), logger: logger.getChild('NodeConnectionManager'), }); nodeManager = new NodeManager({ @@ -681,6 +697,14 @@ describe('notificationsInboxRemove', () => { connectionConnectTimeoutTime: 2000, connectionIdleTimeoutTimeMin: 2000, connectionIdleTimeoutTimeScale: 0, + authenticateNetworkForwardCallback: + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + testsUtils.testNetworkName, + ), + authenticateNetworkReverseCallback: + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + testsUtils.testNetworkName, + ), logger: logger.getChild('NodeConnectionManager'), }); nodeManager = new NodeManager({ @@ -844,6 +868,14 @@ describe('notificationsOutboxClear', () => { connectionConnectTimeoutTime: 2000, connectionIdleTimeoutTimeMin: 2000, connectionIdleTimeoutTimeScale: 0, + authenticateNetworkForwardCallback: + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + testsUtils.testNetworkName, + ), + authenticateNetworkReverseCallback: + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + testsUtils.testNetworkName, + ), logger: logger.getChild('NodeConnectionManager'), }); nodeManager = new NodeManager({ @@ -1004,6 +1036,14 @@ describe('notificationsOutboxRead', () => { connectionConnectTimeoutTime: 2000, connectionIdleTimeoutTimeMin: 2000, connectionIdleTimeoutTimeScale: 0, + authenticateNetworkForwardCallback: + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + testsUtils.testNetworkName, + ), + authenticateNetworkReverseCallback: + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + testsUtils.testNetworkName, + ), logger: logger.getChild('NodeConnectionManager'), }); nodeManager = new NodeManager({ @@ -1346,6 +1386,14 @@ describe('notificationsOutboxRemove', () => { connectionConnectTimeoutTime: 2000, connectionIdleTimeoutTimeMin: 2000, connectionIdleTimeoutTimeScale: 0, + authenticateNetworkForwardCallback: + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + testsUtils.testNetworkName, + ), + authenticateNetworkReverseCallback: + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + testsUtils.testNetworkName, + ), logger: logger.getChild('NodeConnectionManager'), }); nodeManager = new NodeManager({ @@ -1506,6 +1554,14 @@ describe('notificationsSend', () => { connectionConnectTimeoutTime: 2000, connectionIdleTimeoutTimeMin: 2000, connectionIdleTimeoutTimeScale: 0, + authenticateNetworkForwardCallback: + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + testsUtils.testNetworkName, + ), + authenticateNetworkReverseCallback: + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + testsUtils.testNetworkName, + ), logger: logger.getChild('NodeConnectionManager'), }); nodeManager = new NodeManager({ diff --git a/tests/discovery/Discovery.test.ts b/tests/discovery/Discovery.test.ts index 739f46dd86..25258ad1e4 100644 --- a/tests/discovery/Discovery.test.ts +++ b/tests/discovery/Discovery.test.ts @@ -23,6 +23,7 @@ import IdentitiesManager from '@/identities/IdentitiesManager'; import NodeConnectionManager from '@/nodes/NodeConnectionManager'; import NodeGraph from '@/nodes/NodeGraph'; import NodeManager from '@/nodes/NodeManager'; +import NodesAuthenticateConnection from '@/nodes/agent/handlers/NodesAuthenticateConnection'; import KeyRing from '@/keys/KeyRing'; import ACL from '@/acl/ACL'; import Sigchain from '@/sigchain/Sigchain'; @@ -36,6 +37,7 @@ import * as testNodesUtils from '../nodes/utils'; import TestProvider from '../identities/TestProvider'; import 'ix/add/asynciterable-operators/toarray'; import { createTLSConfig } from '../utils/tls'; +import * as testsUtils from '../utils'; describe('Discovery', () => { const password = 'password'; @@ -168,6 +170,14 @@ describe('Discovery', () => { connectionConnectTimeoutTime: 2000, connectionIdleTimeoutTimeMin: 2000, connectionIdleTimeoutTimeScale: 0, + authenticateNetworkForwardCallback: + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + testsUtils.testNetworkName, + ), + authenticateNetworkReverseCallback: + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + testsUtils.testNetworkName, + ), logger: logger.getChild('NodeConnectionManager'), }); nodeManager = new NodeManager({ @@ -183,13 +193,18 @@ describe('Discovery', () => { await nodeManager.start(); await nodeConnectionManager.start({ host: localhost as Host, - agentService: {} as AgentServerManifest, + agentService: { + nodesAuthenticateConnection: new NodesAuthenticateConnection({ + nodeConnectionManager: nodeConnectionManager, + }), + } as AgentServerManifest, }); // Set up other gestalt nodeA = await PolykeyAgent.createPolykeyAgent({ password: password, options: { nodePath: path.join(dataDir, 'nodeA'), + network: testsUtils.testNetworkName, agentServiceHost: localhost, clientServiceHost: localhost, keys: { @@ -204,6 +219,7 @@ describe('Discovery', () => { password: password, options: { nodePath: path.join(dataDir, 'nodeB'), + network: testsUtils.testNetworkName, agentServiceHost: localhost, clientServiceHost: localhost, keys: { diff --git a/tests/nodes/NodeConnectionManager.test.ts b/tests/nodes/NodeConnectionManager.test.ts index cfd32f5de5..ef7c2ee225 100644 --- a/tests/nodes/NodeConnectionManager.test.ts +++ b/tests/nodes/NodeConnectionManager.test.ts @@ -3,20 +3,49 @@ import type NodeConnection from '@/nodes/NodeConnection'; import type { AgentServerManifest } from '@/nodes/agent/handlers'; import type { KeyRing } from '@/keys'; import type { NCMState } from './utils'; +import type { JSONValue, ObjectEmpty } from '@'; +import type { + AgentRPCRequestParams, + AgentRPCResponseResult, + NodesAuthenticateConnectionMessage, + SuccessMessage, +} from '@/nodes/agent/types'; +import type { ContextTimed } from '@matrixai/contexts'; import Logger, { formatting, LogLevel, StreamHandler } from '@matrixai/logger'; import { Timer } from '@matrixai/timer'; import { destroyed } from '@matrixai/async-init'; +import { UnaryHandler } from '@matrixai/rpc'; import * as keysUtils from '@/keys/utils'; import * as nodesEvents from '@/nodes/events'; import * as nodesErrors from '@/nodes/errors'; import NodeConnectionManager from '@/nodes/NodeConnectionManager'; +import NodesAuthenticateConnection from '@/nodes/agent/handlers/NodesAuthenticateConnection'; import NodesConnectionSignalFinal from '@/nodes/agent/handlers/NodesConnectionSignalFinal'; import NodesConnectionSignalInitial from '@/nodes/agent/handlers/NodesConnectionSignalInitial'; import * as utils from '@/utils'; +import * as nodesUtils from '@/nodes/utils'; import * as nodesTestUtils from './utils'; import * as keysTestUtils from '../keys/utils'; import * as testsUtils from '../utils'; +class DummyNodesAuthenticateConnection extends UnaryHandler< + ObjectEmpty, + AgentRPCRequestParams, + AgentRPCResponseResult +> { + public handle = async ( + _input: AgentRPCRequestParams, + _cancel, + _meta: Record | undefined, + _ctx: ContextTimed, + ): Promise> => { + return { + type: 'success', + success: true, + }; + }; +} + describe(`${NodeConnectionManager.name}`, () => { const logger = new Logger( `${NodeConnectionManager.name} test`, @@ -91,11 +120,17 @@ describe(`${NodeConnectionManager.name}`, () => { }, startOptions: { host: localHost, - agentService: () => dummyManifest, + agentService: (nodeConnectionManager) => { + return { + nodesAuthenticateConnection: new NodesAuthenticateConnection({ + nodeConnectionManager: nodeConnectionManager, + }), + dummyMethod: new DummyNodesAuthenticateConnection({}), + } as unknown as AgentServerManifest; + }, }, logger: logger.getChild(`${NodeConnectionManager.name}Local`), }); - ncmPeer1 = await nodesTestUtils.nodeConnectionManagerFactory({ keyRing: keysTestUtils.createDummyKeyRing(), createOptions: { @@ -103,7 +138,14 @@ describe(`${NodeConnectionManager.name}`, () => { }, startOptions: { host: localHost, - agentService: () => dummyManifest, + agentService: (nodeConnectionManager) => { + return { + nodesAuthenticateConnection: new NodesAuthenticateConnection({ + nodeConnectionManager: nodeConnectionManager, + }), + dummyMethod: new DummyNodesAuthenticateConnection({}), + } as unknown as AgentServerManifest; + }, }, logger: logger.getChild(`${NodeConnectionManager.name}Peer1`), }); @@ -281,7 +323,7 @@ describe(`${NodeConnectionManager.name}`, () => { ); const connectionPeerDestroyed = testsUtils.promFromEvent( ncmPeer1.nodeConnectionManager, - nodesEvents.EventNodeConnectionDestroyed, + nodesEvents.EventNodeConnectionManagerConnectionDestroyed, ); await ncmLocal.nodeConnectionManager.createConnection( @@ -560,7 +602,6 @@ describe(`${NodeConnectionManager.name}`, () => { expect(connection.address.port).toBe( ncmPeer1.nodeConnectionManager.port, ); - expect(connection.usageCount).toBe(0); } }); test('stopping NodeConnectionManager should destroy all connections', async () => { @@ -645,6 +686,292 @@ describe(`${NodeConnectionManager.name}`, () => { ncmLocal.nodeConnectionManager.hasConnection(ncmPeer1.nodeId), ).toBeFalse(); }); + test('can authenticate a connection', async () => { + ncmLocal.nodeConnectionManager.setAuthenticateNetworkForwardCallback( + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + 'someNetwork', + ), + ); + ncmPeer1.nodeConnectionManager.setAuthenticateNetworkForwardCallback( + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + 'someNetwork', + ), + ); + ncmLocal.nodeConnectionManager.setAuthenticateNetworkReverseCallback( + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + 'someNetwork', + ), + ); + ncmPeer1.nodeConnectionManager.setAuthenticateNetworkReverseCallback( + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + 'someNetwork', + ), + ); + + // Creating connection + await ncmLocal.nodeConnectionManager.createConnection( + [ncmPeer1.nodeId], + localHost, + ncmPeer1.port, + ); + // Checking authentication result + await ncmLocal.nodeConnectionManager.withConnF( + ncmPeer1.nodeId, + async () => { + // Do nothing + }, + ); + await ncmPeer1.nodeConnectionManager.withConnF( + ncmLocal.nodeId, + async () => { + // Do nothing + }, + ); + }); + test('forward authenticate fails on local', async () => { + ncmLocal.nodeConnectionManager.setAuthenticateNetworkForwardCallback( + async () => { + throw Error('Failure to generate forward authentication message'); + }, + ); + ncmPeer1.nodeConnectionManager.setAuthenticateNetworkForwardCallback( + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + 'someNetwork', + ), + ); + ncmLocal.nodeConnectionManager.setAuthenticateNetworkReverseCallback( + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + 'someNetwork', + ), + ); + ncmPeer1.nodeConnectionManager.setAuthenticateNetworkReverseCallback( + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + 'someNetwork', + ), + ); + + // Creating connection + await ncmLocal.nodeConnectionManager.createConnection( + [ncmPeer1.nodeId], + localHost, + ncmPeer1.port, + ); + // Checking authentication result + const authenticationAttemptP = ncmLocal.nodeConnectionManager.withConnF( + ncmPeer1.nodeId, + async () => { + // Do nothing + }, + ); + await expect(authenticationAttemptP).rejects.toThrow( + nodesErrors.ErrorNodeManagerAuthenticationFailed, + ); + + const authenticationAttemptP2 = ncmPeer1.nodeConnectionManager.withConnF( + ncmLocal.nodeId, + async () => { + // Do nothing + }, + ); + await expect(authenticationAttemptP2).rejects.toThrow( + nodesErrors.ErrorNodeManagerAuthenticationFailed, + ); + }); + test('peer sends invalid authentication', async () => { + ncmLocal.nodeConnectionManager.setAuthenticateNetworkForwardCallback( + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + 'someNetwork', + ), + ); + ncmPeer1.nodeConnectionManager.setAuthenticateNetworkForwardCallback( + nodesUtils.nodesAuthenticateConnectionForwardDefault, + ); + ncmLocal.nodeConnectionManager.setAuthenticateNetworkReverseCallback( + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + 'someNetwork', + ), + ); + ncmPeer1.nodeConnectionManager.setAuthenticateNetworkReverseCallback( + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + 'someNetwork', + ), + ); + + // Creating connection + await ncmLocal.nodeConnectionManager.createConnection( + [ncmPeer1.nodeId], + localHost, + ncmPeer1.port, + ); + + const authenticationAttemptP = ncmLocal.nodeConnectionManager.withConnF( + ncmPeer1.nodeId, + async () => { + // Do nothing + }, + ); + await expect(authenticationAttemptP).rejects.toThrow( + nodesErrors.ErrorNodeManagerAuthenticationFailed, + ); + const forwardAuthenticateP = ncmLocal.nodeConnectionManager.withConnF( + ncmPeer1.nodeId, + async () => { + // Do nothing + }, + ); + await expect(forwardAuthenticateP).rejects.toThrow( + nodesErrors.ErrorNodeManagerAuthenticationFailed, + ); + const reverseAuthenticateP = ncmPeer1.nodeConnectionManager.withConnF( + ncmLocal.nodeId, + async () => { + // Do nothing + }, + ); + await expect(reverseAuthenticateP).rejects.toThrow( + nodesErrors.ErrorNodeManagerAuthenticationFailed, + ); + }); + test('reverse authenticate fails on local', async () => { + ncmLocal.nodeConnectionManager.setAuthenticateNetworkForwardCallback( + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + 'someNetwork', + ), + ); + ncmPeer1.nodeConnectionManager.setAuthenticateNetworkForwardCallback( + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + 'someNetwork', + ), + ); + ncmLocal.nodeConnectionManager.setAuthenticateNetworkReverseCallback( + nodesUtils.nodesAuthenticateConnectionReverseDeny, + ); + ncmPeer1.nodeConnectionManager.setAuthenticateNetworkReverseCallback( + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + 'someNetwork', + ), + ); + + // Creating connection + await ncmLocal.nodeConnectionManager.createConnection( + [ncmPeer1.nodeId], + localHost, + ncmPeer1.port, + ); + + const authenticationAttemptP = ncmLocal.nodeConnectionManager.withConnF( + ncmPeer1.nodeId, + async () => { + // Do nothing + }, + ); + await expect(authenticationAttemptP).rejects.toThrow( + nodesErrors.ErrorNodeManagerAuthenticationFailed, + ); + }); + test('reverse authenticate fails on peer', async () => { + ncmLocal.nodeConnectionManager.setAuthenticateNetworkForwardCallback( + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + 'someNetwork', + ), + ); + ncmPeer1.nodeConnectionManager.setAuthenticateNetworkForwardCallback( + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + 'someNetwork', + ), + ); + ncmLocal.nodeConnectionManager.setAuthenticateNetworkReverseCallback( + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + 'someNetwork', + ), + ); + ncmPeer1.nodeConnectionManager.setAuthenticateNetworkReverseCallback( + nodesUtils.nodesAuthenticateConnectionReverseDeny, + ); + + // Creating connection + await ncmLocal.nodeConnectionManager.createConnection( + [ncmPeer1.nodeId], + localHost, + ncmPeer1.port, + ); + + const authenticationAttemptP = ncmLocal.nodeConnectionManager.withConnF( + ncmPeer1.nodeId, + async () => { + // Do nothing + }, + ); + await expect(authenticationAttemptP).rejects.toThrow( + nodesErrors.ErrorNodeManagerAuthenticationFailed, + ); + }); + test('non whitelisted RPC calls are prevented', async () => { + ncmLocal.nodeConnectionManager.setAuthenticateNetworkForwardCallback( + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + 'someNetwork', + ), + ); + ncmPeer1.nodeConnectionManager.setAuthenticateNetworkForwardCallback( + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + 'someNetwork', + ), + ); + ncmLocal.nodeConnectionManager.setAuthenticateNetworkReverseCallback( + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + 'someNetwork', + ), + ); + ncmPeer1.nodeConnectionManager.setAuthenticateNetworkReverseCallback( + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + 'someNetwork', + ), + ); + + // Creating connection + await ncmLocal.nodeConnectionManager.createConnection( + [ncmPeer1.nodeId], + localHost, + ncmPeer1.port, + ); + + // Getting connection directly to avoid initiating authentication + const connection = ncmLocal.nodeConnectionManager.getConnection( + ncmPeer1.nodeId, + ); + expect(connection).toBeDefined(); + await expect( + connection?.connection.rpcClient.unaryCaller('dummyMethod', {}), + ).rejects.toThrow(nodesErrors.ErrorNodeConnectionManagerRPCDenied); + + const forwardAuthenticateP = ncmLocal.nodeConnectionManager.withConnF( + ncmPeer1.nodeId, + async () => { + // Do nothing + }, + ); + await expect(forwardAuthenticateP).toResolve(); + const reverseAuthenticateP = ncmPeer1.nodeConnectionManager.withConnF( + ncmLocal.nodeId, + async () => { + // Do nothing + }, + ); + await expect(reverseAuthenticateP).toResolve(); + + // Checking RPC again + await ncmLocal.nodeConnectionManager.withConnF( + ncmPeer1.nodeId, + async (conn) => { + await expect( + conn.rpcClient.unaryCaller('dummyMethod', {}), + ).resolves.toMatchObject({ + type: 'success', + success: true, + }); + }, + ); + }); }); describe('With 2 peers', () => { let ncmLocal: NCMState; @@ -659,11 +986,16 @@ describe(`${NodeConnectionManager.name}`, () => { }, startOptions: { host: localHost, - agentService: () => dummyManifest, + agentService: (nodeConnectionManager) => { + return { + nodesAuthenticateConnection: new NodesAuthenticateConnection({ + nodeConnectionManager: nodeConnectionManager, + }), + } as unknown as AgentServerManifest; + }, }, logger: logger.getChild(`${NodeConnectionManager.name}Local`), }); - ncmPeer1 = await nodesTestUtils.nodeConnectionManagerFactory({ keyRing: keysTestUtils.createDummyKeyRing(), createOptions: { @@ -673,6 +1005,9 @@ describe(`${NodeConnectionManager.name}`, () => { host: localHost, agentService: (nodeConnectionManager) => ({ + nodesAuthenticateConnection: new NodesAuthenticateConnection({ + nodeConnectionManager: nodeConnectionManager, + }), nodesConnectionSignalFinal: new NodesConnectionSignalFinal({ nodeConnectionManager, logger, @@ -684,7 +1019,6 @@ describe(`${NodeConnectionManager.name}`, () => { }, logger: logger.getChild(`${NodeConnectionManager.name}Peer1`), }); - ncmPeer2 = await nodesTestUtils.nodeConnectionManagerFactory({ keyRing: keysTestUtils.createDummyKeyRing(), createOptions: { @@ -694,6 +1028,9 @@ describe(`${NodeConnectionManager.name}`, () => { host: localHost, agentService: (nodeConnectionManager) => ({ + nodesAuthenticateConnection: new NodesAuthenticateConnection({ + nodeConnectionManager: nodeConnectionManager, + }), nodesConnectionSignalFinal: new NodesConnectionSignalFinal({ nodeConnectionManager, logger, diff --git a/tests/nodes/NodeManager.test.ts b/tests/nodes/NodeManager.test.ts index 925ba431c8..aba4942ed9 100644 --- a/tests/nodes/NodeManager.test.ts +++ b/tests/nodes/NodeManager.test.ts @@ -3,6 +3,14 @@ import type { AgentServerManifest } from '@/nodes/agent/handlers'; import type nodeGraph from '@/nodes/NodeGraph'; import type { NCMState } from './utils'; import type { NodeAddress, NodeContactAddressData } from '@/nodes/types'; +import type { + AgentRPCRequestParams, + AgentRPCResponseResult, + NodesAuthenticateConnectionMessage, + SuccessMessage, +} from '@/nodes/agent/types'; +import type { JSONValue, ObjectEmpty } from '@'; +import type { ContextTimed } from '@matrixai/contexts'; import fs from 'fs'; import path from 'path'; import os from 'os'; @@ -10,6 +18,8 @@ import Logger, { formatting, LogLevel, StreamHandler } from '@matrixai/logger'; import { DB } from '@matrixai/db'; import { Semaphore } from '@matrixai/async-locks'; import { PromiseCancellable } from '@matrixai/async-cancellable'; +import { UnaryHandler } from '@matrixai/rpc'; +import ACL from '@/acl/ACL'; import NodeGraph from '@/nodes/NodeGraph'; import { NodesClaimsGet, @@ -18,9 +28,12 @@ import { } from '@/nodes/agent/handlers'; import * as keysUtils from '@/keys/utils'; import * as nodesErrors from '@/nodes/errors'; +import * as nodesEvents from '@/nodes/events'; import NodeConnectionManager from '@/nodes/NodeConnectionManager'; +import NodesCrossSignClaim from '@/nodes/agent/handlers/NodesCrossSignClaim'; import NodesConnectionSignalFinal from '@/nodes/agent/handlers/NodesConnectionSignalFinal'; import NodesConnectionSignalInitial from '@/nodes/agent/handlers/NodesConnectionSignalInitial'; +import NodesAuthenticateConnection from '@/nodes/agent/handlers/NodesAuthenticateConnection'; import * as nodesUtils from '@/nodes/utils'; import { TaskManager } from '@/tasks'; import { NodeConnection, NodeManager } from '@/nodes'; @@ -31,9 +44,25 @@ import NodeConnectionQueue from '@/nodes/NodeConnectionQueue'; import * as utils from '@/utils'; import { generateNodeIdForBucket } from './utils'; import * as nodesTestUtils from './utils'; -import ACL from '../../src/acl/ACL'; import * as testsUtils from '../utils'; -import NodesCrossSignClaim from '../../src/nodes/agent/handlers/NodesCrossSignClaim'; + +class DummyNodesAuthenticateConnection extends UnaryHandler< + ObjectEmpty, + AgentRPCRequestParams, + AgentRPCResponseResult +> { + public handle = async ( + _input: AgentRPCRequestParams, + _cancel, + _meta: Record | undefined, + _ctx: ContextTimed, + ): Promise> => { + return { + type: 'success', + success: true, + }; + }; +} describe(`${NodeManager.name}`, () => { const logger = new Logger(`${NodeManager.name} test`, LogLevel.WARN, [ @@ -44,6 +73,9 @@ describe(`${NodeManager.name}`, () => { const password = 'password'; const localHost = '127.0.0.1' as Host; const timeoutTime = 300; + const dummyAgentService = { + nodesAuthenticateConnection: new DummyNodesAuthenticateConnection({}), + } as AgentServerManifest; let dataDir: string; @@ -163,11 +195,19 @@ describe(`${NodeManager.name}`, () => { nodeConnectionManager = new NodeConnectionManager({ keyRing, tlsConfig: await testsUtils.createTLSConfig(keyRing.keyPair), + authenticateNetworkForwardCallback: + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + testsUtils.testNetworkName, + ), + authenticateNetworkReverseCallback: + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + testsUtils.testNetworkName, + ), logger: logger.getChild(NodeConnectionManager.name), connectionConnectTimeoutTime: timeoutTime, }); await nodeConnectionManager.start({ - agentService: {} as AgentServerManifest, + agentService: dummyAgentService, host: localHost, }); taskManager = await TaskManager.createTaskManager({ @@ -443,18 +483,21 @@ describe(`${NodeManager.name}`, () => { nodeConnectionManager = new NodeConnectionManager({ keyRing, tlsConfig: await testsUtils.createTLSConfig(keyRing.keyPair), - logger: logger.getChild(NodeConnectionManager.name), + authenticateNetworkForwardCallback: + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + testsUtils.testNetworkName, + ), + authenticateNetworkReverseCallback: + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + testsUtils.testNetworkName, + ), + logger: logger.getChild(`${NodeConnectionManager.name}Local`), connectionConnectTimeoutTime: timeoutTime, }); - await nodeConnectionManager.start({ - agentService: {} as AgentServerManifest, - host: localHost, - }); taskManager = await TaskManager.createTaskManager({ db, logger: logger.getChild(TaskManager.name), }); - nodeManager = new NodeManager({ db, keyRing, @@ -466,6 +509,14 @@ describe(`${NodeManager.name}`, () => { logger: logger.getChild(NodeManager.name), }); await nodeManager.start(); + await nodeConnectionManager.start({ + agentService: { + nodesAuthenticateConnection: new NodesAuthenticateConnection({ + nodeConnectionManager: nodeConnectionManager, + }), + } as AgentServerManifest, + host: localHost, + }); basePathPeer = path.join(dataDir, 'peer'); const keysPathPeer = path.join(basePathPeer, 'keys'); @@ -504,7 +555,15 @@ describe(`${NodeManager.name}`, () => { nodeConnectionManagerPeer = new NodeConnectionManager({ keyRing: keyRingPeer, tlsConfig: await testsUtils.createTLSConfig(keyRingPeer.keyPair), - logger: logger.getChild(NodeConnectionManager.name), + authenticateNetworkForwardCallback: + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + testsUtils.testNetworkName, + ), + authenticateNetworkReverseCallback: + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + testsUtils.testNetworkName, + ), + logger: logger.getChild(`${NodeConnectionManager.name}Peer`), connectionConnectTimeoutTime: timeoutTime, }); taskManagerPeer = await TaskManager.createTaskManager({ @@ -532,6 +591,9 @@ describe(`${NodeManager.name}`, () => { nodeManager: nodeManagerPeer, acl: aclPeer, }), + nodesAuthenticateConnection: new NodesAuthenticateConnection({ + nodeConnectionManager: nodeConnectionManagerPeer, + }), } as AgentServerManifest, host: localHost, }); @@ -834,17 +896,39 @@ describe(`${NodeManager.name}`, () => { expect(host).toBe(localHost); expect(port).toBe(nodeConnectionManagerPeer.port); }); - test('adds node to NodeGraph after successful connection', async () => { + test('adds node to NodeGraph after successful and authentication', async () => { await nodeConnectionManager.createConnection( [keyRingPeer.getNodeId()], localHost, nodeConnectionManagerPeer.port, ); // Wait for handler to add nodes to the graph - await utils.sleep(100); + await testsUtils.promFromEvent( + nodeConnectionManager, + nodesEvents.EventNodeConnectionManagerConnectionAuthenticated, + ); + // Give time for the node to be added + await utils.sleep(500); expect(await nodeGraph.nodesTotal()).toBe(1); expect(await nodeGraphPeer.nodesTotal()).toBe(1); }); + test('failure to authenticate will not add node to NodeGraph', async () => { + nodeConnectionManager.setAuthenticateNetworkReverseCallback( + nodesUtils.nodesAuthenticateConnectionReverseDeny, + ); + nodeConnectionManagerPeer.setAuthenticateNetworkReverseCallback( + nodesUtils.nodesAuthenticateConnectionReverseDeny, + ); + await nodeConnectionManager.createConnection( + [keyRingPeer.getNodeId()], + localHost, + nodeConnectionManagerPeer.port, + ); + // Give time for the node to be added + await utils.sleep(1000); + expect(await nodeGraph.nodesTotal()).toBe(0); + expect(await nodeGraphPeer.nodesTotal()).toBe(0); + }); }); describe('with 1 peer and mdns', () => { let basePath: string; @@ -907,11 +991,19 @@ describe(`${NodeManager.name}`, () => { nodeConnectionManager = new NodeConnectionManager({ keyRing, tlsConfig: await testsUtils.createTLSConfig(keyRing.keyPair), + authenticateNetworkForwardCallback: + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + testsUtils.testNetworkName, + ), + authenticateNetworkReverseCallback: + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + testsUtils.testNetworkName, + ), logger: logger.getChild(NodeConnectionManager.name), connectionConnectTimeoutTime: timeoutTime, }); await nodeConnectionManager.start({ - agentService: {} as AgentServerManifest, + agentService: dummyAgentService, host: localHost, }); taskManager = await TaskManager.createTaskManager({ @@ -972,11 +1064,19 @@ describe(`${NodeManager.name}`, () => { nodeConnectionManagerPeer = new NodeConnectionManager({ keyRing: keyRingPeer, tlsConfig: await testsUtils.createTLSConfig(keyRingPeer.keyPair), + authenticateNetworkForwardCallback: + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + testsUtils.testNetworkName, + ), + authenticateNetworkReverseCallback: + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + testsUtils.testNetworkName, + ), logger: logger.getChild(NodeConnectionManager.name), connectionConnectTimeoutTime: timeoutTime, }); await nodeConnectionManagerPeer.start({ - agentService: {} as AgentServerManifest, + agentService: dummyAgentService, host: localHost, }); taskManagerPeer = await TaskManager.createTaskManager({ @@ -1154,11 +1254,23 @@ describe(`${NodeManager.name}`, () => { nodeConnectionManager = new NodeConnectionManager({ keyRing, tlsConfig: await testsUtils.createTLSConfig(keyRing.keyPair), + authenticateNetworkForwardCallback: + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + testsUtils.testNetworkName, + ), + authenticateNetworkReverseCallback: + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + testsUtils.testNetworkName, + ), logger: logger.getChild(NodeConnectionManager.name), connectionConnectTimeoutTime: timeoutTime, }); await nodeConnectionManager.start({ - agentService: {} as AgentServerManifest, + agentService: { + nodesAuthenticateConnection: new NodesAuthenticateConnection({ + nodeConnectionManager: nodeConnectionManager, + }), + } as AgentServerManifest, host: localHost, }); taskManager = await TaskManager.createTaskManager({ @@ -1225,6 +1337,9 @@ describe(`${NodeManager.name}`, () => { db, nodeGraph, }), + nodesAuthenticateConnection: new NodesAuthenticateConnection({ + nodeConnectionManager: nodeConnectionManager, + }), }) as AgentServerManifest, }, logger: logger.getChild(`${NodeConnectionManager.name}Peer${i}`), diff --git a/tests/nodes/agent/handlers/nodesClosestActiveConnectionsGet.test.ts b/tests/nodes/agent/handlers/nodesClosestActiveConnectionsGet.test.ts index e17428c979..1622f48f95 100644 --- a/tests/nodes/agent/handlers/nodesClosestActiveConnectionsGet.test.ts +++ b/tests/nodes/agent/handlers/nodesClosestActiveConnectionsGet.test.ts @@ -7,10 +7,11 @@ import type { NodeConnection } from '@/nodes'; import type { ActiveConnectionDataMessage } from '@/nodes/agent/types'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import * as keysUtils from '@/keys/utils'; +import NodeConnectionManager from '@/nodes/NodeConnectionManager'; +import NodesAuthenticateConnection from '@/nodes/agent/handlers/NodesAuthenticateConnection'; import NodesClosestActiveConnectionsGet from '@/nodes/agent/handlers/NodesClosestActiveConnectionsGet'; import * as nodesUtils from '@/nodes/utils'; import * as testsUtils from '../../../utils'; -import NodeConnectionManager from '../../../../src/nodes/NodeConnectionManager'; describe('nodesClosestLocalNode', () => { const logger = new Logger('nodesClosestLocalNode test', LogLevel.WARN, [ @@ -43,6 +44,14 @@ describe('nodesClosestLocalNode', () => { connectionIdleTimeoutTimeMin: 1000, connectionIdleTimeoutTimeScale: 0, connectionConnectTimeoutTime: timeoutTime, + authenticateNetworkForwardCallback: + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + testsUtils.testNetworkName, + ), + authenticateNetworkReverseCallback: + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + testsUtils.testNetworkName, + ), }); const keyPairPeer1 = keysUtils.generateKeyPair(); @@ -57,15 +66,30 @@ describe('nodesClosestLocalNode', () => { logger: logger.getChild(`${NodeConnectionManager.name}Peer1`), tlsConfig: tlsConfigPeer1, connectionConnectTimeoutTime: timeoutTime, + authenticateNetworkForwardCallback: + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + testsUtils.testNetworkName, + ), + authenticateNetworkReverseCallback: + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + testsUtils.testNetworkName, + ), }); await Promise.all([ nodeConnectionManagerLocal.start({ - agentService: {} as AgentServerManifest, + agentService: { + nodesAuthenticateConnection: new NodesAuthenticateConnection({ + nodeConnectionManager: nodeConnectionManagerLocal, + }), + } as AgentServerManifest, host: localHost, }), nodeConnectionManagerPeer1.start({ agentService: { + nodesAuthenticateConnection: new NodesAuthenticateConnection({ + nodeConnectionManager: nodeConnectionManagerPeer1, + }), nodesClosestActiveConnectionsGet: new NodesClosestActiveConnectionsGet({ nodeConnectionManager: nodeConnectionManagerPeer1, @@ -135,14 +159,20 @@ describe('nodesClosestLocalNode', () => { dummyConnections.set(nodeIdString, entry); } - const resultStream = - await connection.rpcClient.methods.nodesClosestActiveConnectionsGet({ - nodeIdEncoded: nodesUtils.encodeNodeId(targetNodeId), - }); - const results: Array = []; - for await (const result of resultStream) { - results.push(result); - } + const results = await nodeConnectionManagerLocal.withConnF( + nodeIdPeer1, + async () => { + const resultStream = + await connection.rpcClient.methods.nodesClosestActiveConnectionsGet({ + nodeIdEncoded: nodesUtils.encodeNodeId(targetNodeId), + }); + const results: Array = []; + for await (const result of resultStream) { + results.push(result); + } + return results; + }, + ); // @ts-ignore: restore existing connections nodeConnectionManagerPeer1.connections = existingConnections; diff --git a/tests/nodes/agent/handlers/notificationsSend.test.ts b/tests/nodes/agent/handlers/notificationsSend.test.ts index 8090cad140..cc34de832c 100644 --- a/tests/nodes/agent/handlers/notificationsSend.test.ts +++ b/tests/nodes/agent/handlers/notificationsSend.test.ts @@ -28,7 +28,7 @@ import * as keysUtils from '@/keys/utils'; import * as networkUtils from '@/network/utils'; import Sigchain from '@/sigchain/Sigchain'; import TaskManager from '@/tasks/TaskManager'; -import * as testUtils from '../../../utils/utils'; +import * as testsUtils from '../../../utils/utils'; import * as tlsTestsUtils from '../../../utils/tls'; import 'ix/add/asynciterable-operators/toarray'; @@ -127,6 +127,14 @@ describe('notificationsSend', () => { connectionConnectTimeoutTime: 2000, connectionIdleTimeoutTimeMin: 2000, connectionIdleTimeoutTimeScale: 0, + authenticateNetworkForwardCallback: + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + testsUtils.testNetworkName, + ), + authenticateNetworkReverseCallback: + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + testsUtils.testNetworkName, + ), logger: logger.getChild('NodeConnectionManager'), }); nodeManager = new NodeManager({ @@ -331,7 +339,7 @@ describe('notificationsSend', () => { isRead: false, }; const token = Token.fromPayload(notification1); - await testUtils.expectRemoteError( + await testsUtils.expectRemoteError( rpcClient.methods.notificationsSend({ signedNotificationEncoded: JSON.stringify( token.toJSON(), @@ -357,7 +365,7 @@ describe('notificationsSend', () => { notification2, senderKeyRing.keyPair, ); - await testUtils.expectRemoteError( + await testsUtils.expectRemoteError( rpcClient.methods.notificationsSend({ signedNotificationEncoded: signedNotification, }), @@ -392,7 +400,7 @@ describe('notificationsSend', () => { notification, senderKeyRing.keyPair, ); - await testUtils.expectRemoteError( + await testsUtils.expectRemoteError( rpcClient.methods.notificationsSend({ signedNotificationEncoded: signedNotification, }), diff --git a/tests/nodes/utils.ts b/tests/nodes/utils.ts index cb88f3ee26..73df775f71 100644 --- a/tests/nodes/utils.ts +++ b/tests/nodes/utils.ts @@ -298,6 +298,14 @@ async function nodeConnectionManagerFactory({ connectionHolePunchIntervalTime, rpcParserBufferSize, rpcCallTimeoutTime, + authenticateNetworkForwardCallback: + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + testsUtils.testNetworkName, + ), + authenticateNetworkReverseCallback: + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + testsUtils.testNetworkName, + ), }); await nodeConnectionManager.start({ diff --git a/tests/notifications/NotificationsManager.test.ts b/tests/notifications/NotificationsManager.test.ts index 363adee2df..7f833115cf 100644 --- a/tests/notifications/NotificationsManager.test.ts +++ b/tests/notifications/NotificationsManager.test.ts @@ -20,6 +20,7 @@ import KeyRing from '@/keys/KeyRing'; import NodeConnectionManager from '@/nodes/NodeConnectionManager'; import NodeGraph from '@/nodes/NodeGraph'; import NodeManager from '@/nodes/NodeManager'; +import NodesAuthenticateConnection from '@/nodes/agent/handlers/NodesAuthenticateConnection'; import NotificationsManager from '@/notifications/NotificationsManager'; import * as nodesErrors from '@/nodes/errors'; import * as notificationsErrors from '@/notifications/errors'; @@ -29,7 +30,7 @@ import * as vaultsUtils from '@/vaults/utils'; import * as nodesUtils from '@/nodes/utils'; import * as keysUtils from '@/keys/utils'; import * as utils from '@/utils'; -import * as testUtils from '../utils'; +import * as testsUtils from '../utils'; import * as tlsTestsUtils from '../utils/tls'; import 'ix/add/asynciterable-operators/toarray'; @@ -120,6 +121,14 @@ describe('NotificationsManager', () => { nodeConnectionManager = new NodeConnectionManager({ keyRing, tlsConfig, + authenticateNetworkForwardCallback: + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + testsUtils.testNetworkName, + ), + authenticateNetworkReverseCallback: + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + testsUtils.testNetworkName, + ), logger, }); nodeManager = new NodeManager({ @@ -135,13 +144,18 @@ describe('NotificationsManager', () => { await nodeManager.start(); await nodeConnectionManager.start({ host: localhost as Host, - agentService: {} as AgentServerManifest, + agentService: { + nodesAuthenticateConnection: new NodesAuthenticateConnection({ + nodeConnectionManager: nodeConnectionManager, + }), + } as AgentServerManifest, }); // Set up node for receiving notifications receiver = await PolykeyAgent.createPolykeyAgent({ password: password, options: { nodePath: path.join(dataDir, 'receiver'), + network: testsUtils.testNetworkName, agentServiceHost: localhost, clientServiceHost: localhost, keys: { @@ -407,7 +421,7 @@ describe('NotificationsManager', () => { }); await taskManager.startProcessing(); const { sendP } = await notificationsManager.sendNotification({ - nodeId: testUtils.generateRandomNodeId(), + nodeId: testsUtils.generateRandomNodeId(), data: { type: 'General', message: 'msg', diff --git a/tests/polykeyScratch.ts b/tests/polykeyScratch.ts index 3408f6966a..7e404e8b1a 100644 --- a/tests/polykeyScratch.ts +++ b/tests/polykeyScratch.ts @@ -6,6 +6,7 @@ import type { Hostname } from '../src/network/types'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; // Import { trackTimers } from './utils'; +import * as testsUtils from './utils'; import PolykeyAgent from '../src/PolykeyAgent'; import { sleep } from '../src/utils'; import { encodeNodeId } from '../src/ids'; @@ -25,6 +26,7 @@ async function main() { password, options: { nodePath, + network: testsUtils.testNetworkName, seedNodes, }, fresh: true, diff --git a/tests/utils/utils.ts b/tests/utils/utils.ts index 8a92f2b023..9916a2b1b3 100644 --- a/tests/utils/utils.ts +++ b/tests/utils/utils.ts @@ -146,6 +146,8 @@ function promFromEvents< return p; } +const testNetworkName = 'testNetwork'; + export { generateRandomNodeId, expectRemoteError, @@ -154,4 +156,5 @@ export { trackTimers, promFromEvent, promFromEvents, + testNetworkName, }; diff --git a/tests/vaults/VaultManager.test.ts b/tests/vaults/VaultManager.test.ts index 1385c14ea8..fb5645bbd8 100644 --- a/tests/vaults/VaultManager.test.ts +++ b/tests/vaults/VaultManager.test.ts @@ -24,6 +24,7 @@ import ACL from '@/acl/ACL'; import GestaltGraph from '@/gestalts/GestaltGraph'; import NodeManager from '@/nodes/NodeManager'; import NodeConnectionManager from '@/nodes/NodeConnectionManager'; +import NodesAuthenticateConnection from '@/nodes/agent/handlers/NodesAuthenticateConnection'; import KeyRing from '@/keys/KeyRing'; import PolykeyAgent from '@/PolykeyAgent'; import VaultManager from '@/vaults/VaultManager'; @@ -33,9 +34,11 @@ import { sleep } from '@/utils'; import * as keysUtils from '@/keys/utils'; import * as vaultsErrors from '@/vaults/errors'; import * as vaultsUtils from '@/vaults/utils'; +import * as nodesUtils from '@/nodes/utils'; import * as nodeTestUtils from '../nodes/utils'; import * as testUtils from '../utils'; import * as tlsTestsUtils from '../utils/tls'; +import * as testsUtils from '../utils'; describe('VaultManager', () => { const localhost = '127.0.0.1'; @@ -610,6 +613,7 @@ describe('VaultManager', () => { password, options: { nodePath: path.join(allDataDir, 'remoteKeynode1'), + network: testsUtils.testNetworkName, agentServiceHost: localhost, clientServiceHost: localhost, keys: { @@ -625,6 +629,7 @@ describe('VaultManager', () => { password, options: { nodePath: path.join(allDataDir, 'remoteKeynode2'), + network: testsUtils.testNetworkName, agentServiceHost: localhost, clientServiceHost: localhost, keys: { @@ -704,12 +709,16 @@ describe('VaultManager', () => { nodeConnectionManager = new NodeConnectionManager({ keyRing, tlsConfig, + authenticateNetworkForwardCallback: + nodesUtils.nodesAuthenticateConnectionForwardBasicPublicFactory( + testsUtils.testNetworkName, + ), + authenticateNetworkReverseCallback: + nodesUtils.nodesAuthenticateConnectionReverseBasicPublicFactory( + testsUtils.testNetworkName, + ), logger, }); - await nodeConnectionManager.start({ - host: localhost as Host, - agentService: {} as AgentServerManifest, - }); nodeManager = new NodeManager({ db, keyRing, @@ -721,6 +730,14 @@ describe('VaultManager', () => { logger, }); await nodeManager.start(); + await nodeConnectionManager.start({ + host: localhost as Host, + agentService: { + nodesAuthenticateConnection: new NodesAuthenticateConnection({ + nodeConnectionManager: nodeConnectionManager, + }), + } as AgentServerManifest, + }); await taskManager.startProcessing(); await nodeGraph.setNodeContactAddressData( remoteKeynode1Id,