diff --git a/packages/interface/src/errors.ts b/packages/interface/src/errors.ts index b5ff2bd497..3df9fbc773 100644 --- a/packages/interface/src/errors.ts +++ b/packages/interface/src/errors.ts @@ -228,6 +228,18 @@ export class UnsupportedProtocolError extends Error { } } +/** + * Thrown when a protocol is not negotiated properly + */ +export class ProtocolNegotiationError extends Error { + static name = 'ProtocolNegotiationError' + + constructor (message = 'Protocol negotiation error', options?: ErrorOptions) { + super(message, options) + this.name = 'ProtocolNegotiationError' + } +} + /** * An invalid or malformed message was encountered during a protocol exchange */ diff --git a/packages/libp2p/test/upgrading/upgrader.spec.ts b/packages/libp2p/test/upgrading/upgrader.spec.ts index 98ba20e73c..2a6bf9d7fc 100644 --- a/packages/libp2p/test/upgrading/upgrader.spec.ts +++ b/packages/libp2p/test/upgrading/upgrader.spec.ts @@ -193,7 +193,7 @@ describe('upgrader', () => { await expect(upgrader.upgradeOutbound(maConn, { signal: AbortSignal.timeout(100) })).to.eventually.be.rejected - .with.property('message').that.include('aborted') + .with.property('message').that.include('protocol negotiation failed') }) it('should not abort if inbound upgrade is successful', async () => { diff --git a/packages/multistream-select/src/select.ts b/packages/multistream-select/src/select.ts index d877fd4a28..8d4fc9cc4c 100644 --- a/packages/multistream-select/src/select.ts +++ b/packages/multistream-select/src/select.ts @@ -1,4 +1,4 @@ -import { UnsupportedProtocolError } from '@libp2p/interface' +import { ProtocolNegotiationError, UnsupportedProtocolError } from '@libp2p/interface' import { lpStream } from 'it-length-prefixed-stream' import pDefer from 'p-defer' import { raceSignal } from 'race-signal' @@ -79,38 +79,43 @@ export async function select (stream: Stream, prot throw new Error('At least one protocol must be specified') } - options.log.trace('select: write ["%s", "%s"]', PROTOCOL_ID, protocol) - const p1 = uint8ArrayFromString(`${PROTOCOL_ID}\n`) - const p2 = uint8ArrayFromString(`${protocol}\n`) - await multistream.writeAll(lp, [p1, p2], options) + try { + options.log.trace('select: write ["%s", "%s"]', PROTOCOL_ID, protocol) + const p1 = uint8ArrayFromString(`${PROTOCOL_ID}\n`) + const p2 = uint8ArrayFromString(`${protocol}\n`) + await multistream.writeAll(lp, [p1, p2], options) - options.log.trace('select: reading multistream-select header') - let response = await multistream.readString(lp, options) - options.log.trace('select: read "%s"', response) - - // Read the protocol response if we got the protocolId in return - if (response === PROTOCOL_ID) { - options.log.trace('select: reading protocol response') - response = await multistream.readString(lp, options) + options.log.trace('select: reading multistream-select header') + let response = await multistream.readString(lp, options) options.log.trace('select: read "%s"', response) - } - // We're done - if (response === protocol) { - return { stream: lp.unwrap(), protocol } - } - - // We haven't gotten a valid ack, try the other protocols - for (const protocol of protocols) { - options.log.trace('select: write "%s"', protocol) - await multistream.write(lp, uint8ArrayFromString(`${protocol}\n`), options) - options.log.trace('select: reading protocol response') - const response = await multistream.readString(lp, options) - options.log.trace('select: read "%s" for "%s"', response, protocol) + // Read the protocol response if we got the protocolId in return + if (response === PROTOCOL_ID) { + options.log.trace('select: reading protocol response') + response = await multistream.readString(lp, options) + options.log.trace('select: read "%s"', response) + } + // We're done if (response === protocol) { return { stream: lp.unwrap(), protocol } } + + // We haven't gotten a valid ack, try the other protocols + for (const protocol of protocols) { + options.log.trace('select: write "%s"', protocol) + await multistream.write(lp, uint8ArrayFromString(`${protocol}\n`), options) + options.log.trace('select: reading protocol response') + const response = await multistream.readString(lp, options) + options.log.trace('select: read "%s" for "%s"', response, protocol) + + if (response === protocol) { + return { stream: lp.unwrap(), protocol } + } + } + } catch (err) { + options.log.error('select: error negotiating protocol', err) + throw new ProtocolNegotiationError('protocol negotiation failed', { cause: err }) } throw new UnsupportedProtocolError('protocol selection failed')