diff --git a/src/client.ts b/src/client.ts index bb972d88..33c491e4 100644 --- a/src/client.ts +++ b/src/client.ts @@ -704,7 +704,10 @@ export class Client { } static async connect(params: ClientParams, logger?: Logger): Promise { - return new Client(logger ?? new NullLogger(), params).start() + return new Client(logger ?? new NullLogger(), { + ...params, + vhost: getVhostOrDefault(params.vhost), + }).start() } } @@ -837,3 +840,5 @@ const extractConsumerId = (extendedConsumerId: string) => { const extractPublisherId = (extendedPublisherId: string) => { return parseInt(extendedPublisherId.split("@").shift() ?? "0") } + +const getVhostOrDefault = (vhost: string) => vhost ?? "/" diff --git a/src/connection.ts b/src/connection.ts index 19180f46..4b8e3246 100644 --- a/src/connection.ts +++ b/src/connection.ts @@ -152,7 +152,8 @@ export class Connection { mechanism: this.params.mechanism ?? "PLAIN", }) const { heartbeat } = await this.tune(this.params.heartbeat ?? 0) - await this.open({ virtualHost: this.params.vhost }) + const connectionOpened = await this.open({ virtualHost: this.params.vhost }) + if (!connectionOpened.ok) return rej(connectionOpened.error) if (!this.heartbeat.started) this.heartbeat.start(heartbeat) await this.exchangeCommandVersions() this.setupCompleted = true @@ -461,12 +462,25 @@ export class Connection { private async open(params: { virtualHost: string }) { this.logger.debug(`Open ...`) + if (this.virtualHostIsNotValid(params.virtualHost)) { + const errorMessage = `[ERROR]: VirtualHost '${params.virtualHost}' is not valid` + this.logger.error(errorMessage) + return { ok: false, error: new Error(errorMessage) } + } const res = await this.sendAndWait(new OpenRequest(params)) this.logger.debug(`Open response: ${res.ok} - '${inspect(res.properties)}'`) const advertisedHost = res.properties["advertised_host"] ?? "" const advertisedPort = parseInt(res.properties["advertised_port"] ?? "5552") this.serverEndpoint = { host: advertisedHost, port: advertisedPort } - return res + return { ok: true, response: res } + } + + private virtualHostIsNotValid(virtualHost: string) { + if (!virtualHost || virtualHost.split("/").length !== 2) { + return true + } + + return false } private async tune(heartbeatInterval: number): Promise<{ heartbeat: number }> { diff --git a/test/e2e/connect.test.ts b/test/e2e/connect.test.ts index 14232ef7..d44e3151 100644 --- a/test/e2e/connect.test.ts +++ b/test/e2e/connect.test.ts @@ -2,7 +2,7 @@ import { expect } from "chai" import { Client, connect } from "../../src" import { createClient } from "../support/fake_data" import { Rabbit } from "../support/rabbit" -import { eventually, username, password, getTestNodesFromEnv } from "../support/util" +import { eventually, username, password, getTestNodesFromEnv, expectToThrowAsync } from "../support/util" import { Version } from "../../src/versions" import { randomUUID } from "node:crypto" import { readFile } from "node:fs/promises" @@ -46,6 +46,26 @@ describe("connect", () => { }, 5000) }).timeout(10000) + it("throw exception if vhost is not valid", async () => { + const [firstNode] = getTestNodesFromEnv() + + await expectToThrowAsync( + async () => { + client = await connect({ + hostname: firstNode.host, + port: firstNode.port, + username, + password, + vhost: "", + frameMax: 0, + heartbeat: 0, + }) + }, + Error, + `[ERROR]: VirtualHost '' is not valid` + ) + }).timeout(10000) + it("using EXTERNAL auth", async () => { client = await createTlsClient()