Skip to content

Add streamCallbacks as an Optional Argument and Implement Stream Handling Logic : Fixes #322 #456

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions common/api-review/generative-ai-server.api.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

```ts


// Warning: (ae-incompatible-release-tags) The symbol "ArraySchema" is marked as @public, but its signature references "BaseSchema" which is marked as @internal
//
// @public
Expand Down
12 changes: 11 additions & 1 deletion common/api-review/generative-ai.api.md
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ export class GenerativeModel {
countTokens(request: CountTokensRequest | string | Array<string | Part>, requestOptions?: SingleRequestOptions): Promise<CountTokensResponse>;
embedContent(request: EmbedContentRequest | string | Array<string | Part>, requestOptions?: SingleRequestOptions): Promise<EmbedContentResponse>;
generateContent(request: GenerateContentRequest | string | Array<string | Part>, requestOptions?: SingleRequestOptions): Promise<GenerateContentResult>;
generateContentStream(request: GenerateContentRequest | string | Array<string | Part>, requestOptions?: SingleRequestOptions): Promise<GenerateContentStreamResult>;
generateContentStream(request: GenerateContentRequest | string | Array<string | Part>, requestOptions?: SingleRequestOptions, callbacks?: StreamCallbacks): Promise<GenerateContentStreamResult>;
// (undocumented)
generationConfig: GenerationConfig;
// (undocumented)
Expand Down Expand Up @@ -841,6 +841,16 @@ export interface StartChatParams extends BaseParams {
tools?: Tool[];
}

// @public
export interface StreamCallbacks {
// (undocumented)
onData?: (data: string) => void;
// (undocumented)
onEnd?: (data: string) => void;
// (undocumented)
onError?: (error: Error) => void;
}

// @public
export type StringSchema = SimpleStringSchema | EnumStringSchema;

Expand Down
5 changes: 3 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
"build": "rollup -c && npm run api-report",
"test": "npm run lint && npm run test:node:unit",
"test:web:integration": "npm run build && npx web-test-runner",
"test:node:unit": "TS_NODE_COMPILER_OPTIONS='{\"module\":\"commonjs\"}' mocha \"src/**/*.test.ts\"",
"test:node:unit": "cross-env TS_NODE_COMPILER_OPTIONS=\"{\\\"module\\\":\\\"commonjs\\\"}\" mocha \"src/**/*.test.ts\"",
"test:node:integration": "npm run build && TS_NODE_COMPILER_OPTIONS='{\"module\":\"commonjs\"}' mocha \"test-integration/node/**/*.test.ts\"",
"lint": "eslint -c .eslintrc.js '**/*.ts' --ignore-path './.gitignore'",
"lint": "eslint -c .eslintrc.js \"**/*.ts\" --ignore-path .gitignore",
"api-report": "api-extractor run -c api-extractor.json --local --verbose && api-extractor run -c api-extractor.server.json --local --verbose",
"docs": "npm run build && npx api-documenter markdown -i ./temp/main -o ./docs/reference/main && npx api-documenter markdown -i ./temp/server -o ./docs/reference/server",
"format": "TS_NODE_COMPILER_OPTIONS='{\"module\":\"nodenext\"}' npx ts-node scripts/run-format.ts",
Expand All @@ -63,6 +63,7 @@
"chai": "^4.3.10",
"chai-as-promised": "^7.1.1",
"chai-deep-equal-ignore-undefined": "^1.1.1",
"cross-env": "^7.0.3",
"eslint": "^8.52.0",
"eslint-plugin-import": "^2.29.0",
"eslint-plugin-unused-imports": "^3.0.0",
Expand Down
4 changes: 3 additions & 1 deletion src/methods/generate-content.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import {
GenerateContentResult,
GenerateContentStreamResult,
SingleRequestOptions,
StreamCallbacks,
} from "../../types";
import { Task, makeModelRequest } from "../requests/request";
import { addHelpers } from "../requests/response-helpers";
Expand All @@ -31,6 +32,7 @@ export async function generateContentStream(
model: string,
params: GenerateContentRequest,
requestOptions: SingleRequestOptions,
callbacks?: StreamCallbacks,
): Promise<GenerateContentStreamResult> {
const response = await makeModelRequest(
model,
Expand All @@ -40,7 +42,7 @@ export async function generateContentStream(
JSON.stringify(params),
requestOptions,
);
return processStream(response);
return processStream(response, callbacks);
}

export async function generateContent(
Expand Down
3 changes: 3 additions & 0 deletions src/models/generative-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import {
SafetySetting,
SingleRequestOptions,
StartChatParams,
StreamCallbacks,
Tool,
ToolConfig,
} from "../../types";
Expand Down Expand Up @@ -132,6 +133,7 @@ export class GenerativeModel {
async generateContentStream(
request: GenerateContentRequest | string | Array<string | Part>,
requestOptions: SingleRequestOptions = {},
callbacks?: StreamCallbacks,
): Promise<GenerateContentStreamResult> {
const formattedParams = formatGenerateContentInput(request);
const generativeModelRequestOptions: SingleRequestOptions = {
Expand All @@ -151,6 +153,7 @@ export class GenerativeModel {
...formattedParams,
},
generativeModelRequestOptions,
callbacks,
);
}

Expand Down
59 changes: 59 additions & 0 deletions src/requests/stream-reader.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,65 @@ describe("processStream", () => {
}
expect(foundCitationMetadata).to.be.true;
});


describe("callbacks", () => {
it("chunk callbacks were called", (done) => {
const fakeResponse = getMockResponseStreaming(
"streaming-success-citations.txt",
);
processStream(fakeResponse as Response, {
onData: (data: string) => {
expect(data).to.not.be.empty;
},
onEnd:()=> done(),
});
});

it("error callbacks were called", (done) => {
const fakeResponse = getMockResponseStreaming(
"streaming-failure-prompt-blocked-safety.txt",
);
processStream(fakeResponse as Response, {
onError: (error: Error) => {
expect(error).to.be.instanceOf(GoogleGenerativeAIError);
done();
},
onEnd:()=> done(),
});
});


it("end callbacks were called", (done) => {
const fakeResponse = getMockResponseStreaming(
"streaming-success-basic-reply-short.txt",
);
processStream(fakeResponse as Response, {
onEnd:(data)=> {
expect(data).to.include("Cheyenne");
done();
}
});
});

it("all callbacks were called", (done) => {
const fakeResponse = getMockResponseStreaming(
"streaming-success-basic-reply-long.txt",
);
processStream(fakeResponse as Response, {
onEnd:(data)=> {
expect(data).to.include("**Cats:**");
expect(data).to.include("to their owners.");
done();
},
onData:(data:string)=> {
expect(data).to.not.be.empty;
},

});
});

});
});

describe("aggregateResponses", () => {
Expand Down
29 changes: 19 additions & 10 deletions src/requests/stream-reader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import {
GenerateContentResponse,
GenerateContentStreamResult,
Part,
StreamCallbacks,
} from "../../types";
import {
GoogleGenerativeAIAbortError,
Expand All @@ -38,7 +39,7 @@ const responseLineRE = /^data\: (.*)(?:\n\n|\r\r|\r\n\r\n)/;
*
* @param response - Response from a fetch call
*/
export function processStream(response: Response): GenerateContentStreamResult {
export function processStream(response: Response, callbacks?: StreamCallbacks): GenerateContentStreamResult {
const inputStream = response.body!.pipeThrough(
new TextDecoderStream("utf8", { fatal: true }),
);
Expand All @@ -47,21 +48,29 @@ export function processStream(response: Response): GenerateContentStreamResult {
const [stream1, stream2] = responseStream.tee();
return {
stream: generateResponseSequence(stream1),
response: getResponsePromise(stream2),
response: getResponsePromise(stream2, callbacks),
};
}

async function getResponsePromise(
stream: ReadableStream<GenerateContentResponse>,
callbacks?: StreamCallbacks,
): Promise<EnhancedGenerateContentResponse> {
const allResponses: GenerateContentResponse[] = [];
const reader = stream.getReader();
while (true) {
const { done, value } = await reader.read();
if (done) {
return addHelpers(aggregateResponses(allResponses));
}
allResponses.push(value);
try {
const allResponses: GenerateContentResponse[] = [];
const reader = stream.getReader();
while (true) {
const { done, value } = await reader.read();
if (done) {
callbacks?.onEnd((allResponses.reduce((acc, curr) => acc + addHelpers(curr).text(), "")));
return addHelpers(aggregateResponses(allResponses));
}
allResponses.push(value);
callbacks?.onData?.(addHelpers(value).text());
}
} catch (error) {
callbacks?.onError?.(error);
throw error;
}
}

Expand Down
10 changes: 10 additions & 0 deletions types/requests.ts
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,16 @@ export interface RequestOptions {
customHeaders?: Headers | Record<string, string>;
}

/**
* Callbacks for streaming responses.
* @public
*/
export interface StreamCallbacks {
onData?: (data: string) => void;
onEnd?: (data: string) => void;
onError?: (error: Error) => void;
}

/**
* Params passed to atomic asynchronous operations.
* @public
Expand Down