Skip to content

Added StreamCallbacks parameter to generateContentStream() (#322) #446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions common/api-review/generative-ai-server.api.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

```ts

/// <reference types="node" />

// Warning: (ae-incompatible-release-tags) The symbol "ArraySchema" is marked as @public, but its signature references "BaseSchema" which is marked as @internal
//
// @public
Expand Down Expand Up @@ -31,8 +33,6 @@ export interface BooleanSchema extends BaseSchema {
type: typeof SchemaType.BOOLEAN;
}

/// <reference types="node" />

// @public
export interface CachedContent extends CachedContentBase {
createTime?: string;
Expand Down
12 changes: 11 additions & 1 deletion common/api-review/generative-ai.api.md
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ export class GenerativeModel {
countTokens(request: CountTokensRequest | string | Array<string | Part>, requestOptions?: SingleRequestOptions): Promise<CountTokensResponse>;
embedContent(request: EmbedContentRequest | string | Array<string | Part>, requestOptions?: SingleRequestOptions): Promise<EmbedContentResponse>;
generateContent(request: GenerateContentRequest | string | Array<string | Part>, requestOptions?: SingleRequestOptions): Promise<GenerateContentResult>;
generateContentStream(request: GenerateContentRequest | string | Array<string | Part>, requestOptions?: SingleRequestOptions): Promise<GenerateContentStreamResult>;
generateContentStream(request: GenerateContentRequest | string | Array<string | Part>, requestOptions?: SingleRequestOptions, callbacks?: StreamCallbacks): Promise<GenerateContentStreamResult>;
// (undocumented)
generationConfig: GenerationConfig;
// (undocumented)
Expand Down Expand Up @@ -841,6 +841,16 @@ export interface StartChatParams extends BaseParams {
tools?: Tool[];
}

// @public
export interface StreamCallbacks {
// (undocumented)
onData?: (data: string) => void;
// (undocumented)
onEnd?: (data: string) => void;
// (undocumented)
onError?: (error: Error) => void;
}

// @public
export type StringSchema = SimpleStringSchema | EnumStringSchema;

Expand Down
7 changes: 4 additions & 3 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
"build": "rollup -c && npm run api-report",
"test": "npm run lint && npm run test:node:unit",
"test:web:integration": "npm run build && npx web-test-runner",
"test:node:unit": "TS_NODE_COMPILER_OPTIONS='{\"module\":\"commonjs\"}' mocha \"src/**/*.test.ts\"",
"test:node:unit": "mocha --require ts-node/register --require tsconfig-paths/register \"src/**/*.test.ts\"",
"test:node:integration": "npm run build && TS_NODE_COMPILER_OPTIONS='{\"module\":\"commonjs\"}' mocha \"test-integration/node/**/*.test.ts\"",
"lint": "eslint -c .eslintrc.js '**/*.ts' --ignore-path './.gitignore'",
"lint": "eslint -c .eslintrc.js **/*.ts --ignore-path .gitignore",
"api-report": "api-extractor run -c api-extractor.json --local --verbose && api-extractor run -c api-extractor.server.json --local --verbose",
"docs": "npm run build && npx api-documenter markdown -i ./temp/main -o ./docs/reference/main && npx api-documenter markdown -i ./temp/server -o ./docs/reference/server",
"format": "TS_NODE_COMPILER_OPTIONS='{\"module\":\"nodenext\"}' npx ts-node scripts/run-format.ts",
Expand All @@ -61,8 +61,9 @@
"@web/dev-server-esbuild": "^1.0.1",
"@web/test-runner": "^0.18.0",
"chai": "^4.3.10",
"chai-as-promised": "^7.1.1",
"chai-as-promised": "^7.1.2",
"chai-deep-equal-ignore-undefined": "^1.1.1",
"cross-env": "^7.0.3",
"eslint": "^8.52.0",
"eslint-plugin-import": "^2.29.0",
"eslint-plugin-unused-imports": "^3.0.0",
Expand Down
4 changes: 2 additions & 2 deletions src/methods/chat-session.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

import { expect, use } from "chai";
import { match, restore, stub, useFakeTimers } from "sinon";
import * as sinonChai from "sinon-chai";
import * as chaiAsPromised from "chai-as-promised";
import sinonChai from "sinon-chai";
import chaiAsPromised from "chai-as-promised";
import * as generateContentMethods from "./generate-content";
import { GenerateContentStreamResult } from "../../types";
import { ChatSession } from "./chat-session";
Expand Down
4 changes: 2 additions & 2 deletions src/methods/generate-content.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

import { assert, expect, use } from "chai";
import { match, restore, stub } from "sinon";
import * as sinonChai from "sinon-chai";
import * as chaiAsPromised from "chai-as-promised";
import sinonChai from "sinon-chai";
import chaiAsPromised from "chai-as-promised";
import { getMockResponse } from "../../test-utils/mock-response";
import * as request from "../requests/request";
import { generateContent } from "./generate-content";
Expand Down
2 changes: 1 addition & 1 deletion src/models/generative-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
import { expect, use } from "chai";
import { GenerativeModel } from "./generative-model";
import * as sinonChai from "sinon-chai";
import sinonChai from "sinon-chai";
import {
CountTokensRequest,
FunctionCallingMode,
Expand Down
5 changes: 5 additions & 0 deletions src/models/generative-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ import {
formatGenerateContentInput,
formatSystemInstruction,
} from "../requests/request-helpers";
import { StreamCallbacks } from "../../types/requests";


/**
* Class for generative model APIs.
Expand Down Expand Up @@ -132,6 +134,8 @@ export class GenerativeModel {
async generateContentStream(
request: GenerateContentRequest | string | Array<string | Part>,
requestOptions: SingleRequestOptions = {},
// eslint-disable-next-line @typescript-eslint/no-unused-vars
callbacks?: StreamCallbacks
): Promise<GenerateContentStreamResult> {
const formattedParams = formatGenerateContentInput(request);
const generativeModelRequestOptions: SingleRequestOptions = {
Expand All @@ -151,6 +155,7 @@ export class GenerativeModel {
...formattedParams,
},
generativeModelRequestOptions,

);
}

Expand Down
2 changes: 1 addition & 1 deletion src/requests/request-helpers.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/

import { expect, use } from "chai";
import * as sinonChai from "sinon-chai";
import sinonChai from "sinon-chai";
import chaiDeepEqualIgnoreUndefined from "chai-deep-equal-ignore-undefined";
import { Content } from "../../types";
import {
Expand Down
4 changes: 2 additions & 2 deletions src/requests/request.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

import { expect, use } from "chai";
import { match, restore, stub } from "sinon";
import * as sinonChai from "sinon-chai";
import * as chaiAsPromised from "chai-as-promised";
import sinonChai from "sinon-chai";
import chaiAsPromised from "chai-as-promised";
import {
DEFAULT_API_VERSION,
DEFAULT_BASE_URL,
Expand Down
2 changes: 1 addition & 1 deletion src/requests/response-helpers.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import { addHelpers, formatBlockErrorMessage } from "./response-helpers";
import { expect, use } from "chai";
import { restore } from "sinon";
import * as sinonChai from "sinon-chai";
import sinonChai from "sinon-chai";
import {
BlockReason,
Content,
Expand Down
70 changes: 68 additions & 2 deletions src/requests/stream-reader.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import {
} from "./stream-reader";
import { expect, use } from "chai";
import { restore } from "sinon";
import * as sinonChai from "sinon-chai";
import sinonChai from "sinon-chai";
import {
getChunkedStream,
getErrorStream,
Expand Down Expand Up @@ -340,7 +340,73 @@ describe("processStream", () => {
}
expect(foundCitationMetadata).to.be.true;
});
});


describe("callbacks", () => {
it("chunk callback were called", (done) => {
const fakeResponse = getMockResponseStreaming(
"streaming-success-citations.txt",
);
let dataCount = 0;
processStream(fakeResponse as Response, {
onData: (data: string) => {
dataCount++;
expect(data).to.not.be.empty;
},
onEnd: () => {
expect(dataCount).to.be.greaterThan(0);
done();
},
});
});

it("error callbacks were called", (done) => {
const fakeResponse = getMockResponseStreaming(
"streaming-failure-prompt-blocked-safety.txt",
);
processStream(fakeResponse as Response, {
onError: (error: Error) => {
expect(error).to.be.instanceOf(GoogleGenerativeAIError);
done();
},
onEnd:() => done(),
});
});

it("end callbacks were called", (done) => {
const fakeResponse = getMockResponseStreaming(
"streaming-success-basic-reply-short.txt",
);
processStream(fakeResponse as Response, {
onEnd: (data) => {
expect(data).to.include("Cheyenne");
done();
},
});
});

it("all callbacks were called", (done) => {
const fakeResponse = getMockResponseStreaming(
"streaming-success-basic-reply-long.txt", // Changed to long response
);
let dataCount = 0;
processStream(fakeResponse as Response, {
onData: (data: string) => {
dataCount++;
expect(data).to.not.be.empty;
},
onEnd: (data) => {
expect(data).to.include("**Cats:**");
expect(data).to.include("to their owners.");
expect(dataCount).to.be.greaterThan(0);
done();
},
});
});
});
});



describe("aggregateResponses", () => {
it("handles no candidates, and promptFeedback", () => {
Expand Down
18 changes: 14 additions & 4 deletions src/requests/stream-reader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import {
GenerateContentResponse,
GenerateContentStreamResult,
Part,
StreamCallbacks,
} from "../../types";
import {
GoogleGenerativeAIAbortError,
Expand All @@ -38,7 +39,7 @@ const responseLineRE = /^data\: (.*)(?:\n\n|\r\r|\r\n\r\n)/;
*
* @param response - Response from a fetch call
*/
export function processStream(response: Response): GenerateContentStreamResult {
export function processStream(response: Response, callbacks?: StreamCallbacks): GenerateContentStreamResult {
const inputStream = response.body!.pipeThrough(
new TextDecoderStream("utf8", { fatal: true }),
);
Expand All @@ -47,22 +48,31 @@ export function processStream(response: Response): GenerateContentStreamResult {
const [stream1, stream2] = responseStream.tee();
return {
stream: generateResponseSequence(stream1),
response: getResponsePromise(stream2),
response: getResponsePromise(stream2,callbacks),
};
}

async function getResponsePromise(
stream: ReadableStream<GenerateContentResponse>,
callbacks?: StreamCallbacks,
): Promise<EnhancedGenerateContentResponse> {
const allResponses: GenerateContentResponse[] = [];
const reader = stream.getReader();

try {
const allResponses: GenerateContentResponse[] = [];
const reader = stream.getReader();
while (true) {
const { done, value } = await reader.read();
if (done) {
callbacks?.onEnd?.(allResponses.reduce((acc, curr) => acc + addHelpers(curr).text(), ""));
return addHelpers(aggregateResponses(allResponses));
}
allResponses.push(value);
callbacks?.onData?.(addHelpers(value).text());
}
}catch(error){
callbacks?.onError?.(error);
throw error;
}
}

async function* generateResponseSequence(
Expand Down
4 changes: 2 additions & 2 deletions src/server/cache-manager.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
*/
import { expect, use } from "chai";
import { GoogleAICacheManager } from "./cache-manager";
import * as sinonChai from "sinon-chai";
import * as chaiAsPromised from "chai-as-promised";
import sinonChai from "sinon-chai";
import chaiAsPromised from "chai-as-promised";
import { restore, stub } from "sinon";
import * as request from "./request";
import { RpcTask } from "./constants";
Expand Down
4 changes: 2 additions & 2 deletions src/server/file-manager.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
*/
import { expect, use } from "chai";
import { GoogleAIFileManager, getUploadMetadata } from "./file-manager";
import * as sinonChai from "sinon-chai";
import * as chaiAsPromised from "chai-as-promised";
import sinonChai from "sinon-chai";
import chaiAsPromised from "chai-as-promised";
import { restore, stub } from "sinon";
import * as request from "./request";
import { RpcTask } from "./constants";
Expand Down
4 changes: 2 additions & 2 deletions src/server/request.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

import { expect, use } from "chai";
import { match, restore, stub } from "sinon";
import * as sinonChai from "sinon-chai";
import * as chaiAsPromised from "chai-as-promised";
import sinonChai from "sinon-chai";
import chaiAsPromised from "chai-as-promised";
import { DEFAULT_API_VERSION, DEFAULT_BASE_URL } from "../requests/request";
import {
FilesRequestUrl,
Expand Down
7 changes: 4 additions & 3 deletions tsconfig.json
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
{
"compilerOptions": {
"noImplicitAny": true,
"module": "es2020",
"target": "es2020",
"module": "NodeNext",
"target": "ES2020",
"allowJs": true,
"moduleResolution": "node",
"moduleResolution": "NodeNext",
"declaration": true,
"outDir": "dist",
"baseUrl": "./",
"allowSyntheticDefaultImports": true
}
}
10 changes: 10 additions & 0 deletions types/requests.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
} from "./function-calling";
import { GoogleSearchRetrievalTool } from "./search-grounding";


/**
* Base parameters for a number of methods.
* @public
Expand Down Expand Up @@ -208,6 +209,15 @@ export interface RequestOptions {
*/
customHeaders?: Headers | Record<string, string>;
}
/**
* Stream callbacks for streaming responses.
* @public
*/
export interface StreamCallbacks {
onData?: (data: string) => void;
onEnd?: (data: string) => void;
onError?: (error: Error) => void;
}

/**
* Params passed to atomic asynchronous operations.
Expand Down