Skip to content

Add support for elastic embedding sizes under 768 dimensions #402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ for complete code.
npm install @google/generative-ai
```

1. Initialize the model
2. Initialize the model

```js
const { GoogleGenerativeAI } = require("@google/generative-ai");
Expand All @@ -44,7 +44,7 @@ const genAI = new GoogleGenerativeAI(process.env.API_KEY);
const model = genAI.getGenerativeModel({ model: "gemini-1.5-flash" });
```

1. Run a prompt
3. Run a prompt

```js
const prompt = "Does this look store-bought or homemade?";
Expand All @@ -59,6 +59,24 @@ const result = await model.generateContent([prompt, image]);
console.log(result.response.text());
```

## Elastic Embedding Sizes

The SDK supports elastic embedding sizes for text embedding models. You can specify the dimension size when creating embeddings:

```js
const model = genAI.getGenerativeModel({ model: "text-embedding-004" });

// Get an embedding with 128 dimensions instead of the default 768
const result = await model.embedContent({
content: { role: "user", parts: [{ text: "Hello world!" }] },
dimensions: 128
});

console.log("Embedding size:", result.embedding.values.length); // 128
```

Supported dimension sizes are: 128, 256, 384, 512, and 768 (default).

## Try out a sample app

This repository contains sample Node and web apps demonstrating how the SDK can
Expand All @@ -69,17 +87,17 @@ access and utilize the Gemini model for various use cases.
1. Check out this repository. \
`git clone https://github.com/google/generative-ai-js`

1. [Obtain an API key](https://makersuite.google.com/app/apikey) to use with
2. [Obtain an API key](https://makersuite.google.com/app/apikey) to use with
the Google AI SDKs.

2. cd into the `samples` folder and run `npm install`.
3. cd into the `samples` folder and run `npm install`.

3. Assign your API key to an environment variable: `export API_KEY=MY_API_KEY`.
4. Assign your API key to an environment variable: `export API_KEY=MY_API_KEY`.

4. Open the sample file you're interested in. Example: `text_generation.js`.
5. Open the sample file you're interested in. Example: `text_generation.js`.
In the `runAll()` function, comment out any samples you don't want to run.

5. Run the sample file. Example: `node text_generation.js`.
6. Run the sample file. Example: `node text_generation.js`.

## Documentation

Expand Down
101 changes: 101 additions & 0 deletions samples/elastic_embeddings.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { GoogleGenerativeAI } from "@google/generative-ai";

async function embedContentWithDimensions() {

const genAI = new GoogleGenerativeAI(process.env.API_KEY);
const model = genAI.getGenerativeModel({
model: "text-embedding-004",
});


const result = await model.embedContent({
content: { role: "user", parts: [{ text: "Hello world!" }] },
dimensions: 128
});

console.log("Embedding size:", result.embedding.values.length);
console.log("First 5 dimensions:", result.embedding.values.slice(0, 5));
}

async function compareEmbeddingSizes() {
const genAI = new GoogleGenerativeAI(process.env.API_KEY);
const model = genAI.getGenerativeModel({
model: "text-embedding-004",
});

const text = "The quick brown fox jumps over the lazy dog";


const dimensions = [128, 256, 384, 512, 768];

console.log(`Comparing embedding sizes for text: "${text}"`);

for (const dim of dimensions) {
const result = await model.embedContent({
content: { role: "user", parts: [{ text }] },
dimensions: dim
});

console.log(`Dimensions: ${dim}, Actual size: ${result.embedding.values.length}`);
}
}

async function batchEmbedContentsWithDimensions() {
const genAI = new GoogleGenerativeAI(process.env.API_KEY);
const model = genAI.getGenerativeModel({
model: "text-embedding-004",
});

function textToRequest(text, dimensions) {
return {
content: { role: "user", parts: [{ text }] },
dimensions
};
}

const result = await model.batchEmbedContents({
requests: [
textToRequest("What is the meaning of life?", 128),
textToRequest("How much wood would a woodchuck chuck?", 256),
textToRequest("How does the brain work?", 384),
],
});

for (let i = 0; i < result.embeddings.length; i++) {
console.log(`Embedding ${i+1} size: ${result.embeddings[i].values.length}`);
}
}

async function runAll() {
try {
console.log("=== Embedding with dimensions ===");
await embedContentWithDimensions();

console.log("\n=== Comparing embedding sizes ===");
await compareEmbeddingSizes();

console.log("\n=== Batch embeddings with dimensions ===");
await batchEmbedContentsWithDimensions();
} catch (error) {
console.error("Error:", error);
}
}

runAll();
34 changes: 34 additions & 0 deletions src/requests/request-helpers.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import { Content } from "../../types";
import {
formatCountTokensInput,
formatGenerateContentInput,
formatEmbedContentInput,
} from "./request-helpers";

use(sinonChai);
Expand Down Expand Up @@ -275,4 +276,37 @@ describe("request formatting methods", () => {
});
});
});
describe("formatEmbedContentInput", () => {
it("handles dimensions parameter", () => {
const result = formatEmbedContentInput({
content: { role: "user", parts: [{ text: "foo" }] },
dimensions: 128
});
expect(result).to.deep.equal({
content: { role: "user", parts: [{ text: "foo" }] },
dimensions: 128
});
});
it("validates dimensions with valid values", () => {
const validDimensions = [128, 256, 384, 512, 768];

for (const dim of validDimensions) {
const result = formatEmbedContentInput({
content: { role: "user", parts: [{ text: "foo" }] },
dimensions: dim
});
expect(result.dimensions).to.equal(dim);
}
});
it("throws error for invalid dimensions", () => {
const invalidDimensions = [100, 200, 300, 400, 600, 800];

for (const dim of invalidDimensions) {
expect(() => formatEmbedContentInput({
content: { role: "user", parts: [{ text: "foo" }] },
dimensions: dim
})).to.throw(/Invalid dimensions/);
}
});
});
});
31 changes: 27 additions & 4 deletions src/requests/request-helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,35 @@ export function formatGenerateContentInput(
return formattedRequest;
}

/**
*
* @param params
* @returns
*/
export function formatEmbedContentInput(
params: EmbedContentRequest | string | Array<string | Part>,
): EmbedContentRequest {
if (typeof params === "string" || Array.isArray(params)) {
const content = formatNewContent(params);
return { content };
if (typeof params === "string") {
return {
content: formatNewContent(params),
};
} else if (Array.isArray(params)) {
return {
content: formatNewContent(params),
};
} else {

const result = { ...params };

if (result.dimensions !== undefined) {
const validDimensions = [128, 256, 384, 512, 768];
if (!validDimensions.includes(result.dimensions)) {
throw new GoogleGenerativeAIRequestInputError(
`Invalid dimensions value: ${result.dimensions}. Valid values are: 128, 256, 384, 512, and 768.`
);
}
}

return result;
}
return params;
}
4 changes: 3 additions & 1 deletion types/requests.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ export interface GenerateContentRequest extends BaseParams {
}

/**
* Request sent to `generateContent` endpoint.
* Internal version of the request that includes a model name.
* @internal
*/
export interface _GenerateContentRequestInternal
Expand Down Expand Up @@ -170,6 +170,8 @@ export interface EmbedContentRequest {
content: Content;
taskType?: TaskType;
title?: string;

dimensions?: number;
}

/**
Expand Down