diff --git a/common/api-review/generative-ai-server.api.md b/common/api-review/generative-ai-server.api.md index b31a89e37..14f040701 100644 --- a/common/api-review/generative-ai-server.api.md +++ b/common/api-review/generative-ai-server.api.md @@ -4,6 +4,7 @@ ```ts + // Warning: (ae-incompatible-release-tags) The symbol "ArraySchema" is marked as @public, but its signature references "BaseSchema" which is marked as @internal // // @public diff --git a/common/api-review/generative-ai.api.md b/common/api-review/generative-ai.api.md index 5880a650d..302ed42c3 100644 --- a/common/api-review/generative-ai.api.md +++ b/common/api-review/generative-ai.api.md @@ -525,7 +525,7 @@ export class GenerativeModel { countTokens(request: CountTokensRequest | string | Array, requestOptions?: SingleRequestOptions): Promise; embedContent(request: EmbedContentRequest | string | Array, requestOptions?: SingleRequestOptions): Promise; generateContent(request: GenerateContentRequest | string | Array, requestOptions?: SingleRequestOptions): Promise; - generateContentStream(request: GenerateContentRequest | string | Array, requestOptions?: SingleRequestOptions): Promise; + generateContentStream(request: GenerateContentRequest | string | Array, requestOptions?: SingleRequestOptions, callbacks?: StreamCallbacks): Promise; // (undocumented) generationConfig: GenerationConfig; // (undocumented) @@ -841,6 +841,16 @@ export interface StartChatParams extends BaseParams { tools?: Tool[]; } +// @public +export interface StreamCallbacks { + // (undocumented) + onData?: (data: string) => void; + // (undocumented) + onEnd?: (data: string) => void; + // (undocumented) + onError?: (error: Error) => void; +} + // @public export type StringSchema = SimpleStringSchema | EnumStringSchema; diff --git a/package.json b/package.json index 54266a7ca..b04d9137f 100644 --- a/package.json +++ b/package.json @@ -35,9 +35,9 @@ "build": "rollup -c && npm run api-report", "test": "npm run lint && npm run test:node:unit", "test:web:integration": "npm run build && npx web-test-runner", - "test:node:unit": "TS_NODE_COMPILER_OPTIONS='{\"module\":\"commonjs\"}' mocha \"src/**/*.test.ts\"", + "test:node:unit": "cross-env TS_NODE_COMPILER_OPTIONS=\"{\\\"module\\\":\\\"commonjs\\\"}\" mocha \"src/**/*.test.ts\"", "test:node:integration": "npm run build && TS_NODE_COMPILER_OPTIONS='{\"module\":\"commonjs\"}' mocha \"test-integration/node/**/*.test.ts\"", - "lint": "eslint -c .eslintrc.js '**/*.ts' --ignore-path './.gitignore'", + "lint": "eslint -c .eslintrc.js \"**/*.ts\" --ignore-path .gitignore", "api-report": "api-extractor run -c api-extractor.json --local --verbose && api-extractor run -c api-extractor.server.json --local --verbose", "docs": "npm run build && npx api-documenter markdown -i ./temp/main -o ./docs/reference/main && npx api-documenter markdown -i ./temp/server -o ./docs/reference/server", "format": "TS_NODE_COMPILER_OPTIONS='{\"module\":\"nodenext\"}' npx ts-node scripts/run-format.ts", @@ -63,6 +63,7 @@ "chai": "^4.3.10", "chai-as-promised": "^7.1.1", "chai-deep-equal-ignore-undefined": "^1.1.1", + "cross-env": "^7.0.3", "eslint": "^8.52.0", "eslint-plugin-import": "^2.29.0", "eslint-plugin-unused-imports": "^3.0.0", diff --git a/src/methods/generate-content.ts b/src/methods/generate-content.ts index 8c1e13934..6e37e68b0 100644 --- a/src/methods/generate-content.ts +++ b/src/methods/generate-content.ts @@ -21,6 +21,7 @@ import { GenerateContentResult, GenerateContentStreamResult, SingleRequestOptions, + StreamCallbacks, } from "../../types"; import { Task, makeModelRequest } from "../requests/request"; import { addHelpers } from "../requests/response-helpers"; @@ -31,6 +32,7 @@ export async function generateContentStream( model: string, params: GenerateContentRequest, requestOptions: SingleRequestOptions, + callbacks?: StreamCallbacks, ): Promise { const response = await makeModelRequest( model, @@ -40,7 +42,7 @@ export async function generateContentStream( JSON.stringify(params), requestOptions, ); - return processStream(response); + return processStream(response, callbacks); } export async function generateContent( diff --git a/src/models/generative-model.ts b/src/models/generative-model.ts index 7cd3fe622..5c4aaf08e 100644 --- a/src/models/generative-model.ts +++ b/src/models/generative-model.ts @@ -38,6 +38,7 @@ import { SafetySetting, SingleRequestOptions, StartChatParams, + StreamCallbacks, Tool, ToolConfig, } from "../../types"; @@ -132,6 +133,7 @@ export class GenerativeModel { async generateContentStream( request: GenerateContentRequest | string | Array, requestOptions: SingleRequestOptions = {}, + callbacks?: StreamCallbacks, ): Promise { const formattedParams = formatGenerateContentInput(request); const generativeModelRequestOptions: SingleRequestOptions = { @@ -151,6 +153,7 @@ export class GenerativeModel { ...formattedParams, }, generativeModelRequestOptions, + callbacks, ); } diff --git a/src/requests/stream-reader.test.ts b/src/requests/stream-reader.test.ts index 043745544..a980bcf1b 100644 --- a/src/requests/stream-reader.test.ts +++ b/src/requests/stream-reader.test.ts @@ -340,6 +340,65 @@ describe("processStream", () => { } expect(foundCitationMetadata).to.be.true; }); + + + describe("callbacks", () => { + it("chunk callbacks were called", (done) => { + const fakeResponse = getMockResponseStreaming( + "streaming-success-citations.txt", + ); + processStream(fakeResponse as Response, { + onData: (data: string) => { + expect(data).to.not.be.empty; + }, + onEnd:()=> done(), + }); + }); + + it("error callbacks were called", (done) => { + const fakeResponse = getMockResponseStreaming( + "streaming-failure-prompt-blocked-safety.txt", + ); + processStream(fakeResponse as Response, { + onError: (error: Error) => { + expect(error).to.be.instanceOf(GoogleGenerativeAIError); + done(); + }, + onEnd:()=> done(), + }); + }); + + + it("end callbacks were called", (done) => { + const fakeResponse = getMockResponseStreaming( + "streaming-success-basic-reply-short.txt", + ); + processStream(fakeResponse as Response, { + onEnd:(data)=> { + expect(data).to.include("Cheyenne"); + done(); + } + }); + }); + + it("all callbacks were called", (done) => { + const fakeResponse = getMockResponseStreaming( + "streaming-success-basic-reply-long.txt", + ); + processStream(fakeResponse as Response, { + onEnd:(data)=> { + expect(data).to.include("**Cats:**"); + expect(data).to.include("to their owners."); + done(); + }, + onData:(data:string)=> { + expect(data).to.not.be.empty; + }, + + }); + }); + +}); }); describe("aggregateResponses", () => { diff --git a/src/requests/stream-reader.ts b/src/requests/stream-reader.ts index 0d1a24e6f..55374c222 100644 --- a/src/requests/stream-reader.ts +++ b/src/requests/stream-reader.ts @@ -21,6 +21,7 @@ import { GenerateContentResponse, GenerateContentStreamResult, Part, + StreamCallbacks, } from "../../types"; import { GoogleGenerativeAIAbortError, @@ -38,7 +39,7 @@ const responseLineRE = /^data\: (.*)(?:\n\n|\r\r|\r\n\r\n)/; * * @param response - Response from a fetch call */ -export function processStream(response: Response): GenerateContentStreamResult { +export function processStream(response: Response, callbacks?: StreamCallbacks): GenerateContentStreamResult { const inputStream = response.body!.pipeThrough( new TextDecoderStream("utf8", { fatal: true }), ); @@ -47,21 +48,29 @@ export function processStream(response: Response): GenerateContentStreamResult { const [stream1, stream2] = responseStream.tee(); return { stream: generateResponseSequence(stream1), - response: getResponsePromise(stream2), + response: getResponsePromise(stream2, callbacks), }; } async function getResponsePromise( stream: ReadableStream, + callbacks?: StreamCallbacks, ): Promise { - const allResponses: GenerateContentResponse[] = []; - const reader = stream.getReader(); - while (true) { - const { done, value } = await reader.read(); - if (done) { - return addHelpers(aggregateResponses(allResponses)); - } - allResponses.push(value); + try { + const allResponses: GenerateContentResponse[] = []; + const reader = stream.getReader(); + while (true) { + const { done, value } = await reader.read(); + if (done) { + callbacks?.onEnd((allResponses.reduce((acc, curr) => acc + addHelpers(curr).text(), ""))); + return addHelpers(aggregateResponses(allResponses)); + } + allResponses.push(value); + callbacks?.onData?.(addHelpers(value).text()); + } + } catch (error) { + callbacks?.onError?.(error); + throw error; } } diff --git a/types/requests.ts b/types/requests.ts index 81285bc20..d1dfec93e 100644 --- a/types/requests.ts +++ b/types/requests.ts @@ -209,6 +209,16 @@ export interface RequestOptions { customHeaders?: Headers | Record; } +/** + * Callbacks for streaming responses. + * @public + */ +export interface StreamCallbacks { + onData?: (data: string) => void; + onEnd?: (data: string) => void; + onError?: (error: Error) => void; +} + /** * Params passed to atomic asynchronous operations. * @public