diff --git a/.github/dictionary.txt b/.github/dictionary.txt index 71e5ed28d0..6b5fa94fea 100644 --- a/.github/dictionary.txt +++ b/.github/dictionary.txt @@ -14,3 +14,4 @@ additionals SECG Certicom RSAES +unuse diff --git a/packages/interface-compliance-tests/src/mocks/registrar.ts b/packages/interface-compliance-tests/src/mocks/registrar.ts index 1b2aa21bd8..6f6676f085 100644 --- a/packages/interface-compliance-tests/src/mocks/registrar.ts +++ b/packages/interface-compliance-tests/src/mocks/registrar.ts @@ -1,10 +1,11 @@ import { mergeOptions } from '@libp2p/utils/merge-options' -import type { Connection, PeerId, Topology, IncomingStreamData, StreamHandler, StreamHandlerOptions, StreamHandlerRecord } from '@libp2p/interface' +import type { Connection, PeerId, Topology, IncomingStreamData, StreamHandler, StreamHandlerOptions, StreamHandlerRecord, StreamMiddleware } from '@libp2p/interface' import type { Registrar } from '@libp2p/interface-internal' export class MockRegistrar implements Registrar { private readonly topologies = new Map>() private readonly handlers = new Map() + private readonly middleware = new Map() getProtocols (): string[] { return Array.from(this.handlers.keys()).sort() @@ -69,6 +70,18 @@ export class MockRegistrar implements Registrar { getTopologies (protocol: string): Topology[] { return (this.topologies.get(protocol) ?? []).map(t => t.topology) } + + use (protocol: string, middleware: StreamMiddleware[]): void { + this.middleware.set(protocol, middleware) + } + + unuse (protocol: string): void { + this.middleware.delete(protocol) + } + + getMiddleware (protocol: string): StreamMiddleware[] { + return this.middleware.get(protocol) ?? [] + } } export function mockRegistrar (): Registrar { diff --git a/packages/interface-internal/src/registrar.ts b/packages/interface-internal/src/registrar.ts index dc96e3877f..69332a78ae 100644 --- a/packages/interface-internal/src/registrar.ts +++ b/packages/interface-internal/src/registrar.ts @@ -1,4 +1,4 @@ -import type { StreamHandler, StreamHandlerOptions, StreamHandlerRecord, Topology, IncomingStreamData } from '@libp2p/interface' +import type { StreamHandler, StreamHandlerOptions, StreamHandlerRecord, Topology, IncomingStreamData, StreamMiddleware } from '@libp2p/interface' import type { AbortOptions } from '@multiformats/multiaddr' export type { @@ -69,6 +69,30 @@ export interface Registrar { */ getHandler(protocol: string): StreamHandlerRecord + /** + * Retrieve any registered middleware for a given protocol. + * + * @param protocol - The protocol to fetch middleware for + * @returns A list of `StreamMiddleware` implementations + */ + use(protocol: string, middleware: StreamMiddleware[]): void + + /** + * Retrieve any registered middleware for a given protocol. + * + * @param protocol - The protocol to fetch middleware for + * @returns A list of `StreamMiddleware` implementations + */ + unuse(protocol: string): void + + /** + * Retrieve any registered middleware for a given protocol. + * + * @param protocol - The protocol to fetch middleware for + * @returns A list of `StreamMiddleware` implementations + */ + getMiddleware(protocol: string): StreamMiddleware[] + /** * Register a topology handler for a protocol - the topology will be * invoked when peers are discovered on the network that support the diff --git a/packages/interface/src/index.ts b/packages/interface/src/index.ts index def07d8b36..e5026e7341 100644 --- a/packages/interface/src/index.ts +++ b/packages/interface/src/index.ts @@ -23,7 +23,7 @@ import type { PeerInfo } from './peer-info.js' import type { PeerRouting } from './peer-routing.js' import type { Address, Peer, PeerStore } from './peer-store.js' import type { Startable } from './startable.js' -import type { StreamHandler, StreamHandlerOptions } from './stream-handler.js' +import type { StreamHandler, StreamHandlerOptions, StreamMiddleware } from './stream-handler.js' import type { Topology } from './topology.js' import type { Listener, OutboundConnectionUpgradeEvents } from './transport.js' import type { Multiaddr } from '@multiformats/multiaddr' @@ -720,6 +720,33 @@ export interface Libp2p extends Startable, Ty */ unregister(id: string): void + /** + * Registers one or more middleware implementations that will be invoked for + * incoming and outgoing protocol streams that match the passed protocol. + * + * @example + * + * ```TypeScript + * libp2p.use('/my/protocol/1.0.0', (stream, connection, next) => { + * // do something with stream and/or connection + * next(stream, connection) + * }) + * ``` + */ + use (protocol: string, middleware: StreamMiddleware | StreamMiddleware[]): void + + /** + * Deregisters all middleware for the passed protocol. + * + * @example + * + * ```TypeScript + * libp2p.unuse('/my/protocol/1.0.0') + * // any previously registered middleware will no longer be invoked + * ``` + */ + unuse (protocol: string): void + /** * Returns the public key for the passed PeerId. If the PeerId is of the 'RSA' * type this may mean searching the routing if the peer's key is not present diff --git a/packages/interface/src/stream-handler.ts b/packages/interface/src/stream-handler.ts index 39a087a5f9..219b513565 100644 --- a/packages/interface/src/stream-handler.ts +++ b/packages/interface/src/stream-handler.ts @@ -20,6 +20,13 @@ export interface StreamHandler { (data: IncomingStreamData): void } +/** + * Stream middleware allows accessing stream data outside of the stream handler + */ +export interface StreamMiddleware { + (stream: Stream, connection: Connection, next: (stream: Stream, connection: Connection) => void): void +} + export interface StreamHandlerOptions extends AbortOptions { /** * How many incoming streams can be open for this protocol at the same time on each connection @@ -46,6 +53,11 @@ export interface StreamHandlerOptions extends AbortOptions { * protocol(s), the existing handler will be discarded. */ force?: true + + /** + * Middleware allows accessing stream data outside of the stream handler + */ + middleware?: StreamMiddleware[] } export interface StreamHandlerRecord { diff --git a/packages/libp2p/src/libp2p.ts b/packages/libp2p/src/libp2p.ts index 5dad2a0520..a4e0235b4b 100644 --- a/packages/libp2p/src/libp2p.ts +++ b/packages/libp2p/src/libp2p.ts @@ -24,7 +24,7 @@ import { userAgent } from './user-agent.js' import * as pkg from './version.js' import type { Components } from './components.js' import type { Libp2p as Libp2pInterface, Libp2pInit } from './index.js' -import type { PeerRouting, ContentRouting, Libp2pEvents, PendingDial, ServiceMap, AbortOptions, ComponentLogger, Logger, Connection, NewStreamOptions, Stream, Metrics, PeerId, PeerInfo, PeerStore, Topology, Libp2pStatus, IsDialableOptions, DialOptions, PublicKey, Ed25519PeerId, Secp256k1PeerId, RSAPublicKey, RSAPeerId, URLPeerId, Ed25519PublicKey, Secp256k1PublicKey, StreamHandler, StreamHandlerOptions } from '@libp2p/interface' +import type { PeerRouting, ContentRouting, Libp2pEvents, PendingDial, ServiceMap, AbortOptions, ComponentLogger, Logger, Connection, NewStreamOptions, Stream, Metrics, PeerId, PeerInfo, PeerStore, Topology, Libp2pStatus, IsDialableOptions, DialOptions, PublicKey, Ed25519PeerId, Secp256k1PeerId, RSAPublicKey, RSAPeerId, URLPeerId, Ed25519PublicKey, Secp256k1PublicKey, StreamHandler, StreamHandlerOptions, StreamMiddleware } from '@libp2p/interface' import type { Multiaddr } from '@multiformats/multiaddr' export class Libp2p extends TypedEventEmitter implements Libp2pInterface { @@ -402,6 +402,14 @@ export class Libp2p extends TypedEventEmitter this.components.registrar.unregister(id) } + use (protocol: string, middleware: StreamMiddleware | StreamMiddleware[]): void { + this.components.registrar.use(protocol, Array.isArray(middleware) ? middleware : [middleware]) + } + + unuse (protocol: string): void { + this.components.registrar.unuse(protocol) + } + async isDialable (multiaddr: Multiaddr, options: IsDialableOptions = {}): Promise { return this.components.connectionManager.isDialable(multiaddr, options) } diff --git a/packages/libp2p/src/registrar.ts b/packages/libp2p/src/registrar.ts index c4f10cb111..d628ee047b 100644 --- a/packages/libp2p/src/registrar.ts +++ b/packages/libp2p/src/registrar.ts @@ -2,7 +2,7 @@ import { InvalidParametersError } from '@libp2p/interface' import { mergeOptions } from '@libp2p/utils/merge-options' import { trackedMap } from '@libp2p/utils/tracked-map' import * as errorsJs from './errors.js' -import type { IdentifyResult, Libp2pEvents, Logger, PeerUpdate, PeerId, PeerStore, Topology, StreamHandler, StreamHandlerRecord, StreamHandlerOptions, AbortOptions, Metrics } from '@libp2p/interface' +import type { IdentifyResult, Libp2pEvents, Logger, PeerUpdate, PeerId, PeerStore, Topology, StreamHandler, StreamHandlerRecord, StreamHandlerOptions, AbortOptions, Metrics, StreamMiddleware } from '@libp2p/interface' import type { Registrar as RegistrarInterface } from '@libp2p/interface-internal' import type { ComponentLogger } from '@libp2p/logger' import type { TypedEventTarget } from 'main-event' @@ -26,10 +26,12 @@ export class Registrar implements RegistrarInterface { private readonly topologies: Map> private readonly handlers: Map private readonly components: RegistrarComponents + private readonly middleware: Map constructor (components: RegistrarComponents) { this.components = components this.log = components.logger.forComponent('libp2p:registrar') + this.middleware = new Map() this.topologies = new Map() components.metrics?.registerMetricGroup('libp2p_registrar_topologies', { calculate: () => { @@ -165,6 +167,18 @@ export class Registrar implements RegistrarInterface { } } + use (protocol: string, middleware: StreamMiddleware[]): void { + this.middleware.set(protocol, middleware) + } + + unuse (protocol: string): void { + this.middleware.delete(protocol) + } + + getMiddleware (protocol: string): StreamMiddleware[] { + return this.middleware.get(protocol) ?? [] + } + /** * Remove a disconnected peer from the record */ diff --git a/packages/libp2p/src/upgrader.ts b/packages/libp2p/src/upgrader.ts index 833e6e6872..5eda4d67af 100644 --- a/packages/libp2p/src/upgrader.ts +++ b/packages/libp2p/src/upgrader.ts @@ -395,7 +395,7 @@ export class Upgrader implements UpgraderInterface { let muxer: StreamMuxer | undefined let newStream: ((multicodecs: string[], options?: AbortOptions) => Promise) | undefined - let connection: Connection // eslint-disable-line prefer-const + let connection: Connection if (muxerFactory != null) { // Create the muxer @@ -488,7 +488,7 @@ export class Upgrader implements UpgraderInterface { } connection.log.trace('starting new stream for protocols %s', protocols) - const muxedStream = await muxer.newStream() + let muxedStream = await muxer.newStream() connection.log.trace('started new stream %s for protocols %s', muxedStream.id, protocols) try { @@ -556,6 +556,23 @@ export class Upgrader implements UpgraderInterface { this.components.metrics?.trackProtocolStream(muxedStream, connection) + const middleware = this.components.registrar.getMiddleware(protocol) + + middleware.push((stream, connection, next) => { + next(stream, connection) + }) + + let i = 0 + + while (i < middleware.length) { + // eslint-disable-next-line no-loop-func + middleware[i](muxedStream, connection, (s, c) => { + muxedStream = s + connection = c + i++ + }) + } + return muxedStream } catch (err: any) { connection.log.error('could not create new outbound stream on connection %s %a for protocols %s - %e', direction === 'inbound' ? 'from' : 'to', opts.maConn.remoteAddr, protocols, err) @@ -652,14 +669,30 @@ export class Upgrader implements UpgraderInterface { * Routes incoming streams to the correct handler */ _onStream (opts: OnStreamOptions): void { - const { connection, stream, protocol } = opts + let { connection, stream, protocol } = opts const { handler, options } = this.components.registrar.getHandler(protocol) if (connection.limits != null && options.runOnLimitedConnection !== true) { throw new LimitedConnectionError('Cannot open protocol stream on limited connection') } - handler({ connection, stream }) + const middleware = this.components.registrar.getMiddleware(protocol) + + middleware.push((stream, connection, next) => { + handler({ connection, stream }) + next(stream, connection) + }) + + let i = 0 + + while (i < middleware.length) { + // eslint-disable-next-line no-loop-func + middleware[i](stream, connection, (s, c) => { + stream = s + connection = c + i++ + }) + } } /** diff --git a/packages/libp2p/test/upgrading/upgrader.spec.ts b/packages/libp2p/test/upgrading/upgrader.spec.ts index 5fd14e8cc8..09289e3e67 100644 --- a/packages/libp2p/test/upgrading/upgrader.spec.ts +++ b/packages/libp2p/test/upgrading/upgrader.spec.ts @@ -25,6 +25,7 @@ describe('upgrader', () => { let init: UpgraderInit const encrypterProtocol = '/test-encrypter' const muxerProtocol = '/test-muxer' + const streamProtocol = '/test/protocol' let remotePeer: PeerId let remoteAddr: Multiaddr let maConn: MultiaddrConnection @@ -36,6 +37,38 @@ describe('upgrader', () => { async secureOutbound (): Promise { throw new Error('Boom') } } + function stubMuxerFactory (protocol: string = streamProtocol, onInit?: (init: StreamMuxerInit) => void): StreamMuxerFactory { + return stubInterface({ + protocol: muxerProtocol, + createStreamMuxer: (init = {}) => { + const streamMuxer = stubInterface({ + protocol: muxerProtocol, + sink: async (source) => drain(source), + source: (async function * () {})(), + streams: [], + newStream: () => { + const outgoingStream = stubInterface({ + id: 'stream-id', + log: logger('test-stream'), + direction: 'outbound', + sink: async (source) => drain(source), + source: map((async function * () { + yield '/multistream/1.0.0\n' + yield `${protocol}\n` + })(), str => encode.single(uint8ArrayFromString(str))) + }) + + streamMuxer?.streams.push(outgoingStream) + return outgoingStream + } + }) + + onInit?.(init) + return streamMuxer + } + }) + } + beforeEach(async () => { remotePeer = peerIdFromPrivateKey(await generateKeyPair('Ed25519')) remoteAddr = multiaddr(`/ip4/123.123.123.123/tcp/1234/p2p/${remotePeer}`) @@ -435,7 +468,8 @@ describe('upgrader', () => { }, handler: Sinon.stub() }), - getProtocols: () => [protocol] + getProtocols: () => [protocol], + getMiddleware: () => [] }) }) const upgrader = new Upgrader(components, { @@ -503,7 +537,8 @@ describe('upgrader', () => { }, handler: Sinon.stub() }), - getProtocols: () => [protocol] + getProtocols: () => [protocol], + getMiddleware: () => [] }) }) const upgrader = new Upgrader(components, { @@ -566,7 +601,8 @@ describe('upgrader', () => { options: {}, handler: Sinon.stub() }), - getProtocols: () => [protocol] + getProtocols: () => [protocol], + getMiddleware: () => [] }) }) const upgrader = new Upgrader(components, { @@ -625,6 +661,115 @@ describe('upgrader', () => { .with.property('name', 'TooManyOutboundProtocolStreamsError') }) + it('should support outgoing stream middleware', async () => { + const middleware1 = Sinon.stub().callsFake((stream, connection, next) => { + next(stream, connection) + }) + const middleware2 = Sinon.stub().callsFake((stream, connection, next) => { + next(stream, connection) + }) + + const middleware = [ + middleware1, + middleware2 + ] + + const components = await createDefaultUpgraderComponents({ + registrar: stubInterface({ + getHandler: () => ({ + options: {}, + handler: Sinon.stub() + }), + getProtocols: () => [streamProtocol], + getMiddleware: () => middleware + }) + }) + const upgrader = new Upgrader(components, { + ...init, + streamMuxers: [ + stubMuxerFactory() + ] + }) + + const connectionPromise = pEvent<'connection:open', CustomEvent>(components.events, 'connection:open') + + await upgrader.upgradeInbound(maConn, { + signal: AbortSignal.timeout(5_000) + }) + + const event = await connectionPromise + const conn = event.detail + + expect(conn.streams).to.have.lengthOf(0) + + await conn.newStream(streamProtocol) + + expect(middleware1.called).to.be.true() + expect(middleware2.called).to.be.true() + }) + + it('should support incoming stream middleware', async () => { + const middleware1 = Sinon.stub().callsFake((stream, connection, next) => { + next(stream, connection) + }) + const middleware2 = Sinon.stub().callsFake((stream, connection, next) => { + next(stream, connection) + }) + + const middleware = [ + middleware1, + middleware2 + ] + + const streamMuxerInitPromise = Promise.withResolvers() + + const components = await createDefaultUpgraderComponents({ + registrar: stubInterface({ + getHandler: () => ({ + options: {}, + handler: Sinon.stub() + }), + getProtocols: () => [streamProtocol], + getMiddleware: () => middleware + }) + }) + const upgrader = new Upgrader(components, { + ...init, + streamMuxers: [ + stubMuxerFactory(muxerProtocol, (init) => { + streamMuxerInitPromise.resolve(init) + }) + ] + }) + + const conn = await upgrader.upgradeOutbound(maConn, { + signal: AbortSignal.timeout(5_000) + }) + + const { onIncomingStream } = await streamMuxerInitPromise.promise + + expect(conn.streams).to.have.lengthOf(0) + + const incomingStream = stubInterface({ + id: 'stream-id', + log: logger('test-stream'), + direction: 'outbound', + sink: async (source) => drain(source), + source: map((async function * () { + yield '/multistream/1.0.0\n' + yield `${streamProtocol}\n` + })(), str => encode.single(uint8ArrayFromString(str))) + }) + + onIncomingStream?.(incomingStream) + + // incoming stream is opened asynchronously + await delay(100) + + expect(middleware1.called).to.be.true() + expect(middleware2.called).to.be.true() + }) + describe('early muxer selection', () => { let earlyMuxerProtocol: string let streamMuxerFactory: StreamMuxerFactory diff --git a/packages/pubsub/test/utils/index.ts b/packages/pubsub/test/utils/index.ts index 453981c567..44c7be5747 100644 --- a/packages/pubsub/test/utils/index.ts +++ b/packages/pubsub/test/utils/index.ts @@ -1,7 +1,7 @@ import { duplexPair } from 'it-pair/duplex' import { PubSubBaseProtocol } from '../../src/index.js' import { RPC } from '../message/rpc.js' -import type { Connection, PeerId, PublishResult, PubSubRPC, PubSubRPCMessage, Topology, IncomingStreamData, StreamHandler, StreamHandlerRecord } from '@libp2p/interface' +import type { Connection, PeerId, PublishResult, PubSubRPC, PubSubRPCMessage, Topology, IncomingStreamData, StreamHandler, StreamHandlerRecord, StreamMiddleware } from '@libp2p/interface' import type { Registrar } from '@libp2p/interface-internal' export class PubsubImplementation extends PubSubBaseProtocol { @@ -31,6 +31,7 @@ export class PubsubImplementation extends PubSubBaseProtocol { export class MockRegistrar implements Registrar { private readonly topologies = new Map() private readonly handlers = new Map() + private readonly middleware = new Map() getProtocols (): string[] { const protocols = new Set() @@ -114,6 +115,18 @@ export class MockRegistrar implements Registrar { throw new Error(`No topologies registered for protocol ${protocol}`) } + + use (protocol: string, middleware: StreamMiddleware[]): void { + this.middleware.set(protocol, middleware) + } + + unuse (protocol: string): void { + this.middleware.delete(protocol) + } + + getMiddleware (protocol: string): StreamMiddleware[] { + return this.middleware.get(protocol) ?? [] + } } export const ConnectionPair = (): [Connection, Connection] => {