diff --git a/common/api-review/generative-ai-server.api.md b/common/api-review/generative-ai-server.api.md
index b31a89e37..d7639ade0 100644
--- a/common/api-review/generative-ai-server.api.md
+++ b/common/api-review/generative-ai-server.api.md
@@ -4,6 +4,8 @@
```ts
+///
+
// Warning: (ae-incompatible-release-tags) The symbol "ArraySchema" is marked as @public, but its signature references "BaseSchema" which is marked as @internal
//
// @public
@@ -31,8 +33,6 @@ export interface BooleanSchema extends BaseSchema {
type: typeof SchemaType.BOOLEAN;
}
-///
-
// @public
export interface CachedContent extends CachedContentBase {
createTime?: string;
diff --git a/common/api-review/generative-ai.api.md b/common/api-review/generative-ai.api.md
index 5880a650d..302ed42c3 100644
--- a/common/api-review/generative-ai.api.md
+++ b/common/api-review/generative-ai.api.md
@@ -525,7 +525,7 @@ export class GenerativeModel {
countTokens(request: CountTokensRequest | string | Array, requestOptions?: SingleRequestOptions): Promise;
embedContent(request: EmbedContentRequest | string | Array, requestOptions?: SingleRequestOptions): Promise;
generateContent(request: GenerateContentRequest | string | Array, requestOptions?: SingleRequestOptions): Promise;
- generateContentStream(request: GenerateContentRequest | string | Array, requestOptions?: SingleRequestOptions): Promise;
+ generateContentStream(request: GenerateContentRequest | string | Array, requestOptions?: SingleRequestOptions, callbacks?: StreamCallbacks): Promise;
// (undocumented)
generationConfig: GenerationConfig;
// (undocumented)
@@ -841,6 +841,16 @@ export interface StartChatParams extends BaseParams {
tools?: Tool[];
}
+// @public
+export interface StreamCallbacks {
+ // (undocumented)
+ onData?: (data: string) => void;
+ // (undocumented)
+ onEnd?: (data: string) => void;
+ // (undocumented)
+ onError?: (error: Error) => void;
+}
+
// @public
export type StringSchema = SimpleStringSchema | EnumStringSchema;
diff --git a/package.json b/package.json
index 54266a7ca..5242eb42f 100644
--- a/package.json
+++ b/package.json
@@ -35,9 +35,9 @@
"build": "rollup -c && npm run api-report",
"test": "npm run lint && npm run test:node:unit",
"test:web:integration": "npm run build && npx web-test-runner",
- "test:node:unit": "TS_NODE_COMPILER_OPTIONS='{\"module\":\"commonjs\"}' mocha \"src/**/*.test.ts\"",
+ "test:node:unit": "mocha --require ts-node/register --require tsconfig-paths/register \"src/**/*.test.ts\"",
"test:node:integration": "npm run build && TS_NODE_COMPILER_OPTIONS='{\"module\":\"commonjs\"}' mocha \"test-integration/node/**/*.test.ts\"",
- "lint": "eslint -c .eslintrc.js '**/*.ts' --ignore-path './.gitignore'",
+ "lint": "eslint -c .eslintrc.js **/*.ts --ignore-path .gitignore",
"api-report": "api-extractor run -c api-extractor.json --local --verbose && api-extractor run -c api-extractor.server.json --local --verbose",
"docs": "npm run build && npx api-documenter markdown -i ./temp/main -o ./docs/reference/main && npx api-documenter markdown -i ./temp/server -o ./docs/reference/server",
"format": "TS_NODE_COMPILER_OPTIONS='{\"module\":\"nodenext\"}' npx ts-node scripts/run-format.ts",
@@ -61,8 +61,9 @@
"@web/dev-server-esbuild": "^1.0.1",
"@web/test-runner": "^0.18.0",
"chai": "^4.3.10",
- "chai-as-promised": "^7.1.1",
+ "chai-as-promised": "^7.1.2",
"chai-deep-equal-ignore-undefined": "^1.1.1",
+ "cross-env": "^7.0.3",
"eslint": "^8.52.0",
"eslint-plugin-import": "^2.29.0",
"eslint-plugin-unused-imports": "^3.0.0",
diff --git a/src/methods/chat-session.test.ts b/src/methods/chat-session.test.ts
index c09821030..b33050f4d 100644
--- a/src/methods/chat-session.test.ts
+++ b/src/methods/chat-session.test.ts
@@ -17,8 +17,8 @@
import { expect, use } from "chai";
import { match, restore, stub, useFakeTimers } from "sinon";
-import * as sinonChai from "sinon-chai";
-import * as chaiAsPromised from "chai-as-promised";
+import sinonChai from "sinon-chai";
+import chaiAsPromised from "chai-as-promised";
import * as generateContentMethods from "./generate-content";
import { GenerateContentStreamResult } from "../../types";
import { ChatSession } from "./chat-session";
diff --git a/src/methods/generate-content.test.ts b/src/methods/generate-content.test.ts
index 94659ba5c..2ad3594a3 100644
--- a/src/methods/generate-content.test.ts
+++ b/src/methods/generate-content.test.ts
@@ -17,8 +17,8 @@
import { assert, expect, use } from "chai";
import { match, restore, stub } from "sinon";
-import * as sinonChai from "sinon-chai";
-import * as chaiAsPromised from "chai-as-promised";
+import sinonChai from "sinon-chai";
+import chaiAsPromised from "chai-as-promised";
import { getMockResponse } from "../../test-utils/mock-response";
import * as request from "../requests/request";
import { generateContent } from "./generate-content";
diff --git a/src/models/generative-model.test.ts b/src/models/generative-model.test.ts
index e6c6fd85d..cdfb65772 100644
--- a/src/models/generative-model.test.ts
+++ b/src/models/generative-model.test.ts
@@ -16,7 +16,7 @@
*/
import { expect, use } from "chai";
import { GenerativeModel } from "./generative-model";
-import * as sinonChai from "sinon-chai";
+import sinonChai from "sinon-chai";
import {
CountTokensRequest,
FunctionCallingMode,
diff --git a/src/models/generative-model.ts b/src/models/generative-model.ts
index 7cd3fe622..168a5bebe 100644
--- a/src/models/generative-model.ts
+++ b/src/models/generative-model.ts
@@ -50,6 +50,8 @@ import {
formatGenerateContentInput,
formatSystemInstruction,
} from "../requests/request-helpers";
+import { StreamCallbacks } from "../../types/requests";
+
/**
* Class for generative model APIs.
@@ -132,6 +134,8 @@ export class GenerativeModel {
async generateContentStream(
request: GenerateContentRequest | string | Array,
requestOptions: SingleRequestOptions = {},
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
+ callbacks?: StreamCallbacks
): Promise {
const formattedParams = formatGenerateContentInput(request);
const generativeModelRequestOptions: SingleRequestOptions = {
@@ -151,6 +155,7 @@ export class GenerativeModel {
...formattedParams,
},
generativeModelRequestOptions,
+
);
}
diff --git a/src/requests/request-helpers.test.ts b/src/requests/request-helpers.test.ts
index f3d46cd05..77eed0ca0 100644
--- a/src/requests/request-helpers.test.ts
+++ b/src/requests/request-helpers.test.ts
@@ -16,7 +16,7 @@
*/
import { expect, use } from "chai";
-import * as sinonChai from "sinon-chai";
+import sinonChai from "sinon-chai";
import chaiDeepEqualIgnoreUndefined from "chai-deep-equal-ignore-undefined";
import { Content } from "../../types";
import {
diff --git a/src/requests/request.test.ts b/src/requests/request.test.ts
index 95f7a3a47..5033077a5 100644
--- a/src/requests/request.test.ts
+++ b/src/requests/request.test.ts
@@ -17,8 +17,8 @@
import { expect, use } from "chai";
import { match, restore, stub } from "sinon";
-import * as sinonChai from "sinon-chai";
-import * as chaiAsPromised from "chai-as-promised";
+import sinonChai from "sinon-chai";
+import chaiAsPromised from "chai-as-promised";
import {
DEFAULT_API_VERSION,
DEFAULT_BASE_URL,
diff --git a/src/requests/response-helpers.test.ts b/src/requests/response-helpers.test.ts
index d58517f6c..0e102a628 100644
--- a/src/requests/response-helpers.test.ts
+++ b/src/requests/response-helpers.test.ts
@@ -18,7 +18,7 @@
import { addHelpers, formatBlockErrorMessage } from "./response-helpers";
import { expect, use } from "chai";
import { restore } from "sinon";
-import * as sinonChai from "sinon-chai";
+import sinonChai from "sinon-chai";
import {
BlockReason,
Content,
diff --git a/src/requests/stream-reader.test.ts b/src/requests/stream-reader.test.ts
index 043745544..d2d680853 100644
--- a/src/requests/stream-reader.test.ts
+++ b/src/requests/stream-reader.test.ts
@@ -22,7 +22,7 @@ import {
} from "./stream-reader";
import { expect, use } from "chai";
import { restore } from "sinon";
-import * as sinonChai from "sinon-chai";
+import sinonChai from "sinon-chai";
import {
getChunkedStream,
getErrorStream,
@@ -340,7 +340,73 @@ describe("processStream", () => {
}
expect(foundCitationMetadata).to.be.true;
});
-});
+
+
+ describe("callbacks", () => {
+ it("chunk callback were called", (done) => {
+ const fakeResponse = getMockResponseStreaming(
+ "streaming-success-citations.txt",
+ );
+ let dataCount = 0;
+ processStream(fakeResponse as Response, {
+ onData: (data: string) => {
+ dataCount++;
+ expect(data).to.not.be.empty;
+ },
+ onEnd: () => {
+ expect(dataCount).to.be.greaterThan(0);
+ done();
+ },
+ });
+ });
+
+ it("error callbacks were called", (done) => {
+ const fakeResponse = getMockResponseStreaming(
+ "streaming-failure-prompt-blocked-safety.txt",
+ );
+ processStream(fakeResponse as Response, {
+ onError: (error: Error) => {
+ expect(error).to.be.instanceOf(GoogleGenerativeAIError);
+ done();
+ },
+ onEnd:() => done(),
+ });
+ });
+
+ it("end callbacks were called", (done) => {
+ const fakeResponse = getMockResponseStreaming(
+ "streaming-success-basic-reply-short.txt",
+ );
+ processStream(fakeResponse as Response, {
+ onEnd: (data) => {
+ expect(data).to.include("Cheyenne");
+ done();
+ },
+ });
+ });
+
+ it("all callbacks were called", (done) => {
+ const fakeResponse = getMockResponseStreaming(
+ "streaming-success-basic-reply-long.txt", // Changed to long response
+ );
+ let dataCount = 0;
+ processStream(fakeResponse as Response, {
+ onData: (data: string) => {
+ dataCount++;
+ expect(data).to.not.be.empty;
+ },
+ onEnd: (data) => {
+ expect(data).to.include("**Cats:**");
+ expect(data).to.include("to their owners.");
+ expect(dataCount).to.be.greaterThan(0);
+ done();
+ },
+ });
+ });
+ });
+ });
+
+
describe("aggregateResponses", () => {
it("handles no candidates, and promptFeedback", () => {
diff --git a/src/requests/stream-reader.ts b/src/requests/stream-reader.ts
index 0d1a24e6f..d2f51babb 100644
--- a/src/requests/stream-reader.ts
+++ b/src/requests/stream-reader.ts
@@ -21,6 +21,7 @@ import {
GenerateContentResponse,
GenerateContentStreamResult,
Part,
+ StreamCallbacks,
} from "../../types";
import {
GoogleGenerativeAIAbortError,
@@ -38,7 +39,7 @@ const responseLineRE = /^data\: (.*)(?:\n\n|\r\r|\r\n\r\n)/;
*
* @param response - Response from a fetch call
*/
-export function processStream(response: Response): GenerateContentStreamResult {
+export function processStream(response: Response, callbacks?: StreamCallbacks): GenerateContentStreamResult {
const inputStream = response.body!.pipeThrough(
new TextDecoderStream("utf8", { fatal: true }),
);
@@ -47,22 +48,31 @@ export function processStream(response: Response): GenerateContentStreamResult {
const [stream1, stream2] = responseStream.tee();
return {
stream: generateResponseSequence(stream1),
- response: getResponsePromise(stream2),
+ response: getResponsePromise(stream2,callbacks),
};
}
async function getResponsePromise(
stream: ReadableStream,
+ callbacks?: StreamCallbacks,
): Promise {
- const allResponses: GenerateContentResponse[] = [];
- const reader = stream.getReader();
+
+ try {
+ const allResponses: GenerateContentResponse[] = [];
+ const reader = stream.getReader();
while (true) {
const { done, value } = await reader.read();
if (done) {
+ callbacks?.onEnd?.(allResponses.reduce((acc, curr) => acc + addHelpers(curr).text(), ""));
return addHelpers(aggregateResponses(allResponses));
}
allResponses.push(value);
+ callbacks?.onData?.(addHelpers(value).text());
}
+}catch(error){
+ callbacks?.onError?.(error);
+ throw error;
+}
}
async function* generateResponseSequence(
diff --git a/src/server/cache-manager.test.ts b/src/server/cache-manager.test.ts
index ba34ec67b..cf9947c08 100644
--- a/src/server/cache-manager.test.ts
+++ b/src/server/cache-manager.test.ts
@@ -16,8 +16,8 @@
*/
import { expect, use } from "chai";
import { GoogleAICacheManager } from "./cache-manager";
-import * as sinonChai from "sinon-chai";
-import * as chaiAsPromised from "chai-as-promised";
+import sinonChai from "sinon-chai";
+import chaiAsPromised from "chai-as-promised";
import { restore, stub } from "sinon";
import * as request from "./request";
import { RpcTask } from "./constants";
diff --git a/src/server/file-manager.test.ts b/src/server/file-manager.test.ts
index a0531c9b0..483024027 100644
--- a/src/server/file-manager.test.ts
+++ b/src/server/file-manager.test.ts
@@ -16,8 +16,8 @@
*/
import { expect, use } from "chai";
import { GoogleAIFileManager, getUploadMetadata } from "./file-manager";
-import * as sinonChai from "sinon-chai";
-import * as chaiAsPromised from "chai-as-promised";
+import sinonChai from "sinon-chai";
+import chaiAsPromised from "chai-as-promised";
import { restore, stub } from "sinon";
import * as request from "./request";
import { RpcTask } from "./constants";
diff --git a/src/server/request.test.ts b/src/server/request.test.ts
index 4a0e8299a..2a7b643ac 100644
--- a/src/server/request.test.ts
+++ b/src/server/request.test.ts
@@ -17,8 +17,8 @@
import { expect, use } from "chai";
import { match, restore, stub } from "sinon";
-import * as sinonChai from "sinon-chai";
-import * as chaiAsPromised from "chai-as-promised";
+import sinonChai from "sinon-chai";
+import chaiAsPromised from "chai-as-promised";
import { DEFAULT_API_VERSION, DEFAULT_BASE_URL } from "../requests/request";
import {
FilesRequestUrl,
diff --git a/tsconfig.json b/tsconfig.json
index 3daccfd53..088837e04 100644
--- a/tsconfig.json
+++ b/tsconfig.json
@@ -1,12 +1,13 @@
{
"compilerOptions": {
"noImplicitAny": true,
- "module": "es2020",
- "target": "es2020",
+ "module": "NodeNext",
+ "target": "ES2020",
"allowJs": true,
- "moduleResolution": "node",
+ "moduleResolution": "NodeNext",
"declaration": true,
"outDir": "dist",
+ "baseUrl": "./",
"allowSyntheticDefaultImports": true
}
}
diff --git a/types/requests.ts b/types/requests.ts
index 81285bc20..69d2012b3 100644
--- a/types/requests.ts
+++ b/types/requests.ts
@@ -25,6 +25,7 @@ import {
} from "./function-calling";
import { GoogleSearchRetrievalTool } from "./search-grounding";
+
/**
* Base parameters for a number of methods.
* @public
@@ -208,6 +209,15 @@ export interface RequestOptions {
*/
customHeaders?: Headers | Record;
}
+/**
+ * Stream callbacks for streaming responses.
+ * @public
+ */
+export interface StreamCallbacks {
+ onData?: (data: string) => void;
+ onEnd?: (data: string) => void;
+ onError?: (error: Error) => void;
+}
/**
* Params passed to atomic asynchronous operations.