From 5a8bc4980d34500021426abc1719172b58171880 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 8 Mar 2025 17:40:17 -0500 Subject: [PATCH] Add support for elastic embedding sizes under 768 dimensions --- README.md | 32 +++++++-- samples/elastic_embeddings.js | 101 +++++++++++++++++++++++++++ src/requests/request-helpers.test.ts | 34 +++++++++ src/requests/request-helpers.ts | 31 ++++++-- types/requests.ts | 4 +- 5 files changed, 190 insertions(+), 12 deletions(-) create mode 100644 samples/elastic_embeddings.js diff --git a/README.md b/README.md index 906701af6..13423bd3e 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ for complete code. npm install @google/generative-ai ``` -1. Initialize the model +2. Initialize the model ```js const { GoogleGenerativeAI } = require("@google/generative-ai"); @@ -44,7 +44,7 @@ const genAI = new GoogleGenerativeAI(process.env.API_KEY); const model = genAI.getGenerativeModel({ model: "gemini-1.5-flash" }); ``` -1. Run a prompt +3. Run a prompt ```js const prompt = "Does this look store-bought or homemade?"; @@ -59,6 +59,24 @@ const result = await model.generateContent([prompt, image]); console.log(result.response.text()); ``` +## Elastic Embedding Sizes + +The SDK supports elastic embedding sizes for text embedding models. You can specify the dimension size when creating embeddings: + +```js +const model = genAI.getGenerativeModel({ model: "text-embedding-004" }); + +// Get an embedding with 128 dimensions instead of the default 768 +const result = await model.embedContent({ + content: { role: "user", parts: [{ text: "Hello world!" }] }, + dimensions: 128 +}); + +console.log("Embedding size:", result.embedding.values.length); // 128 +``` + +Supported dimension sizes are: 128, 256, 384, 512, and 768 (default). + ## Try out a sample app This repository contains sample Node and web apps demonstrating how the SDK can @@ -69,17 +87,17 @@ access and utilize the Gemini model for various use cases. 1. Check out this repository. \ `git clone https://github.com/google/generative-ai-js` -1. [Obtain an API key](https://makersuite.google.com/app/apikey) to use with +2. [Obtain an API key](https://makersuite.google.com/app/apikey) to use with the Google AI SDKs. -2. cd into the `samples` folder and run `npm install`. +3. cd into the `samples` folder and run `npm install`. -3. Assign your API key to an environment variable: `export API_KEY=MY_API_KEY`. +4. Assign your API key to an environment variable: `export API_KEY=MY_API_KEY`. -4. Open the sample file you're interested in. Example: `text_generation.js`. +5. Open the sample file you're interested in. Example: `text_generation.js`. In the `runAll()` function, comment out any samples you don't want to run. -5. Run the sample file. Example: `node text_generation.js`. +6. Run the sample file. Example: `node text_generation.js`. ## Documentation diff --git a/samples/elastic_embeddings.js b/samples/elastic_embeddings.js new file mode 100644 index 000000000..8ff05dc1b --- /dev/null +++ b/samples/elastic_embeddings.js @@ -0,0 +1,101 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { GoogleGenerativeAI } from "@google/generative-ai"; + +async function embedContentWithDimensions() { + + const genAI = new GoogleGenerativeAI(process.env.API_KEY); + const model = genAI.getGenerativeModel({ + model: "text-embedding-004", + }); + + + const result = await model.embedContent({ + content: { role: "user", parts: [{ text: "Hello world!" }] }, + dimensions: 128 + }); + + console.log("Embedding size:", result.embedding.values.length); + console.log("First 5 dimensions:", result.embedding.values.slice(0, 5)); +} + +async function compareEmbeddingSizes() { + const genAI = new GoogleGenerativeAI(process.env.API_KEY); + const model = genAI.getGenerativeModel({ + model: "text-embedding-004", + }); + + const text = "The quick brown fox jumps over the lazy dog"; + + + const dimensions = [128, 256, 384, 512, 768]; + + console.log(`Comparing embedding sizes for text: "${text}"`); + + for (const dim of dimensions) { + const result = await model.embedContent({ + content: { role: "user", parts: [{ text }] }, + dimensions: dim + }); + + console.log(`Dimensions: ${dim}, Actual size: ${result.embedding.values.length}`); + } +} + +async function batchEmbedContentsWithDimensions() { + const genAI = new GoogleGenerativeAI(process.env.API_KEY); + const model = genAI.getGenerativeModel({ + model: "text-embedding-004", + }); + + function textToRequest(text, dimensions) { + return { + content: { role: "user", parts: [{ text }] }, + dimensions + }; + } + + const result = await model.batchEmbedContents({ + requests: [ + textToRequest("What is the meaning of life?", 128), + textToRequest("How much wood would a woodchuck chuck?", 256), + textToRequest("How does the brain work?", 384), + ], + }); + + for (let i = 0; i < result.embeddings.length; i++) { + console.log(`Embedding ${i+1} size: ${result.embeddings[i].values.length}`); + } +} + +async function runAll() { + try { + console.log("=== Embedding with dimensions ==="); + await embedContentWithDimensions(); + + console.log("\n=== Comparing embedding sizes ==="); + await compareEmbeddingSizes(); + + console.log("\n=== Batch embeddings with dimensions ==="); + await batchEmbedContentsWithDimensions(); + } catch (error) { + console.error("Error:", error); + } +} + +runAll(); \ No newline at end of file diff --git a/src/requests/request-helpers.test.ts b/src/requests/request-helpers.test.ts index f3d46cd05..2851eafc8 100644 --- a/src/requests/request-helpers.test.ts +++ b/src/requests/request-helpers.test.ts @@ -22,6 +22,7 @@ import { Content } from "../../types"; import { formatCountTokensInput, formatGenerateContentInput, + formatEmbedContentInput, } from "./request-helpers"; use(sinonChai); @@ -275,4 +276,37 @@ describe("request formatting methods", () => { }); }); }); + describe("formatEmbedContentInput", () => { + it("handles dimensions parameter", () => { + const result = formatEmbedContentInput({ + content: { role: "user", parts: [{ text: "foo" }] }, + dimensions: 128 + }); + expect(result).to.deep.equal({ + content: { role: "user", parts: [{ text: "foo" }] }, + dimensions: 128 + }); + }); + it("validates dimensions with valid values", () => { + const validDimensions = [128, 256, 384, 512, 768]; + + for (const dim of validDimensions) { + const result = formatEmbedContentInput({ + content: { role: "user", parts: [{ text: "foo" }] }, + dimensions: dim + }); + expect(result.dimensions).to.equal(dim); + } + }); + it("throws error for invalid dimensions", () => { + const invalidDimensions = [100, 200, 300, 400, 600, 800]; + + for (const dim of invalidDimensions) { + expect(() => formatEmbedContentInput({ + content: { role: "user", parts: [{ text: "foo" }] }, + dimensions: dim + })).to.throw(/Invalid dimensions/); + } + }); + }); }); diff --git a/src/requests/request-helpers.ts b/src/requests/request-helpers.ts index 58232e057..f42607897 100644 --- a/src/requests/request-helpers.ts +++ b/src/requests/request-helpers.ts @@ -168,12 +168,35 @@ export function formatGenerateContentInput( return formattedRequest; } +/** + * + * @param params + * @returns + */ export function formatEmbedContentInput( params: EmbedContentRequest | string | Array, ): EmbedContentRequest { - if (typeof params === "string" || Array.isArray(params)) { - const content = formatNewContent(params); - return { content }; + if (typeof params === "string") { + return { + content: formatNewContent(params), + }; + } else if (Array.isArray(params)) { + return { + content: formatNewContent(params), + }; + } else { + + const result = { ...params }; + + if (result.dimensions !== undefined) { + const validDimensions = [128, 256, 384, 512, 768]; + if (!validDimensions.includes(result.dimensions)) { + throw new GoogleGenerativeAIRequestInputError( + `Invalid dimensions value: ${result.dimensions}. Valid values are: 128, 256, 384, 512, and 768.` + ); + } + } + + return result; } - return params; } diff --git a/types/requests.ts b/types/requests.ts index 81285bc20..639dda5b6 100644 --- a/types/requests.ts +++ b/types/requests.ts @@ -62,7 +62,7 @@ export interface GenerateContentRequest extends BaseParams { } /** - * Request sent to `generateContent` endpoint. + * Internal version of the request that includes a model name. * @internal */ export interface _GenerateContentRequestInternal @@ -170,6 +170,8 @@ export interface EmbedContentRequest { content: Content; taskType?: TaskType; title?: string; + + dimensions?: number; } /**