diff --git a/README.md b/README.md index 906701af6..13423bd3e 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ for complete code. npm install @google/generative-ai ``` -1. Initialize the model +2. Initialize the model ```js const { GoogleGenerativeAI } = require("@google/generative-ai"); @@ -44,7 +44,7 @@ const genAI = new GoogleGenerativeAI(process.env.API_KEY); const model = genAI.getGenerativeModel({ model: "gemini-1.5-flash" }); ``` -1. Run a prompt +3. Run a prompt ```js const prompt = "Does this look store-bought or homemade?"; @@ -59,6 +59,24 @@ const result = await model.generateContent([prompt, image]); console.log(result.response.text()); ``` +## Elastic Embedding Sizes + +The SDK supports elastic embedding sizes for text embedding models. You can specify the dimension size when creating embeddings: + +```js +const model = genAI.getGenerativeModel({ model: "text-embedding-004" }); + +// Get an embedding with 128 dimensions instead of the default 768 +const result = await model.embedContent({ + content: { role: "user", parts: [{ text: "Hello world!" }] }, + dimensions: 128 +}); + +console.log("Embedding size:", result.embedding.values.length); // 128 +``` + +Supported dimension sizes are: 128, 256, 384, 512, and 768 (default). + ## Try out a sample app This repository contains sample Node and web apps demonstrating how the SDK can @@ -69,17 +87,17 @@ access and utilize the Gemini model for various use cases. 1. Check out this repository. \ `git clone https://github.com/google/generative-ai-js` -1. [Obtain an API key](https://makersuite.google.com/app/apikey) to use with +2. [Obtain an API key](https://makersuite.google.com/app/apikey) to use with the Google AI SDKs. -2. cd into the `samples` folder and run `npm install`. +3. cd into the `samples` folder and run `npm install`. -3. Assign your API key to an environment variable: `export API_KEY=MY_API_KEY`. +4. Assign your API key to an environment variable: `export API_KEY=MY_API_KEY`. -4. Open the sample file you're interested in. Example: `text_generation.js`. +5. Open the sample file you're interested in. Example: `text_generation.js`. In the `runAll()` function, comment out any samples you don't want to run. -5. Run the sample file. Example: `node text_generation.js`. +6. Run the sample file. Example: `node text_generation.js`. ## Documentation diff --git a/common/api-review/generative-ai.api.md b/common/api-review/generative-ai.api.md index 5880a650d..977043e8f 100644 --- a/common/api-review/generative-ai.api.md +++ b/common/api-review/generative-ai.api.md @@ -201,6 +201,8 @@ export interface EmbedContentRequest { // (undocumented) content: Content; // (undocumented) + dimensions?: number; + // (undocumented) taskType?: TaskType; // (undocumented) title?: string; @@ -525,7 +527,7 @@ export class GenerativeModel { countTokens(request: CountTokensRequest | string | Array, requestOptions?: SingleRequestOptions): Promise; embedContent(request: EmbedContentRequest | string | Array, requestOptions?: SingleRequestOptions): Promise; generateContent(request: GenerateContentRequest | string | Array, requestOptions?: SingleRequestOptions): Promise; - generateContentStream(request: GenerateContentRequest | string | Array, requestOptions?: SingleRequestOptions): Promise; + generateContentStream(request: GenerateContentRequest | string | Array, requestOptions?: SingleRequestOptions, streamCallbacks?: StreamCallbacks): Promise; // (undocumented) generationConfig: GenerationConfig; // (undocumented) @@ -841,6 +843,12 @@ export interface StartChatParams extends BaseParams { tools?: Tool[]; } +// @public +export interface StreamCallbacks { + onData?: (chunk: string) => void; + onDone?: (fullText: string) => void; +} + // @public export type StringSchema = SimpleStringSchema | EnumStringSchema; diff --git a/samples/elastic_embeddings.js b/samples/elastic_embeddings.js new file mode 100644 index 000000000..8ff05dc1b --- /dev/null +++ b/samples/elastic_embeddings.js @@ -0,0 +1,101 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { GoogleGenerativeAI } from "@google/generative-ai"; + +async function embedContentWithDimensions() { + + const genAI = new GoogleGenerativeAI(process.env.API_KEY); + const model = genAI.getGenerativeModel({ + model: "text-embedding-004", + }); + + + const result = await model.embedContent({ + content: { role: "user", parts: [{ text: "Hello world!" }] }, + dimensions: 128 + }); + + console.log("Embedding size:", result.embedding.values.length); + console.log("First 5 dimensions:", result.embedding.values.slice(0, 5)); +} + +async function compareEmbeddingSizes() { + const genAI = new GoogleGenerativeAI(process.env.API_KEY); + const model = genAI.getGenerativeModel({ + model: "text-embedding-004", + }); + + const text = "The quick brown fox jumps over the lazy dog"; + + + const dimensions = [128, 256, 384, 512, 768]; + + console.log(`Comparing embedding sizes for text: "${text}"`); + + for (const dim of dimensions) { + const result = await model.embedContent({ + content: { role: "user", parts: [{ text }] }, + dimensions: dim + }); + + console.log(`Dimensions: ${dim}, Actual size: ${result.embedding.values.length}`); + } +} + +async function batchEmbedContentsWithDimensions() { + const genAI = new GoogleGenerativeAI(process.env.API_KEY); + const model = genAI.getGenerativeModel({ + model: "text-embedding-004", + }); + + function textToRequest(text, dimensions) { + return { + content: { role: "user", parts: [{ text }] }, + dimensions + }; + } + + const result = await model.batchEmbedContents({ + requests: [ + textToRequest("What is the meaning of life?", 128), + textToRequest("How much wood would a woodchuck chuck?", 256), + textToRequest("How does the brain work?", 384), + ], + }); + + for (let i = 0; i < result.embeddings.length; i++) { + console.log(`Embedding ${i+1} size: ${result.embeddings[i].values.length}`); + } +} + +async function runAll() { + try { + console.log("=== Embedding with dimensions ==="); + await embedContentWithDimensions(); + + console.log("\n=== Comparing embedding sizes ==="); + await compareEmbeddingSizes(); + + console.log("\n=== Batch embeddings with dimensions ==="); + await batchEmbedContentsWithDimensions(); + } catch (error) { + console.error("Error:", error); + } +} + +runAll(); \ No newline at end of file diff --git a/samples/stream_callbacks.js b/samples/stream_callbacks.js new file mode 100644 index 000000000..ec1b46a16 --- /dev/null +++ b/samples/stream_callbacks.js @@ -0,0 +1,59 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +const { GoogleGenerativeAI } = require("@google/generative-ai"); + +// This sample demonstrates how to use streamCallbacks for receiving +// streaming responses without manually handling Node.js streams. + +// Access your API key as an environment variable +const genAI = new GoogleGenerativeAI(process.env.API_KEY); + +// For text-only input, use the gemini-pro model +async function runWithCallbacks() { + const model = genAI.getGenerativeModel({ model: "gemini-pro" }); + + console.log("Generating response with callbacks..."); + + await model.generateContentStream("Tell me a joke", {}, { + onData: (chunk) => process.stdout.write(chunk), + onDone: (fullText) => console.log("\n\nFull response:\n", fullText), + }); +} + +// Alternative usage with only onDone callback +async function runWithOnlyDoneCallback() { + const model = genAI.getGenerativeModel({ model: "gemini-pro" }); + + console.log("\nGenerating response with only onDone callback..."); + + await model.generateContentStream("Tell me another joke", {}, { + onDone: (fullText) => console.log("Full response:\n", fullText), + }); +} + +// Run the demos +async function main() { + try { + await runWithCallbacks(); + await runWithOnlyDoneCallback(); + } catch (error) { + console.error("Error:", error); + } +} + +main(); \ No newline at end of file diff --git a/src/models/generative-model.ts b/src/models/generative-model.ts index 7cd3fe622..b2d18540b 100644 --- a/src/models/generative-model.ts +++ b/src/models/generative-model.ts @@ -40,6 +40,7 @@ import { StartChatParams, Tool, ToolConfig, + StreamCallbacks, } from "../../types"; import { ChatSession } from "../methods/chat-session"; import { countTokens } from "../methods/count-tokens"; @@ -128,17 +129,23 @@ export class GenerativeModel { * Fields set in the optional {@link SingleRequestOptions} parameter will * take precedence over the {@link RequestOptions} values provided to * {@link GoogleGenerativeAI.getGenerativeModel }. + * + * The optional {@link StreamCallbacks} parameter allows receiving text + * chunks via callbacks without manually handling Node.js streams. + * - onData: Called with each chunk of text as it arrives + * - onDone: Called with the full text when streaming is complete */ async generateContentStream( request: GenerateContentRequest | string | Array, requestOptions: SingleRequestOptions = {}, + streamCallbacks?: StreamCallbacks ): Promise { const formattedParams = formatGenerateContentInput(request); const generativeModelRequestOptions: SingleRequestOptions = { ...this._requestOptions, ...requestOptions, }; - return generateContentStream( + const result = await generateContentStream( this.apiKey, this.model, { @@ -152,6 +159,34 @@ export class GenerativeModel { }, generativeModelRequestOptions, ); + + // If streamCallbacks are provided, set up the handlers + if (streamCallbacks?.onData || streamCallbacks?.onDone) { + // Handle onData callback for each chunk + if (streamCallbacks.onData) { + const originalStream = result.stream; + result.stream = (async function* () { + let fullText = ''; + for await (const chunk of originalStream) { + const text = chunk.text(); + fullText += text; + streamCallbacks.onData?.(text); + yield chunk; + } + // Call onDone with the full text when complete + if (streamCallbacks.onDone) { + streamCallbacks.onDone(fullText); + } + })(); + } else if (streamCallbacks.onDone) { + // If only onDone is provided, collect the full text + result.response.then(response => { + streamCallbacks.onDone?.(response.text()); + }); + } + } + + return result; } /** diff --git a/src/models/stream-callbacks.test.ts b/src/models/stream-callbacks.test.ts new file mode 100644 index 000000000..972e653ea --- /dev/null +++ b/src/models/stream-callbacks.test.ts @@ -0,0 +1,159 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { expect } from "chai"; +import { describe, it, beforeEach, afterEach } from "mocha"; +import { SinonStub, stub, useFakeTimers, SinonFakeTimers } from "sinon"; +import * as generateContentModule from "../methods/generate-content"; +import { GenerativeModel } from "./generative-model"; +import { StreamCallbacks } from "../../types"; + +describe("GenerativeModel streamCallbacks", () => { + let generateContentStreamStub: SinonStub; + let mockStream: AsyncGenerator; + let mockResponse: Promise; + let clock: SinonFakeTimers; + + beforeEach(() => { + clock = useFakeTimers(); + + // Mock the response and stream + const mockText = "This is a test response"; + const textChunks = ["This ", "is ", "a ", "test ", "response"]; + + mockStream = (async function* () { + for (const chunk of textChunks) { + yield { + text: () => chunk, + candidates: [{ content: { parts: [{ text: chunk }] } }] + }; + } + })(); + + mockResponse = Promise.resolve({ + text: () => mockText, + candidates: [{ content: { parts: [{ text: mockText }] } }] + }); + + // Stub the generateContentStream method + generateContentStreamStub = stub( + generateContentModule, + "generateContentStream" + ).resolves({ + stream: mockStream, + response: mockResponse + }); + }); + + afterEach(() => { + generateContentStreamStub.restore(); + clock.restore(); + }); + + it("should call onData for each chunk", async () => { + const model = new GenerativeModel({ + model: "gemini-pro", + apiKey: "test-api-key" + }); + + const chunks: string[] = []; + const streamCallbacks: StreamCallbacks = { + onData: (chunk) => chunks.push(chunk) + }; + + const result = await model.generateContentStream( + "Test prompt", + {}, + streamCallbacks + ); + + // Consume the stream + for await (const _ of result.stream) { + // Do nothing, just consume + } + + expect(chunks).to.deep.equal(["This ", "is ", "a ", "test ", "response"]); + }); + + it("should call onDone with full text when streaming completes", async () => { + const model = new GenerativeModel({ + model: "gemini-pro", + apiKey: "test-api-key" + }); + + let doneText = ""; + const streamCallbacks: StreamCallbacks = { + onData: () => {}, + onDone: (fullText) => { doneText = fullText; } + }; + + const result = await model.generateContentStream( + "Test prompt", + {}, + streamCallbacks + ); + + // Consume the stream + for await (const _ of result.stream) { + // Do nothing, just consume + } + + expect(doneText).to.equal("This is a test response"); + }); + + it("should call only onDone when onData is not provided", async () => { + const model = new GenerativeModel({ + model: "gemini-pro", + apiKey: "test-api-key" + }); + + let doneText = ""; + const streamCallbacks: StreamCallbacks = { + onDone: (fullText) => { doneText = fullText; } + }; + + await model.generateContentStream( + "Test prompt", + {}, + streamCallbacks + ); + + // Resolve the response promise + await mockResponse; + + expect(doneText).to.equal("This is a test response"); + }); + + it("should not modify the result when streamCallbacks are not provided", async () => { + const model = new GenerativeModel({ + model: "gemini-pro", + apiKey: "test-api-key" + }); + + const result = await model.generateContentStream("Test prompt"); + + // Verify that the result has the expected structure + expect(result).to.have.property("stream"); + expect(result).to.have.property("response"); + + // Verify that the generateContentStream was called with expected parameters + expect(generateContentStreamStub.callCount).to.equal(1); + expect(generateContentStreamStub.firstCall.args[2]).to.deep.include({ + contents: [{ role: "user", parts: [{ text: "Test prompt" }] }] + }); + }); +}); \ No newline at end of file diff --git a/src/requests/request-helpers.test.ts b/src/requests/request-helpers.test.ts index f3d46cd05..2851eafc8 100644 --- a/src/requests/request-helpers.test.ts +++ b/src/requests/request-helpers.test.ts @@ -22,6 +22,7 @@ import { Content } from "../../types"; import { formatCountTokensInput, formatGenerateContentInput, + formatEmbedContentInput, } from "./request-helpers"; use(sinonChai); @@ -275,4 +276,37 @@ describe("request formatting methods", () => { }); }); }); + describe("formatEmbedContentInput", () => { + it("handles dimensions parameter", () => { + const result = formatEmbedContentInput({ + content: { role: "user", parts: [{ text: "foo" }] }, + dimensions: 128 + }); + expect(result).to.deep.equal({ + content: { role: "user", parts: [{ text: "foo" }] }, + dimensions: 128 + }); + }); + it("validates dimensions with valid values", () => { + const validDimensions = [128, 256, 384, 512, 768]; + + for (const dim of validDimensions) { + const result = formatEmbedContentInput({ + content: { role: "user", parts: [{ text: "foo" }] }, + dimensions: dim + }); + expect(result.dimensions).to.equal(dim); + } + }); + it("throws error for invalid dimensions", () => { + const invalidDimensions = [100, 200, 300, 400, 600, 800]; + + for (const dim of invalidDimensions) { + expect(() => formatEmbedContentInput({ + content: { role: "user", parts: [{ text: "foo" }] }, + dimensions: dim + })).to.throw(/Invalid dimensions/); + } + }); + }); }); diff --git a/src/requests/request-helpers.ts b/src/requests/request-helpers.ts index 58232e057..f42607897 100644 --- a/src/requests/request-helpers.ts +++ b/src/requests/request-helpers.ts @@ -168,12 +168,35 @@ export function formatGenerateContentInput( return formattedRequest; } +/** + * + * @param params + * @returns + */ export function formatEmbedContentInput( params: EmbedContentRequest | string | Array, ): EmbedContentRequest { - if (typeof params === "string" || Array.isArray(params)) { - const content = formatNewContent(params); - return { content }; + if (typeof params === "string") { + return { + content: formatNewContent(params), + }; + } else if (Array.isArray(params)) { + return { + content: formatNewContent(params), + }; + } else { + + const result = { ...params }; + + if (result.dimensions !== undefined) { + const validDimensions = [128, 256, 384, 512, 768]; + if (!validDimensions.includes(result.dimensions)) { + throw new GoogleGenerativeAIRequestInputError( + `Invalid dimensions value: ${result.dimensions}. Valid values are: 128, 256, 384, 512, and 768.` + ); + } + } + + return result; } - return params; } diff --git a/types/requests.ts b/types/requests.ts index 81285bc20..15d6f025f 100644 --- a/types/requests.ts +++ b/types/requests.ts @@ -62,7 +62,7 @@ export interface GenerateContentRequest extends BaseParams { } /** - * Request sent to `generateContent` endpoint. + * Internal version of the request that includes a model name. * @internal */ export interface _GenerateContentRequestInternal @@ -169,7 +169,8 @@ export interface _CountTokensRequestInternal { export interface EmbedContentRequest { content: Content; taskType?: TaskType; - title?: string; + title?: string; + dimensions?: number; } /** @@ -225,6 +226,17 @@ export interface SingleRequestOptions extends RequestOptions { signal?: AbortSignal; } +/** + * Callbacks for handling streaming responses without managing Node.js streams directly. + * @public + */ +export interface StreamCallbacks { + /** Called for each chunk of text as it arrives */ + onData?: (chunk: string) => void; + /** Called with the full text when streaming is complete */ + onDone?: (fullText: string) => void; +} + /** * Defines a tool that model can call to access external knowledge. * @public