From a7840580f1f38155ddb7e8bb659f33913875e120 Mon Sep 17 00:00:00 2001 From: sakshi pimpale Date: Tue, 25 Mar 2025 13:27:05 +0530 Subject: [PATCH] Added StreamCallbacks parameter to generateContentStream() --- common/api-review/generative-ai-server.api.md | 4 +- common/api-review/generative-ai.api.md | 12 +++- package.json | 7 +- src/methods/chat-session.test.ts | 4 +- src/methods/generate-content.test.ts | 4 +- src/models/generative-model.test.ts | 2 +- src/models/generative-model.ts | 5 ++ src/requests/request-helpers.test.ts | 2 +- src/requests/request.test.ts | 4 +- src/requests/response-helpers.test.ts | 2 +- src/requests/stream-reader.test.ts | 70 ++++++++++++++++++- src/requests/stream-reader.ts | 18 +++-- src/server/cache-manager.test.ts | 4 +- src/server/file-manager.test.ts | 4 +- src/server/request.test.ts | 4 +- tsconfig.json | 7 +- types/requests.ts | 10 +++ 17 files changed, 133 insertions(+), 30 deletions(-) diff --git a/common/api-review/generative-ai-server.api.md b/common/api-review/generative-ai-server.api.md index b31a89e37..d7639ade0 100644 --- a/common/api-review/generative-ai-server.api.md +++ b/common/api-review/generative-ai-server.api.md @@ -4,6 +4,8 @@ ```ts +/// + // Warning: (ae-incompatible-release-tags) The symbol "ArraySchema" is marked as @public, but its signature references "BaseSchema" which is marked as @internal // // @public @@ -31,8 +33,6 @@ export interface BooleanSchema extends BaseSchema { type: typeof SchemaType.BOOLEAN; } -/// - // @public export interface CachedContent extends CachedContentBase { createTime?: string; diff --git a/common/api-review/generative-ai.api.md b/common/api-review/generative-ai.api.md index 5880a650d..302ed42c3 100644 --- a/common/api-review/generative-ai.api.md +++ b/common/api-review/generative-ai.api.md @@ -525,7 +525,7 @@ export class GenerativeModel { countTokens(request: CountTokensRequest | string | Array, requestOptions?: SingleRequestOptions): Promise; embedContent(request: EmbedContentRequest | string | Array, requestOptions?: SingleRequestOptions): Promise; generateContent(request: GenerateContentRequest | string | Array, requestOptions?: SingleRequestOptions): Promise; - generateContentStream(request: GenerateContentRequest | string | Array, requestOptions?: SingleRequestOptions): Promise; + generateContentStream(request: GenerateContentRequest | string | Array, requestOptions?: SingleRequestOptions, callbacks?: StreamCallbacks): Promise; // (undocumented) generationConfig: GenerationConfig; // (undocumented) @@ -841,6 +841,16 @@ export interface StartChatParams extends BaseParams { tools?: Tool[]; } +// @public +export interface StreamCallbacks { + // (undocumented) + onData?: (data: string) => void; + // (undocumented) + onEnd?: (data: string) => void; + // (undocumented) + onError?: (error: Error) => void; +} + // @public export type StringSchema = SimpleStringSchema | EnumStringSchema; diff --git a/package.json b/package.json index 54266a7ca..5242eb42f 100644 --- a/package.json +++ b/package.json @@ -35,9 +35,9 @@ "build": "rollup -c && npm run api-report", "test": "npm run lint && npm run test:node:unit", "test:web:integration": "npm run build && npx web-test-runner", - "test:node:unit": "TS_NODE_COMPILER_OPTIONS='{\"module\":\"commonjs\"}' mocha \"src/**/*.test.ts\"", + "test:node:unit": "mocha --require ts-node/register --require tsconfig-paths/register \"src/**/*.test.ts\"", "test:node:integration": "npm run build && TS_NODE_COMPILER_OPTIONS='{\"module\":\"commonjs\"}' mocha \"test-integration/node/**/*.test.ts\"", - "lint": "eslint -c .eslintrc.js '**/*.ts' --ignore-path './.gitignore'", + "lint": "eslint -c .eslintrc.js **/*.ts --ignore-path .gitignore", "api-report": "api-extractor run -c api-extractor.json --local --verbose && api-extractor run -c api-extractor.server.json --local --verbose", "docs": "npm run build && npx api-documenter markdown -i ./temp/main -o ./docs/reference/main && npx api-documenter markdown -i ./temp/server -o ./docs/reference/server", "format": "TS_NODE_COMPILER_OPTIONS='{\"module\":\"nodenext\"}' npx ts-node scripts/run-format.ts", @@ -61,8 +61,9 @@ "@web/dev-server-esbuild": "^1.0.1", "@web/test-runner": "^0.18.0", "chai": "^4.3.10", - "chai-as-promised": "^7.1.1", + "chai-as-promised": "^7.1.2", "chai-deep-equal-ignore-undefined": "^1.1.1", + "cross-env": "^7.0.3", "eslint": "^8.52.0", "eslint-plugin-import": "^2.29.0", "eslint-plugin-unused-imports": "^3.0.0", diff --git a/src/methods/chat-session.test.ts b/src/methods/chat-session.test.ts index c09821030..b33050f4d 100644 --- a/src/methods/chat-session.test.ts +++ b/src/methods/chat-session.test.ts @@ -17,8 +17,8 @@ import { expect, use } from "chai"; import { match, restore, stub, useFakeTimers } from "sinon"; -import * as sinonChai from "sinon-chai"; -import * as chaiAsPromised from "chai-as-promised"; +import sinonChai from "sinon-chai"; +import chaiAsPromised from "chai-as-promised"; import * as generateContentMethods from "./generate-content"; import { GenerateContentStreamResult } from "../../types"; import { ChatSession } from "./chat-session"; diff --git a/src/methods/generate-content.test.ts b/src/methods/generate-content.test.ts index 94659ba5c..2ad3594a3 100644 --- a/src/methods/generate-content.test.ts +++ b/src/methods/generate-content.test.ts @@ -17,8 +17,8 @@ import { assert, expect, use } from "chai"; import { match, restore, stub } from "sinon"; -import * as sinonChai from "sinon-chai"; -import * as chaiAsPromised from "chai-as-promised"; +import sinonChai from "sinon-chai"; +import chaiAsPromised from "chai-as-promised"; import { getMockResponse } from "../../test-utils/mock-response"; import * as request from "../requests/request"; import { generateContent } from "./generate-content"; diff --git a/src/models/generative-model.test.ts b/src/models/generative-model.test.ts index e6c6fd85d..cdfb65772 100644 --- a/src/models/generative-model.test.ts +++ b/src/models/generative-model.test.ts @@ -16,7 +16,7 @@ */ import { expect, use } from "chai"; import { GenerativeModel } from "./generative-model"; -import * as sinonChai from "sinon-chai"; +import sinonChai from "sinon-chai"; import { CountTokensRequest, FunctionCallingMode, diff --git a/src/models/generative-model.ts b/src/models/generative-model.ts index 7cd3fe622..168a5bebe 100644 --- a/src/models/generative-model.ts +++ b/src/models/generative-model.ts @@ -50,6 +50,8 @@ import { formatGenerateContentInput, formatSystemInstruction, } from "../requests/request-helpers"; +import { StreamCallbacks } from "../../types/requests"; + /** * Class for generative model APIs. @@ -132,6 +134,8 @@ export class GenerativeModel { async generateContentStream( request: GenerateContentRequest | string | Array, requestOptions: SingleRequestOptions = {}, + // eslint-disable-next-line @typescript-eslint/no-unused-vars + callbacks?: StreamCallbacks ): Promise { const formattedParams = formatGenerateContentInput(request); const generativeModelRequestOptions: SingleRequestOptions = { @@ -151,6 +155,7 @@ export class GenerativeModel { ...formattedParams, }, generativeModelRequestOptions, + ); } diff --git a/src/requests/request-helpers.test.ts b/src/requests/request-helpers.test.ts index f3d46cd05..77eed0ca0 100644 --- a/src/requests/request-helpers.test.ts +++ b/src/requests/request-helpers.test.ts @@ -16,7 +16,7 @@ */ import { expect, use } from "chai"; -import * as sinonChai from "sinon-chai"; +import sinonChai from "sinon-chai"; import chaiDeepEqualIgnoreUndefined from "chai-deep-equal-ignore-undefined"; import { Content } from "../../types"; import { diff --git a/src/requests/request.test.ts b/src/requests/request.test.ts index 95f7a3a47..5033077a5 100644 --- a/src/requests/request.test.ts +++ b/src/requests/request.test.ts @@ -17,8 +17,8 @@ import { expect, use } from "chai"; import { match, restore, stub } from "sinon"; -import * as sinonChai from "sinon-chai"; -import * as chaiAsPromised from "chai-as-promised"; +import sinonChai from "sinon-chai"; +import chaiAsPromised from "chai-as-promised"; import { DEFAULT_API_VERSION, DEFAULT_BASE_URL, diff --git a/src/requests/response-helpers.test.ts b/src/requests/response-helpers.test.ts index d58517f6c..0e102a628 100644 --- a/src/requests/response-helpers.test.ts +++ b/src/requests/response-helpers.test.ts @@ -18,7 +18,7 @@ import { addHelpers, formatBlockErrorMessage } from "./response-helpers"; import { expect, use } from "chai"; import { restore } from "sinon"; -import * as sinonChai from "sinon-chai"; +import sinonChai from "sinon-chai"; import { BlockReason, Content, diff --git a/src/requests/stream-reader.test.ts b/src/requests/stream-reader.test.ts index 043745544..d2d680853 100644 --- a/src/requests/stream-reader.test.ts +++ b/src/requests/stream-reader.test.ts @@ -22,7 +22,7 @@ import { } from "./stream-reader"; import { expect, use } from "chai"; import { restore } from "sinon"; -import * as sinonChai from "sinon-chai"; +import sinonChai from "sinon-chai"; import { getChunkedStream, getErrorStream, @@ -340,7 +340,73 @@ describe("processStream", () => { } expect(foundCitationMetadata).to.be.true; }); -}); + + + describe("callbacks", () => { + it("chunk callback were called", (done) => { + const fakeResponse = getMockResponseStreaming( + "streaming-success-citations.txt", + ); + let dataCount = 0; + processStream(fakeResponse as Response, { + onData: (data: string) => { + dataCount++; + expect(data).to.not.be.empty; + }, + onEnd: () => { + expect(dataCount).to.be.greaterThan(0); + done(); + }, + }); + }); + + it("error callbacks were called", (done) => { + const fakeResponse = getMockResponseStreaming( + "streaming-failure-prompt-blocked-safety.txt", + ); + processStream(fakeResponse as Response, { + onError: (error: Error) => { + expect(error).to.be.instanceOf(GoogleGenerativeAIError); + done(); + }, + onEnd:() => done(), + }); + }); + + it("end callbacks were called", (done) => { + const fakeResponse = getMockResponseStreaming( + "streaming-success-basic-reply-short.txt", + ); + processStream(fakeResponse as Response, { + onEnd: (data) => { + expect(data).to.include("Cheyenne"); + done(); + }, + }); + }); + + it("all callbacks were called", (done) => { + const fakeResponse = getMockResponseStreaming( + "streaming-success-basic-reply-long.txt", // Changed to long response + ); + let dataCount = 0; + processStream(fakeResponse as Response, { + onData: (data: string) => { + dataCount++; + expect(data).to.not.be.empty; + }, + onEnd: (data) => { + expect(data).to.include("**Cats:**"); + expect(data).to.include("to their owners."); + expect(dataCount).to.be.greaterThan(0); + done(); + }, + }); + }); + }); + }); + + describe("aggregateResponses", () => { it("handles no candidates, and promptFeedback", () => { diff --git a/src/requests/stream-reader.ts b/src/requests/stream-reader.ts index 0d1a24e6f..d2f51babb 100644 --- a/src/requests/stream-reader.ts +++ b/src/requests/stream-reader.ts @@ -21,6 +21,7 @@ import { GenerateContentResponse, GenerateContentStreamResult, Part, + StreamCallbacks, } from "../../types"; import { GoogleGenerativeAIAbortError, @@ -38,7 +39,7 @@ const responseLineRE = /^data\: (.*)(?:\n\n|\r\r|\r\n\r\n)/; * * @param response - Response from a fetch call */ -export function processStream(response: Response): GenerateContentStreamResult { +export function processStream(response: Response, callbacks?: StreamCallbacks): GenerateContentStreamResult { const inputStream = response.body!.pipeThrough( new TextDecoderStream("utf8", { fatal: true }), ); @@ -47,22 +48,31 @@ export function processStream(response: Response): GenerateContentStreamResult { const [stream1, stream2] = responseStream.tee(); return { stream: generateResponseSequence(stream1), - response: getResponsePromise(stream2), + response: getResponsePromise(stream2,callbacks), }; } async function getResponsePromise( stream: ReadableStream, + callbacks?: StreamCallbacks, ): Promise { - const allResponses: GenerateContentResponse[] = []; - const reader = stream.getReader(); + + try { + const allResponses: GenerateContentResponse[] = []; + const reader = stream.getReader(); while (true) { const { done, value } = await reader.read(); if (done) { + callbacks?.onEnd?.(allResponses.reduce((acc, curr) => acc + addHelpers(curr).text(), "")); return addHelpers(aggregateResponses(allResponses)); } allResponses.push(value); + callbacks?.onData?.(addHelpers(value).text()); } +}catch(error){ + callbacks?.onError?.(error); + throw error; +} } async function* generateResponseSequence( diff --git a/src/server/cache-manager.test.ts b/src/server/cache-manager.test.ts index ba34ec67b..cf9947c08 100644 --- a/src/server/cache-manager.test.ts +++ b/src/server/cache-manager.test.ts @@ -16,8 +16,8 @@ */ import { expect, use } from "chai"; import { GoogleAICacheManager } from "./cache-manager"; -import * as sinonChai from "sinon-chai"; -import * as chaiAsPromised from "chai-as-promised"; +import sinonChai from "sinon-chai"; +import chaiAsPromised from "chai-as-promised"; import { restore, stub } from "sinon"; import * as request from "./request"; import { RpcTask } from "./constants"; diff --git a/src/server/file-manager.test.ts b/src/server/file-manager.test.ts index a0531c9b0..483024027 100644 --- a/src/server/file-manager.test.ts +++ b/src/server/file-manager.test.ts @@ -16,8 +16,8 @@ */ import { expect, use } from "chai"; import { GoogleAIFileManager, getUploadMetadata } from "./file-manager"; -import * as sinonChai from "sinon-chai"; -import * as chaiAsPromised from "chai-as-promised"; +import sinonChai from "sinon-chai"; +import chaiAsPromised from "chai-as-promised"; import { restore, stub } from "sinon"; import * as request from "./request"; import { RpcTask } from "./constants"; diff --git a/src/server/request.test.ts b/src/server/request.test.ts index 4a0e8299a..2a7b643ac 100644 --- a/src/server/request.test.ts +++ b/src/server/request.test.ts @@ -17,8 +17,8 @@ import { expect, use } from "chai"; import { match, restore, stub } from "sinon"; -import * as sinonChai from "sinon-chai"; -import * as chaiAsPromised from "chai-as-promised"; +import sinonChai from "sinon-chai"; +import chaiAsPromised from "chai-as-promised"; import { DEFAULT_API_VERSION, DEFAULT_BASE_URL } from "../requests/request"; import { FilesRequestUrl, diff --git a/tsconfig.json b/tsconfig.json index 3daccfd53..088837e04 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -1,12 +1,13 @@ { "compilerOptions": { "noImplicitAny": true, - "module": "es2020", - "target": "es2020", + "module": "NodeNext", + "target": "ES2020", "allowJs": true, - "moduleResolution": "node", + "moduleResolution": "NodeNext", "declaration": true, "outDir": "dist", + "baseUrl": "./", "allowSyntheticDefaultImports": true } } diff --git a/types/requests.ts b/types/requests.ts index 81285bc20..69d2012b3 100644 --- a/types/requests.ts +++ b/types/requests.ts @@ -25,6 +25,7 @@ import { } from "./function-calling"; import { GoogleSearchRetrievalTool } from "./search-grounding"; + /** * Base parameters for a number of methods. * @public @@ -208,6 +209,15 @@ export interface RequestOptions { */ customHeaders?: Headers | Record; } +/** + * Stream callbacks for streaming responses. + * @public + */ +export interface StreamCallbacks { + onData?: (data: string) => void; + onEnd?: (data: string) => void; + onError?: (error: Error) => void; +} /** * Params passed to atomic asynchronous operations.