diff --git a/package.json b/package.json index 39455ef1161..6b39f9c4b90 100644 --- a/package.json +++ b/package.json @@ -86,6 +86,7 @@ "@types/sinon-chai": "3.2.12", "@types/tmp": "0.2.6", "@types/trusted-types": "2.0.7", + "@types/ws": "8.18.1", "@types/yargs": "17.0.33", "@typescript-eslint/eslint-plugin": "7.18.0", "@typescript-eslint/eslint-plugin-tslint": "7.0.2", @@ -158,6 +159,7 @@ "typescript": "5.5.4", "watch": "1.0.2", "webpack": "5.98.0", + "ws": "8.18.3", "yargs": "17.7.2" } } diff --git a/packages/ai/package.json b/packages/ai/package.json index 97186afb1e1..d2ca2142b67 100644 --- a/packages/ai/package.json +++ b/packages/ai/package.json @@ -39,6 +39,7 @@ "test:ci": "yarn testsetup && node ../../scripts/run_tests_in_ci.js -s test", "test:skip-clone": "karma start", "test:browser": "yarn testsetup && karma start", + "test:node": "TS_NODE_COMPILER_OPTIONS='{\"module\":\"commonjs\"}' mocha --require ts-node/register --require src/index.node.ts 'src/**/*.test.ts' --config ../../config/mocharc.node.js", "test:integration": "karma start --integration", "api-report": "api-extractor run --local --verbose", "typings:public": "node ../../scripts/build/use_typings.js ./dist/ai-public.d.ts", diff --git a/packages/ai/rollup.config.js b/packages/ai/rollup.config.js index 7ebbff4f2f5..016698824fb 100644 --- a/packages/ai/rollup.config.js +++ b/packages/ai/rollup.config.js @@ -15,6 +15,7 @@ * limitations under the License. */ +import alias from '@rollup/plugin-alias'; import json from '@rollup/plugin-json'; import typescriptPlugin from 'rollup-plugin-typescript2'; import replace from 'rollup-plugin-replace'; @@ -23,6 +24,7 @@ import pkg from './package.json'; import tsconfig from './tsconfig.json'; import { generateBuildTargetReplaceConfig } from '../../scripts/build/rollup_replace_build_target'; import { emitModulePackageFile } from '../../scripts/build/rollup_emit_module_package_file'; +import { generateAliasConfig } from '../../scripts/build/rollup_generate_alias_config'; const deps = Object.keys( Object.assign({}, pkg.peerDependencies, pkg.dependencies) @@ -55,14 +57,16 @@ const browserBuilds = [ sourcemap: true }, plugins: [ + alias(generateAliasConfig('browser')), ...buildPlugins, replace({ ...generateBuildTargetReplaceConfig('esm', 2020), - __PACKAGE_VERSION__: pkg.version + '__PACKAGE_VERSION__': pkg.version }), emitModulePackageFile() ], - external: id => deps.some(dep => id === dep || id.startsWith(`${dep}/`)) + external: id => + id === 'ws' || deps.some(dep => id === dep || id.startsWith(`${dep}/`)) }, { input: 'src/index.ts', @@ -72,13 +76,15 @@ const browserBuilds = [ sourcemap: true }, plugins: [ + alias(generateAliasConfig('browser')), ...buildPlugins, replace({ ...generateBuildTargetReplaceConfig('cjs', 2020), - __PACKAGE_VERSION__: pkg.version + '__PACKAGE_VERSION__': pkg.version }) ], - external: id => deps.some(dep => id === dep || id.startsWith(`${dep}/`)) + external: id => + id === 'ws' || deps.some(dep => id === dep || id.startsWith(`${dep}/`)) } ]; @@ -91,12 +97,14 @@ const nodeBuilds = [ sourcemap: true }, plugins: [ + alias(generateAliasConfig('node')), ...buildPlugins, replace({ ...generateBuildTargetReplaceConfig('esm', 2020) }) ], - external: id => deps.some(dep => id === dep || id.startsWith(`${dep}/`)) + external: id => + id === 'ws' || deps.some(dep => id === dep || id.startsWith(`${dep}/`)) }, { input: 'src/index.node.ts', @@ -106,12 +114,14 @@ const nodeBuilds = [ sourcemap: true }, plugins: [ + alias(generateAliasConfig('node')), ...buildPlugins, replace({ ...generateBuildTargetReplaceConfig('cjs', 2020) }) ], - external: id => deps.some(dep => id === dep || id.startsWith(`${dep}/`)) + external: id => + id === 'ws' || deps.some(dep => id === dep || id.startsWith(`${dep}/`)) } ]; diff --git a/packages/ai/src/platform/browser/websocket.test.ts b/packages/ai/src/platform/browser/websocket.test.ts new file mode 100644 index 00000000000..5f211363576 --- /dev/null +++ b/packages/ai/src/platform/browser/websocket.test.ts @@ -0,0 +1,273 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { expect, use } from 'chai'; +import sinon, { SinonFakeTimers, SinonStub } from 'sinon'; +import sinonChai from 'sinon-chai'; +import chaiAsPromised from 'chai-as-promised'; +import { isBrowser } from '@firebase/util'; +import { BrowserWebSocketHandler } from './websocket'; +import { AIError } from '../../errors'; + +use(sinonChai); +use(chaiAsPromised); + +class MockBrowserWebSocket { + static CONNECTING = 0; + static OPEN = 1; + static CLOSING = 2; + static CLOSED = 3; + + readyState: number = MockBrowserWebSocket.CONNECTING; + sentMessages: Array = []; + url: string; + private listeners: Map> = new Map(); + + constructor(url: string) { + this.url = url; + } + + send(data: string | ArrayBuffer): void { + if (this.readyState !== MockBrowserWebSocket.OPEN) { + throw new Error('WebSocket is not in OPEN state'); + } + this.sentMessages.push(data); + } + + close(): void { + if ( + this.readyState === MockBrowserWebSocket.CLOSED || + this.readyState === MockBrowserWebSocket.CLOSING + ) { + return; + } + this.readyState = MockBrowserWebSocket.CLOSING; + setTimeout(() => { + this.readyState = MockBrowserWebSocket.CLOSED; + this.dispatchEvent(new Event('close')); + }, 10); + } + + addEventListener(type: string, listener: EventListener): void { + if (!this.listeners.has(type)) { + this.listeners.set(type, new Set()); + } + this.listeners.get(type)!.add(listener); + } + + removeEventListener(type: string, listener: EventListener): void { + this.listeners.get(type)?.delete(listener); + } + + dispatchEvent(event: Event): void { + this.listeners.get(event.type)?.forEach(listener => listener(event)); + } + + triggerOpen(): void { + this.readyState = MockBrowserWebSocket.OPEN; + this.dispatchEvent(new Event('open')); + } + + triggerMessage(data: any): void { + this.dispatchEvent(new MessageEvent('message', { data })); + } + + triggerError(): void { + this.dispatchEvent(new Event('error')); + } +} + +describe('BrowserWebSocketHandler', () => { + let handler: BrowserWebSocketHandler; + let mockWebSocket: MockBrowserWebSocket; + let clock: SinonFakeTimers; + let webSocketStub: SinonStub; + + // Only run these tests in a browser environment + if (!isBrowser()) { + return; + } + + beforeEach(() => { + webSocketStub = sinon.stub(window, 'WebSocket').callsFake((url: string) => { + mockWebSocket = new MockBrowserWebSocket(url); + return mockWebSocket as any; + }); + clock = sinon.useFakeTimers(); + handler = new BrowserWebSocketHandler(); + }); + + afterEach(() => { + sinon.restore(); + clock.restore(); + }); + + describe('connect()', () => { + it('should resolve on open event', async () => { + const connectPromise = handler.connect('ws://test-url'); + expect(webSocketStub).to.have.been.calledWith('ws://test-url'); + + await clock.tickAsync(1); + mockWebSocket.triggerOpen(); + + await expect(connectPromise).to.be.fulfilled; + }); + + it('should reject on error event', async () => { + const connectPromise = handler.connect('ws://test-url'); + await clock.tickAsync(1); + mockWebSocket.triggerError(); + + await expect(connectPromise).to.be.rejectedWith( + AIError, + /Failed to establish WebSocket connection/ + ); + }); + }); + + describe('listen()', () => { + beforeEach(async () => { + const connectPromise = handler.connect('ws://test'); + mockWebSocket.triggerOpen(); + await connectPromise; + }); + + it('should yield multiple messages as they arrive', async () => { + const generator = handler.listen(); + + const received: unknown[] = []; + const listenPromise = (async () => { + for await (const msg of generator) { + received.push(msg); + } + })(); + + // Use tickAsync to allow the consumer to start listening + await clock.tickAsync(1); + mockWebSocket.triggerMessage(new Blob([JSON.stringify({ foo: 1 })])); + + await clock.tickAsync(10); + mockWebSocket.triggerMessage(new Blob([JSON.stringify({ foo: 2 })])); + + await clock.tickAsync(5); + mockWebSocket.close(); + await clock.runAllAsync(); // Let timers finish + + await listenPromise; // Wait for the consumer to finish + + expect(received).to.deep.equal([ + { + foo: 1 + }, + { + foo: 2 + } + ]); + }); + + it('should buffer messages that arrive before the consumer calls .next()', async () => { + const generator = handler.listen(); + + // Create a promise that will consume the generator in a separate async context + const received: unknown[] = []; + const consumptionPromise = (async () => { + for await (const message of generator) { + received.push(message); + } + })(); + + await clock.tickAsync(1); + + mockWebSocket.triggerMessage(new Blob([JSON.stringify({ foo: 1 })])); + mockWebSocket.triggerMessage(new Blob([JSON.stringify({ foo: 2 })])); + + await clock.tickAsync(1); + mockWebSocket.close(); + await clock.runAllAsync(); + + await consumptionPromise; + + expect(received).to.deep.equal([ + { + foo: 1 + }, + { + foo: 2 + } + ]); + }); + }); + + describe('close()', () => { + it('should be idempotent and not throw if called multiple times', async () => { + const connectPromise = handler.connect('ws://test'); + mockWebSocket.triggerOpen(); + await connectPromise; + + const closePromise1 = handler.close(); + await clock.runAllAsync(); + await closePromise1; + + await expect(handler.close()).to.be.fulfilled; + }); + + it('should wait for the onclose event before resolving', async () => { + const connectPromise = handler.connect('ws://test'); + mockWebSocket.triggerOpen(); + await connectPromise; + + let closed = false; + const closePromise = handler.close().then(() => { + closed = true; + }); + + // The promise should not have resolved yet + await clock.tickAsync(5); + expect(closed).to.be.false; + + // Now, let the mock's setTimeout for closing run, which triggers onclose + await clock.tickAsync(10); + + await expect(closePromise).to.be.fulfilled; + expect(closed).to.be.true; + }); + }); + + describe('Interaction between listen() and close()', () => { + it('should allow close() to take precedence and resolve correctly, while also terminating the listener', async () => { + const connectPromise = handler.connect('ws://test'); + mockWebSocket.triggerOpen(); + await connectPromise; + + const generator = handler.listen(); + const listenPromise = (async () => { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for await (const _ of generator) { + } + })(); + + const closePromise = handler.close(); + + await clock.runAllAsync(); + + await expect(closePromise).to.be.fulfilled; + await expect(listenPromise).to.be.fulfilled; + + expect(mockWebSocket.readyState).to.equal(MockBrowserWebSocket.CLOSED); + }); + }); +}); diff --git a/packages/ai/src/platform/browser/websocket.ts b/packages/ai/src/platform/browser/websocket.ts new file mode 100644 index 00000000000..16e5bd843a9 --- /dev/null +++ b/packages/ai/src/platform/browser/websocket.ts @@ -0,0 +1,188 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { AIError } from '../../errors'; +import { AIErrorCode } from '../../types'; +import { WebSocketHandler } from '../websocket'; + +export function createWebSocketHandler(): WebSocketHandler { + if (typeof WebSocket === 'undefined') { + throw new AIError( + AIErrorCode.UNSUPPORTED, + 'The WebSocket API is not available in this browser-like environment. ' + + 'The "Live" feature is not supported here. It is supported in ' + + 'modern browser windows, Web Workers with WebSocket support, and Node >= 22.' + ); + } + + return new BrowserWebSocketHandler(); +} + +/** + * A WebSocketHandler implementation for the browser environment. + * It uses the native `WebSocket`. + * + * @internal + */ +export class BrowserWebSocketHandler implements WebSocketHandler { + private ws?: WebSocket; + + connect(url: string): Promise { + return new Promise((resolve, reject) => { + this.ws = new WebSocket(url); + this.ws.addEventListener('open', () => resolve(), { once: true }); + this.ws.addEventListener( + 'error', + () => + reject( + new AIError( + AIErrorCode.FETCH_ERROR, + 'Failed to establish WebSocket connection' + ) + ), + { once: true } + ); + }); + } + + send(data: string | ArrayBuffer): void { + if (!this.ws || this.ws.readyState !== WebSocket.OPEN) { + throw new AIError(AIErrorCode.REQUEST_ERROR, 'WebSocket is not open.'); + } + this.ws.send(data); + } + + async *listen(): AsyncGenerator { + if (!this.ws) { + throw new AIError( + AIErrorCode.REQUEST_ERROR, + 'WebSocket is not connected.' + ); + } + + const messageQueue: unknown[] = []; + const errorQueue: Error[] = []; + let resolvePromise: (() => void) | null = null; + let isClosed = false; + + const messageListener = async (event: MessageEvent): Promise => { + let data: string; + if (event.data instanceof Blob) { + data = await event.data.text(); + } else if (typeof event.data === 'string') { + data = event.data; + } else { + errorQueue.push( + new AIError( + AIErrorCode.PARSE_FAILED, + `Failed to parse WebSocket response. Expected data to be a Blob or string, but was ${typeof event.data}.` + ) + ); + if (resolvePromise) { + resolvePromise(); + resolvePromise = null; + } + return; + } + + try { + const obj = JSON.parse(data) as unknown; + messageQueue.push(obj); + } catch (e) { + const err = e as Error; + errorQueue.push( + new AIError( + AIErrorCode.PARSE_FAILED, + `Error parsing WebSocket message to JSON: ${err.message}` + ) + ); + } + + if (resolvePromise) { + resolvePromise(); + resolvePromise = null; + } + }; + + const errorListener = (): void => { + errorQueue.push( + new AIError(AIErrorCode.FETCH_ERROR, 'WebSocket connection error.') + ); + if (resolvePromise) { + resolvePromise(); + resolvePromise = null; + } + }; + + const closeListener = (): void => { + isClosed = true; + if (resolvePromise) { + resolvePromise(); + resolvePromise = null; + } + // Clean up listeners to prevent memory leaks + this.ws?.removeEventListener('message', messageListener); + this.ws?.removeEventListener('close', closeListener); + this.ws?.removeEventListener('error', errorListener); + }; + + this.ws.addEventListener('message', messageListener); + this.ws.addEventListener('close', closeListener); + this.ws.addEventListener('error', errorListener); + + while (!isClosed) { + if (errorQueue.length > 0) { + const error = errorQueue.shift()!; + throw error; + } + if (messageQueue.length > 0) { + yield messageQueue.shift()!; + } else { + await new Promise(resolve => { + resolvePromise = resolve; + }); + } + } + + // If the loop terminated because isClosed is true, check for any final errors + if (errorQueue.length > 0) { + const error = errorQueue.shift()!; + throw error; + } + } + + close(code?: number, reason?: string): Promise { + return new Promise(resolve => { + if (!this.ws) { + return resolve(); + } + + this.ws.addEventListener('close', () => resolve(), { once: true }); + // Calling 'close' during these states results in an error. + if ( + this.ws.readyState === WebSocket.CLOSED || + this.ws.readyState === WebSocket.CONNECTING + ) { + return resolve(); + } + + if (this.ws.readyState !== WebSocket.CLOSING) { + this.ws.close(code, reason); + } + }); + } +} diff --git a/packages/ai/src/platform/node/websocket.test.ts b/packages/ai/src/platform/node/websocket.test.ts new file mode 100644 index 00000000000..6fc6fc49c70 --- /dev/null +++ b/packages/ai/src/platform/node/websocket.test.ts @@ -0,0 +1,143 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { expect, use } from 'chai'; +import sinonChai from 'sinon-chai'; +import chaiAsPromised from 'chai-as-promised'; +import { isNode } from '@firebase/util'; +import { TextEncoder } from 'util'; +import { MockWebSocketServer } from '../../../test-utils/mock-websocket-server'; +import { WebSocketHandler } from '../websocket'; +import { NodeWebSocketHandler } from './websocket'; + +use(sinonChai); +use(chaiAsPromised); + +const TEST_PORT = 9003; +const TEST_URL = `ws://localhost:${TEST_PORT}`; + +describe('NodeWebSocketHandler', () => { + let server: MockWebSocketServer; + let handler: WebSocketHandler; + + // Only run these tests in a Node environment + if (!isNode()) { + return; + } + + before(async () => { + server = new MockWebSocketServer(TEST_PORT); + }); + + after(async () => { + await server.close(); + }); + + beforeEach(() => { + handler = new NodeWebSocketHandler(); + server.reset(); + }); + + afterEach(async () => { + await handler.close().catch(() => {}); + }); + + describe('connect()', () => { + it('should successfully connect to a running server', async () => { + await handler.connect(TEST_URL); + // Allow a brief moment for the server to register the connection + await new Promise(r => setTimeout(r, 50)); + expect(server.connectionCount).to.equal(1); + expect(server.clients.size).to.equal(1); + }); + + it('should reject if the connection fails', async () => { + const wrongPortUrl = `ws://wrongUrl:9000`; + await expect(handler.connect(wrongPortUrl)).to.be.rejected; + }); + }); + + describe('listen()', () => { + beforeEach(async () => { + await handler.connect(TEST_URL); + // Wait for server to see the connection + await new Promise(r => setTimeout(r, 50)); + }); + + it('should yield parsed JSON objects from string data sent by the server', async () => { + const generator = handler.listen(); + const messageObj = { id: 1, text: 'test' }; + + const received: unknown[] = []; + const consumerPromise = (async () => { + for await (const msg of generator) { + received.push(msg); + } + })(); + + // Wait for the listener to be attached + await new Promise(r => setTimeout(r, 50)); + server.broadcast(JSON.stringify(messageObj)); + await new Promise(r => setTimeout(r, 50)); + await handler.close(); // Close client to terminate the loop + + await consumerPromise; + expect(received).to.deep.equal([messageObj]); + }); + + it('should correctly decode UTF-8 binary data sent by the server', async () => { + const generator = handler.listen(); + const messageObj = { text: '你好, 世界 🌍' }; + const encoder = new TextEncoder(); + const bufferData = encoder.encode(JSON.stringify(messageObj)); + + const received: unknown[] = []; + const consumerPromise = (async () => { + for await (const msg of generator) { + received.push(msg); + } + })(); + + await new Promise(r => setTimeout(r, 50)); + // The server's `send` method can handle Buffers/Uint8Arrays + server.clients.forEach(client => client.send(bufferData)); + await new Promise(r => setTimeout(r, 50)); + await handler.close(); + + await consumerPromise; + expect(received).to.deep.equal([messageObj]); + }); + + it('should terminate the generator when the server closes the connection', async () => { + const generator = handler.listen(); + const consumerPromise = (async () => { + // This loop should finish without error when the server closes + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for await (const _ of generator) { + } + })(); + + await new Promise(r => setTimeout(r, 50)); + + await server.close(); + server = new MockWebSocketServer(TEST_PORT); + + // The consumer promise should resolve without timing out + await expect(consumerPromise).to.be.fulfilled; + }); + }); +}); diff --git a/packages/ai/src/platform/node/websocket.ts b/packages/ai/src/platform/node/websocket.ts new file mode 100644 index 00000000000..156ee988785 --- /dev/null +++ b/packages/ai/src/platform/node/websocket.ts @@ -0,0 +1,205 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { AIError } from '../../errors'; +import { AIErrorCode } from '../../types'; +import { WebSocketHandler } from '../websocket'; + +export function createWebSocketHandler(): WebSocketHandler { + if (typeof process === 'object' && process.versions?.node) { + const [major] = process.versions.node.split('.').map(Number); + if (major < 22) { + throw new AIError( + AIErrorCode.UNSUPPORTED, + `The "Live" feature is being used in a Node environment, but the ` + + `runtime version is ${process.versions.node}. This feature requires Node >= 22 ` + + `for native WebSocket support.` + ); + } else if (typeof WebSocket === 'undefined') { + throw new AIError( + AIErrorCode.UNSUPPORTED, + `The "Live" feature is being used in a Node environment that does not offer the ` + + `'WebSocket' API in the global scope.` + ); + } + + return new NodeWebSocketHandler(); + } else { + throw new AIError( + AIErrorCode.UNSUPPORTED, + 'The "Live" feature is not supported in this Node-like environment. It is supported in ' + + 'modern browser windows, Web Workers with WebSocket support, and Node >= 22.' + ); + } +} + +/** + * A WebSocketHandler implementation for Node >= 22. + * + * Node 22 is the minimum version that offers the built-in global `WebSocket` API. + * + * @internal + */ +export class NodeWebSocketHandler implements WebSocketHandler { + private ws?: WebSocket; + + async connect(url: string): Promise { + return new Promise(async (resolve, reject) => { + this.ws = new WebSocket(url); + this.ws.binaryType = 'blob'; + this.ws!.addEventListener('open', () => resolve(), { once: true }); + this.ws!.addEventListener( + 'error', + () => + reject( + new AIError( + AIErrorCode.FETCH_ERROR, + 'Failed to establish WebSocket connection' + ) + ), + { once: true } + ); + }); + } + + send(data: string | ArrayBuffer): void { + if (!this.ws || this.ws.readyState !== WebSocket.OPEN) { + throw new AIError(AIErrorCode.REQUEST_ERROR, 'WebSocket is not open.'); + } + this.ws.send(data); + } + + async *listen(): AsyncGenerator { + if (!this.ws) { + throw new AIError( + AIErrorCode.REQUEST_ERROR, + 'WebSocket is not connected.' + ); + } + + const messageQueue: unknown[] = []; + const errorQueue: Error[] = []; + let resolvePromise: (() => void) | null = null; + let isClosed = false; + + const messageListener = async (event: MessageEvent): Promise => { + let data: string; + if (event.data instanceof Blob) { + data = await event.data.text(); + } else if (typeof event.data === 'string') { + data = event.data; + } else { + errorQueue.push( + new AIError( + AIErrorCode.PARSE_FAILED, + `Failed to parse WebSocket response. Expected data to be a Blob or string, but was ${typeof event.data}.` + ) + ); + if (resolvePromise) { + resolvePromise(); + resolvePromise = null; + } + return; + } + + try { + const obj = JSON.parse(data) as unknown; + messageQueue.push(obj); + } catch (e) { + const err = e as Error; + errorQueue.push( + new AIError( + AIErrorCode.PARSE_FAILED, + `Error parsing WebSocket message to JSON: ${err.message}` + ) + ); + } + + if (resolvePromise) { + resolvePromise(); + resolvePromise = null; + } + }; + + const errorListener = (): void => { + errorQueue.push( + new AIError(AIErrorCode.FETCH_ERROR, 'WebSocket connection error.') + ); + if (resolvePromise) { + resolvePromise(); + resolvePromise = null; + } + }; + + const closeListener = (): void => { + isClosed = true; + if (resolvePromise) { + resolvePromise(); + resolvePromise = null; + } + // Clean up listeners to prevent memory leaks. + this.ws?.removeEventListener('message', messageListener); + this.ws?.removeEventListener('close', closeListener); + this.ws?.removeEventListener('error', errorListener); + }; + + this.ws.addEventListener('message', messageListener); + this.ws.addEventListener('close', closeListener); + this.ws.addEventListener('error', errorListener); + + while (!isClosed) { + if (errorQueue.length > 0) { + const error = errorQueue.shift()!; + throw error; + } + if (messageQueue.length > 0) { + yield messageQueue.shift()!; + } else { + await new Promise(resolve => { + resolvePromise = resolve; + }); + } + } + + // If the loop terminated because isClosed is true, check for any final errors + if (errorQueue.length > 0) { + const error = errorQueue.shift()!; + throw error; + } + } + + close(code?: number, reason?: string): Promise { + return new Promise(resolve => { + if (!this.ws) { + return resolve(); + } + + this.ws.addEventListener('close', () => resolve(), { once: true }); + // Calling 'close' during these states results in an error + if ( + this.ws.readyState === WebSocket.CLOSED || + this.ws.readyState === WebSocket.CONNECTING + ) { + return resolve(); + } + + if (this.ws.readyState !== WebSocket.CLOSING) { + this.ws.close(code, reason); + } + }); + } +} diff --git a/packages/ai/src/platform/websocket.ts b/packages/ai/src/platform/websocket.ts new file mode 100644 index 00000000000..65669c24809 --- /dev/null +++ b/packages/ai/src/platform/websocket.ts @@ -0,0 +1,86 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { NodeWebSocketHandler } from './node/websocket'; + +/** + * A standardized interface for interacting with a WebSocket connection. + * This abstraction allows the SDK to use the appropriate WebSocket implementation + * for the current JS environment (Browser vs. Node) without + * changing the core logic of the `LiveSession`. + * @internal + */ +export interface WebSocketHandler { + /** + * Establishes a connection to the given URL. + * + * @param url The WebSocket URL (e.g., wss://...). + * @returns A promise that resolves on successful connection or rejects on failure. + */ + connect(url: string): Promise; + + /** + * Sends data over the WebSocket. + * + * @param data The string or binary data to send. + */ + send(data: string | ArrayBuffer): void; + + /** + * Returns an async generator that yields parsed JSON objects from the server. + * The yielded type is `unknown` because the handler cannot guarantee the shape of the data. + * The consumer is responsible for type validation. + * The generator terminates when the connection is closed. + * + * @returns A generator that allows consumers to pull messages using a `for await...of` loop. + */ + listen(): AsyncGenerator; + + /** + * Closes the WebSocket connection. + * + * @param code - A numeric status code explaining why the connection is closing. + * @param reason - A human-readable string explaining why the connection is closing. + */ + close(code?: number, reason?: string): Promise; +} + +/** + * NOTE: Imports to this these APIs are renamed to either `platform/browser/websocket.ts` or + * `platform/node/websocket.ts` during build time. + * + * The types are still useful for type-checking during development. + * These are only used during the Node tests, which are ran against non-bundled code. + */ + +/** + * Factory function to create the appropriate WebSocketHandler for the current environment. + * + * This is only a stub for tests. See the real definitions in `./browser/websocket.ts` and + * `./node/websocket.ts`. + * + * @internal + */ +export function createWebSocketHandler(): WebSocketHandler { + if (typeof WebSocket === 'undefined') { + throw Error( + 'WebSocket API is not available. Make sure tests are being ran in Node >= 22.' + ); + } + + return new NodeWebSocketHandler(); +} diff --git a/packages/ai/test-utils/convert-mocks.ts b/packages/ai/test-utils/convert-mocks.ts index 4bac70d1d10..34233a73ace 100644 --- a/packages/ai/test-utils/convert-mocks.ts +++ b/packages/ai/test-utils/convert-mocks.ts @@ -19,6 +19,8 @@ const { readdirSync, readFileSync, writeFileSync } = require('node:fs'); const { join } = require('node:path'); +type BackendName = import('./types').BackendName; // Import type without triggering ES module detection + const MOCK_RESPONSES_DIR_PATH = join( __dirname, 'vertexai-sdk-test-data', @@ -26,8 +28,6 @@ const MOCK_RESPONSES_DIR_PATH = join( ); const MOCK_LOOKUP_OUTPUT_PATH = join(__dirname, 'mocks-lookup.ts'); -type BackendName = 'vertexAI' | 'googleAI'; - const mockDirs: Record = { vertexAI: join(MOCK_RESPONSES_DIR_PATH, 'vertexai'), googleAI: join(MOCK_RESPONSES_DIR_PATH, 'googleai') diff --git a/packages/ai/test-utils/mock-response.ts b/packages/ai/test-utils/mock-response.ts index 5128ddabe74..4963bcbb193 100644 --- a/packages/ai/test-utils/mock-response.ts +++ b/packages/ai/test-utils/mock-response.ts @@ -15,6 +15,7 @@ * limitations under the License. */ +import { BackendName } from './types'; import { vertexAIMocksLookup, googleAIMocksLookup } from './mocks-lookup'; const mockSetMaps: Record> = { diff --git a/packages/ai/test-utils/mock-websocket-server.ts b/packages/ai/test-utils/mock-websocket-server.ts new file mode 100644 index 00000000000..e207c95b2be --- /dev/null +++ b/packages/ai/test-utils/mock-websocket-server.ts @@ -0,0 +1,76 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { WebSocketServer, WebSocket } from 'ws'; + +/** + * A mock WebSocket server for running integration tests against the + * `NodeWebSocketHandler`. It listens on a specified port, accepts connections, + * logs messages, and can broadcast messages to clients. + * + * This should only be used in a Node environment. + * + * @internal + */ +export class MockWebSocketServer { + private wss: WebSocketServer; + clients: Set = new Set(); + receivedMessages: string[] = []; + connectionCount = 0; + + constructor(public port: number) { + this.wss = new WebSocketServer({ port }); + + this.wss.on('connection', ws => { + this.connectionCount++; + this.clients.add(ws); + + ws.on('message', message => { + this.receivedMessages.push(message.toString()); + }); + + ws.on('close', () => { + this.clients.delete(ws); + }); + }); + } + + broadcast(message: string | Buffer): void { + for (const client of this.clients) { + if (client.readyState === WebSocket.OPEN) { + client.send(message, { binary: true }); + } + } + } + + close(): Promise { + return new Promise(resolve => { + for (const client of this.clients) { + client.terminate(); + } + this.wss.close(() => { + this.reset(); + resolve(); + }); + }); + } + + reset(): void { + this.receivedMessages = []; + this.connectionCount = 0; + } +} diff --git a/packages/ai/test-utils/types.ts b/packages/ai/test-utils/types.ts new file mode 100644 index 00000000000..00b99eef55a --- /dev/null +++ b/packages/ai/test-utils/types.ts @@ -0,0 +1,18 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export type BackendName = 'vertexAI' | 'googleAI'; diff --git a/scripts/build/rollup_generate_alias_config.js b/scripts/build/rollup_generate_alias_config.js new file mode 100644 index 00000000000..95c435b9fa4 --- /dev/null +++ b/scripts/build/rollup_generate_alias_config.js @@ -0,0 +1,27 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export function generateAliasConfig(platform) { + return { + entries: [ + { + find: /^(.*)\/platform\/([^.\/]*)(\.ts)?$/, + replacement: `$1\/platform/${platform}/$2.ts` + } + ] + }; +} diff --git a/yarn.lock b/yarn.lock index fe69e44aead..cfc5a206b1f 100644 --- a/yarn.lock +++ b/yarn.lock @@ -3321,6 +3321,13 @@ tapable "^2.2.0" webpack "^5" +"@types/ws@8.18.1": + version "8.18.1" + resolved "https://registry.npmjs.org/@types/ws/-/ws-8.18.1.tgz#48464e4bf2ddfd17db13d845467f6070ffea4aa9" + integrity sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg== + dependencies: + "@types/node" "*" + "@types/yargs-parser@*": version "21.0.3" resolved "https://registry.npmjs.org/@types/yargs-parser/-/yargs-parser-21.0.3.tgz#815e30b786d2e8f0dcd85fd5bcf5e1a04d008f15" @@ -15139,7 +15146,7 @@ string-argv@~0.3.1: resolved "https://registry.npmjs.org/string-argv/-/string-argv-0.3.2.tgz#2b6d0ef24b656274d957d54e0a4bbf6153dc02b6" integrity sha512-aqD2Q0144Z+/RqG52NeHEkZauTAUWJO8c6yTftGJKO3Tja5tUgIfmIl6kExvhtxSDP7fXB6DvzkfMpCd/F3G+Q== -"string-width-cjs@npm:string-width@^4.2.0": +"string-width-cjs@npm:string-width@^4.2.0", "string-width@^1.0.2 || 2 || 3 || 4", string-width@^4.0.0, string-width@^4.1.0, string-width@^4.2.0, string-width@^4.2.2, string-width@^4.2.3: version "4.2.3" resolved "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010" integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g== @@ -15157,15 +15164,6 @@ string-width@^1.0.1, string-width@^1.0.2: is-fullwidth-code-point "^1.0.0" strip-ansi "^3.0.0" -"string-width@^1.0.2 || 2 || 3 || 4", string-width@^4.0.0, string-width@^4.1.0, string-width@^4.2.0, string-width@^4.2.2, string-width@^4.2.3: - version "4.2.3" - resolved "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010" - integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g== - dependencies: - emoji-regex "^8.0.0" - is-fullwidth-code-point "^3.0.0" - strip-ansi "^6.0.1" - string-width@^2.1.1: version "2.1.1" resolved "https://registry.npmjs.org/string-width/-/string-width-2.1.1.tgz#ab93f27a8dc13d28cac815c462143a6d9012ae9e" @@ -15229,7 +15227,7 @@ string_decoder@~1.1.1: dependencies: safe-buffer "~5.1.0" -"strip-ansi-cjs@npm:strip-ansi@^6.0.1": +"strip-ansi-cjs@npm:strip-ansi@^6.0.1", strip-ansi@^6.0.0, strip-ansi@^6.0.1: version "6.0.1" resolved "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz#9e26c63d30f53443e9489495b2105d37b67a85d9" integrity sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A== @@ -15250,13 +15248,6 @@ strip-ansi@^4.0.0: dependencies: ansi-regex "^3.0.0" -strip-ansi@^6.0.0, strip-ansi@^6.0.1: - version "6.0.1" - resolved "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz#9e26c63d30f53443e9489495b2105d37b67a85d9" - integrity sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A== - dependencies: - ansi-regex "^5.0.1" - strip-ansi@^7.0.1: version "7.1.0" resolved "https://registry.npmjs.org/strip-ansi/-/strip-ansi-7.1.0.tgz#d5b6568ca689d8561370b0707685d22434faff45" @@ -16929,7 +16920,7 @@ workerpool@6.2.0: resolved "https://registry.npmjs.org/workerpool/-/workerpool-6.2.0.tgz#827d93c9ba23ee2019c3ffaff5c27fccea289e8b" integrity sha512-Rsk5qQHJ9eowMH28Jwhe8HEbmdYDX4lwoMWshiCXugjtHqMD9ZbiqSDLxcsfdqsETPzVUtX5s1Z5kStiIM6l4A== -"wrap-ansi-cjs@npm:wrap-ansi@^7.0.0": +"wrap-ansi-cjs@npm:wrap-ansi@^7.0.0", wrap-ansi@^7.0.0: version "7.0.0" resolved "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-7.0.0.tgz#67e145cff510a6a6984bdf1152911d69d2eb9e43" integrity sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q== @@ -16963,15 +16954,6 @@ wrap-ansi@^6.0.1, wrap-ansi@^6.2.0: string-width "^4.1.0" strip-ansi "^6.0.0" -wrap-ansi@^7.0.0: - version "7.0.0" - resolved "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-7.0.0.tgz#67e145cff510a6a6984bdf1152911d69d2eb9e43" - integrity sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q== - dependencies: - ansi-styles "^4.0.0" - string-width "^4.1.0" - strip-ansi "^6.0.0" - wrap-ansi@^8.1.0: version "8.1.0" resolved "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-8.1.0.tgz#56dc22368ee570face1b49819975d9b9a5ead214" @@ -17055,6 +17037,11 @@ write-pkg@^4.0.0: type-fest "^0.4.1" write-json-file "^3.2.0" +ws@8.18.3: + version "8.18.3" + resolved "https://registry.npmjs.org/ws/-/ws-8.18.3.tgz#b56b88abffde62791c639170400c93dcb0c95472" + integrity sha512-PEIGCY5tSlUt50cqyMXfCzX+oOPqN0vuGqWzbcJ2xvnkzkq46oOpz7dQaTDBdfICb4N14+GARUDw2XV2N4tvzg== + ws@^7.5.10: version "7.5.10" resolved "https://registry.npmjs.org/ws/-/ws-7.5.10.tgz#58b5c20dc281633f6c19113f39b349bd8bd558d9"