From 839244cc3f725294c1130656185c3e62129aa021 Mon Sep 17 00:00:00 2001 From: Deyaaeldeen Almahallawi Date: Wed, 22 Jan 2025 15:04:58 -0800 Subject: [PATCH] [Azure] Support Realtime API --- README.md | 27 ++++++- examples/{azure.ts => azure/chat.ts} | 3 +- examples/azure/websocket.ts | 61 ++++++++++++++++ examples/azure/ws.ts | 68 +++++++++++++++++ examples/package.json | 1 + src/beta/realtime/websocket.ts | 105 ++++++++++++++++++++++++++- src/beta/realtime/ws.ts | 100 ++++++++++++++++++++++++- src/index.ts | 10 +-- 8 files changed, 366 insertions(+), 9 deletions(-) rename examples/{azure.ts => azure/chat.ts} (91%) create mode 100644 examples/azure/websocket.ts create mode 100644 examples/azure/ws.ts diff --git a/README.md b/README.md index 3bd386e99..f81ff4207 100644 --- a/README.md +++ b/README.md @@ -499,7 +499,7 @@ const credential = new DefaultAzureCredential(); const scope = 'https://cognitiveservices.azure.com/.default'; const azureADTokenProvider = getBearerTokenProvider(credential, scope); -const openai = new AzureOpenAI({ azureADTokenProvider }); +const openai = new AzureOpenAI({ azureADTokenProvider, apiVersion: "" }); const result = await openai.chat.completions.create({ model: 'gpt-4o', @@ -509,6 +509,31 @@ const result = await openai.chat.completions.create({ console.log(result.choices[0]!.message?.content); ``` +### Realtime API +This SDK provides real-time streaming capabilities for Azure OpenAI through the `AzureOpenAIRealtimeWS` and `AzureOpenAIRealtimeWebSocket` classes. These classes parallel the `OpenAIRealtimeWS` and `OpenAIRealtimeWebSocket` clients described previously, but they are specifically adapted for Azure OpenAI endpoints. + +To utilize the real-time features, begin by creating a fully configured `AzureOpenAI` client and passing it into either `AzureOpenAIRealtimeWS` or `AzureOpenAIRealtimeWebSocket`. For example: + +```ts +const cred = new DefaultAzureCredential(); +const scope = 'https://cognitiveservices.azure.com/.default'; +const deploymentName = 'gpt-4o-realtime-preview-1001'; +const azureADTokenProvider = getBearerTokenProvider(cred, scope); +const client = new AzureOpenAI({ + azureADTokenProvider, + apiVersion: '2024-10-01-preview', + deployment: deploymentName, +}); +const rt = new AzureOpenAIRealtimeWS(client); +``` + +Once the real-time client has been created, open its underlying WebSocket connection by invoking the open method: +```ts +await rt.open(); +``` + +With the connection established, you can then begin sending requests and receiving streaming responses in real time. + ### Retries Certain errors will be automatically retried 2 times by default, with a short exponential backoff. diff --git a/examples/azure.ts b/examples/azure/chat.ts similarity index 91% rename from examples/azure.ts rename to examples/azure/chat.ts index 5fe1718fa..46df820f8 100755 --- a/examples/azure.ts +++ b/examples/azure/chat.ts @@ -2,6 +2,7 @@ import { AzureOpenAI } from 'openai'; import { getBearerTokenProvider, DefaultAzureCredential } from '@azure/identity'; +import 'dotenv/config'; // Corresponds to your Model deployment within your OpenAI resource, e.g. gpt-4-1106-preview // Navigate to the Azure OpenAI Studio to deploy a model. @@ -13,7 +14,7 @@ const azureADTokenProvider = getBearerTokenProvider(credential, scope); // Make sure to set AZURE_OPENAI_ENDPOINT with the endpoint of your Azure resource. // You can find it in the Azure Portal. -const openai = new AzureOpenAI({ azureADTokenProvider }); +const openai = new AzureOpenAI({ azureADTokenProvider, apiVersion: '2024-10-01-preview' }); async function main() { console.log('Non-streaming:'); diff --git a/examples/azure/websocket.ts b/examples/azure/websocket.ts new file mode 100644 index 000000000..7543b64c0 --- /dev/null +++ b/examples/azure/websocket.ts @@ -0,0 +1,61 @@ +import { AzureOpenAIRealtimeWebSocket } from 'openai/beta/realtime/websocket'; +import { AzureOpenAI } from 'openai'; +import { DefaultAzureCredential, getBearerTokenProvider } from '@azure/identity'; +import 'dotenv/config'; + +async function main() { + const cred = new DefaultAzureCredential(); + const scope = 'https://cognitiveservices.azure.com/.default'; + const deploymentName = 'gpt-4o-realtime-preview-1001'; + const azureADTokenProvider = getBearerTokenProvider(cred, scope); + const client = new AzureOpenAI({ + azureADTokenProvider, + apiVersion: '2024-10-01-preview', + deployment: deploymentName, + }); + const rt = new AzureOpenAIRealtimeWebSocket(client); + await rt.open(); + + // access the underlying `ws.WebSocket` instance + rt.socket.addEventListener('open', () => { + console.log('Connection opened!'); + rt.send({ + type: 'session.update', + session: { + modalities: ['text'], + model: 'gpt-4o-realtime-preview', + }, + }); + + rt.send({ + type: 'conversation.item.create', + item: { + type: 'message', + role: 'user', + content: [{ type: 'input_text', text: 'Say a couple paragraphs!' }], + }, + }); + + rt.send({ type: 'response.create' }); + }); + + rt.on('error', (err) => { + // in a real world scenario this should be logged somewhere as you + // likely want to continue procesing events regardless of any errors + throw err; + }); + + rt.on('session.created', (event) => { + console.log('session created!', event.session); + console.log(); + }); + + rt.on('response.text.delta', (event) => process.stdout.write(event.delta)); + rt.on('response.text.done', () => console.log()); + + rt.on('response.done', () => rt.close()); + + rt.socket.addEventListener('close', () => console.log('\nConnection closed!')); +} + +main(); diff --git a/examples/azure/ws.ts b/examples/azure/ws.ts new file mode 100644 index 000000000..786243487 --- /dev/null +++ b/examples/azure/ws.ts @@ -0,0 +1,68 @@ +import { DefaultAzureCredential, getBearerTokenProvider } from '@azure/identity'; +import { AzureOpenAIRealtimeWS } from 'openai/beta/realtime/ws'; +import { AzureOpenAI } from 'openai'; +import 'dotenv/config'; + +async function main() { + const cred = new DefaultAzureCredential(); + const scope = 'https://cognitiveservices.azure.com/.default'; + const deploymentName = 'gpt-4o-realtime-preview-1001'; + const azureADTokenProvider = getBearerTokenProvider(cred, scope); + const client = new AzureOpenAI({ + azureADTokenProvider, + apiVersion: '2024-10-01-preview', + deployment: deploymentName, + }); + const rt = new AzureOpenAIRealtimeWS(client); + await rt.open(); + + // access the underlying `ws.WebSocket` instance + rt.socket.on('open', () => { + console.log('Connection opened!'); + rt.send({ + type: 'session.update', + session: { + modalities: ['text'], + model: 'gpt-4o-realtime-preview', + }, + }); + rt.send({ + type: 'session.update', + session: { + modalities: ['text'], + model: 'gpt-4o-realtime-preview', + }, + }); + + rt.send({ + type: 'conversation.item.create', + item: { + type: 'message', + role: 'user', + content: [{ type: 'input_text', text: 'Say a couple paragraphs!' }], + }, + }); + + rt.send({ type: 'response.create' }); + }); + + rt.on('error', (err) => { + // in a real world scenario this should be logged somewhere as you + // likely want to continue procesing events regardless of any errors + throw err; + }); + + rt.on('session.created', (event) => { + console.log('session created!', event.session); + console.log(); + }); + + rt.on('response.text.delta', (event) => process.stdout.write(event.delta)); + rt.on('response.text.done', () => console.log()); + + rt.on('response.done', () => rt.close()); + + rt.socket.on('close', () => console.log('\nConnection closed!')); +} + +main(); diff --git a/examples/package.json b/examples/package.json index b8c34ac45..70ec2c523 100644 --- a/examples/package.json +++ b/examples/package.json @@ -7,6 +7,7 @@ "private": true, "dependencies": { "@azure/identity": "^4.2.0", + "dotenv": "^16.4.7", "express": "^4.18.2", "next": "^14.1.1", "openai": "file:..", diff --git a/src/beta/realtime/websocket.ts b/src/beta/realtime/websocket.ts index e0853779d..a36bcb993 100644 --- a/src/beta/realtime/websocket.ts +++ b/src/beta/realtime/websocket.ts @@ -1,4 +1,4 @@ -import { OpenAI } from '../../index'; +import { AzureOpenAI, OpenAI } from '../../index'; import { OpenAIError } from '../../error'; import * as Core from '../../core'; import type { RealtimeClientEvent, RealtimeServerEvent } from '../../resources/beta/realtime/realtime'; @@ -95,3 +95,106 @@ export class OpenAIRealtimeWebSocket extends OpenAIRealtimeEmitter { } } } + +export class AzureOpenAIRealtimeWebSocket extends OpenAIRealtimeEmitter { + socket: _WebSocket; + + constructor( + private client: AzureOpenAI, + private options: { + deploymentName?: string; + } = {}, + ) { + super(); + } + + async open(): Promise { + async function getUrl({ + apiVersion, + baseURL, + deploymentName, + apiKey, + token, + }: { + baseURL: string; + deploymentName: string; + apiVersion: string; + apiKey: string; + token: string | undefined; + }): Promise { + const path = '/realtime'; + const url = new URL(baseURL + (baseURL.endsWith('/') ? path.slice(1) : path)); + url.protocol = 'wss'; + url.searchParams.set('api-version', apiVersion); + url.searchParams.set('deployment', deploymentName); + if (apiKey !== '') { + url.searchParams.set('api-key', apiKey); + } else { + if (token) { + url.searchParams.set('Authorization', `Bearer ${token}`); + } else { + throw new Error('AzureOpenAI is not instantiated correctly. No API key or token provided.'); + } + } + return url; + } + const deploymentName = this.client.deploymentName ?? this.options.deploymentName; + if (!deploymentName) { + throw new Error('No deployment name provided'); + } + const url = await getUrl({ + apiVersion: this.client.apiVersion, + baseURL: this.client.baseURL, + deploymentName, + apiKey: this.client.apiKey, + token: await this.client.getAzureADToken(), + }); + // @ts-ignore + this.socket = new WebSocket(url, ['realtime', 'openai-beta.realtime-v1']); + + this.socket.addEventListener('message', (websocketEvent: MessageEvent) => { + const event = (() => { + try { + return JSON.parse(websocketEvent.data.toString()) as RealtimeServerEvent; + } catch (err) { + this._onError(null, 'could not parse websocket event', err); + return null; + } + })(); + + if (event) { + this._emit('event', event); + + if (event.type === 'error') { + this._onError(event); + } else { + // @ts-expect-error TS isn't smart enough to get the relationship right here + this._emit(event.type, event); + } + } + }); + + this.socket.addEventListener('error', (event: any) => { + this._onError(null, event.message, null); + }); + } + + send(event: RealtimeClientEvent) { + if (!this.socket) { + throw new Error('Socket is not open, call open() first'); + } + try { + this.socket.send(JSON.stringify(event)); + } catch (err) { + this._onError(null, 'could not send data', err); + } + } + + close(props?: { code: number; reason: string }) { + try { + this.socket?.close(props?.code ?? 1000, props?.reason ?? 'OK'); + } catch (err) { + this._onError(null, 'could not close the connection', err); + } + } +} diff --git a/src/beta/realtime/ws.ts b/src/beta/realtime/ws.ts index 631a36cd2..9e33a102f 100644 --- a/src/beta/realtime/ws.ts +++ b/src/beta/realtime/ws.ts @@ -1,5 +1,5 @@ import * as WS from 'ws'; -import { OpenAI } from '../../index'; +import { AzureOpenAI, OpenAI } from '../../index'; import type { RealtimeClientEvent, RealtimeServerEvent } from '../../resources/beta/realtime/realtime'; import { OpenAIRealtimeEmitter, buildRealtimeURL } from './internal-base'; @@ -67,3 +67,101 @@ export class OpenAIRealtimeWS extends OpenAIRealtimeEmitter { } } } + +export class AzureOpenAIRealtimeWS extends OpenAIRealtimeEmitter { + url: URL; + socket: WS.WebSocket; + + constructor( + private client: AzureOpenAI, + private props: { deploymentName?: string; options?: WS.ClientOptions | undefined } = {}, + ) { + super(); + const path = '/realtime'; + const baseURL = client.baseURL; + const url = new URL(baseURL + (baseURL.endsWith('/') ? path.slice(1) : path)); + url.protocol = 'wss'; + url.searchParams.set('api-version', client.apiVersion); + const deploymentName = props.deploymentName ?? client.deploymentName; + if (!deploymentName) { + throw new Error('AzureOpenAIRealtimeWS requires a deployment name'); + } + url.searchParams.set('deployment', deploymentName); + this.url = url; + this.socket = undefined as any; + } + + async open(): Promise { + const headers = { + ...this.props.options?.headers, + 'OpenAI-Beta': 'realtime=v1', + }; + if (this.client.apiKey !== '') { + this.socket = new WS.WebSocket(this.url, { + ...this.props.options, + headers: { + ...headers, + 'api-key': this.client.apiKey, + }, + }); + } else { + const token = await this.client.getAzureADToken(); + if (token) { + this.socket = new WS.WebSocket(this.url, { + ...this.props.options, + headers: { + ...headers, + Authorization: `Bearer ${token}`, + }, + }); + } else { + throw new Error('AzureOpenAI is not instantiated correctly. No API key or token provided.'); + } + } + + this.socket.on('message', (wsEvent) => { + const event = (() => { + try { + return JSON.parse(wsEvent.toString()) as RealtimeServerEvent; + } catch (err) { + this._onError(null, 'could not parse websocket event', err); + return null; + } + })(); + + if (event) { + this._emit('event', event); + + if (event.type === 'error') { + this._onError(event); + } else { + // @ts-expect-error TS isn't smart enough to get the relationship right here + this._emit(event.type, event); + } + } + }); + + this.socket.on('error', (err) => { + this._onError(null, err.message, err); + }); + } + + send(event: RealtimeClientEvent) { + if (!this.socket) { + throw new Error('Socket is not open, call open() first'); + } + try { + this.socket.send(JSON.stringify(event)); + } catch (err) { + this._onError(null, 'could not send data', err); + } + } + + close(props?: { code: number; reason: string }) { + try { + this.socket?.close(props?.code ?? 1000, props?.reason ?? 'OK'); + } catch (err) { + this._onError(null, 'could not close the connection', err); + } + } +} diff --git a/src/index.ts b/src/index.ts index 944def00f..6f0a9dbb0 100644 --- a/src/index.ts +++ b/src/index.ts @@ -491,7 +491,7 @@ export interface AzureClientOptions extends ClientOptions { /** API Client for interfacing with the Azure OpenAI API. */ export class AzureOpenAI extends OpenAI { private _azureADTokenProvider: (() => Promise) | undefined; - private _deployment: string | undefined; + deploymentName: string | undefined; apiVersion: string = ''; /** * API Client for interfacing with the Azure OpenAI API. @@ -574,7 +574,7 @@ export class AzureOpenAI extends OpenAI { this._azureADTokenProvider = azureADTokenProvider; this.apiVersion = apiVersion; - this._deployment = deployment; + this.deploymentName = deployment; } override buildRequest( @@ -589,7 +589,7 @@ export class AzureOpenAI extends OpenAI { if (!Core.isObj(options.body)) { throw new Error('Expected request body to be an object'); } - const model = this._deployment || options.body['model']; + const model = this.deploymentName || options.body['model']; if (model !== undefined && !this.baseURL.includes('/deployments')) { options.path = `/deployments/${model}${options.path}`; } @@ -597,7 +597,7 @@ export class AzureOpenAI extends OpenAI { return super.buildRequest(options, props); } - private async _getAzureADToken(): Promise { + async getAzureADToken(): Promise { if (typeof this._azureADTokenProvider === 'function') { const token = await this._azureADTokenProvider(); if (!token || typeof token !== 'string') { @@ -624,7 +624,7 @@ export class AzureOpenAI extends OpenAI { if (opts.headers?.['api-key']) { return super.prepareOptions(opts); } - const token = await this._getAzureADToken(); + const token = await this.getAzureADToken(); opts.headers ??= {}; if (token) { opts.headers['Authorization'] = `Bearer ${token}`;