Skip to content

Support for Fine-tuning APIs #448

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions src/gen-ai.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import { ModelParams } from "../types";
import { GenerativeModel, GoogleGenerativeAI } from "./gen-ai";
import { expect } from "chai";
import * as sinon from "sinon";
import * as fineTuningMethods from "./methods/fine-tuning";

const fakeContents = [{ role: "user", parts: [{ text: "hello" }] }];

Expand Down Expand Up @@ -119,3 +121,63 @@ describe("GoogleGenerativeAI", () => {
);
});
});

// ------------------ Added Fine-Tuning Test Cases ------------------

describe("GoogleGenerativeAI Fine-Tuning API Methods", () => {
let genAI: GoogleGenerativeAI;
let sandbox: sinon.SinonSandbox;

beforeEach(() => {
sandbox = sinon.createSandbox();
genAI = new GoogleGenerativeAI("apikey");
});

afterEach(() => {
sandbox.restore();
});

it("listTunedModels returns a list of tuned models", async () => {
const fakeResponse = { tunedModels: [{ name: "model-1" }, { name: "model-2" }] };
const listStub = sandbox
.stub(fineTuningMethods, "listTunedModels")
.resolves(fakeResponse);
const response = await genAI.listTunedModels(2);
expect(response).to.deep.equal(fakeResponse);
expect(listStub.calledOnceWith("apikey", 2)).to.be.true;
});

it("createTunedModel returns a tuned model creation response", async () => {
const displayName = "Test Model";
const trainingData = [{ input: "example", output: "response" }];
const fakeResponse = { name: "tuned-model-1" };
const createStub = sandbox
.stub(fineTuningMethods, "createTunedModel")
.resolves(fakeResponse);
const response = await genAI.createTunedModel(displayName, trainingData);
expect(response).to.deep.equal(fakeResponse);
expect(createStub.calledOnceWith("apikey", displayName, trainingData)).to.be.true;
});

it("checkTuningStatus returns the tuning status", async () => {
const operationName = "operation-123";
const fakeResponse = { metadata: { completedPercent: 50 } };
const checkStub = sandbox
.stub(fineTuningMethods, "checkTuningStatus")
.resolves(fakeResponse);
const response = await genAI.checkTuningStatus(operationName);
expect(response).to.deep.equal(fakeResponse);
expect(checkStub.calledOnceWith("apikey", operationName)).to.be.true;
});

it("deleteTunedModel returns the delete tuned model response", async () => {
const modelName = "tuned-model-1";
const fakeResponse = { success: true };
const deleteStub = sandbox
.stub(fineTuningMethods, "deleteTunedModel")
.resolves(fakeResponse);
const response = await genAI.deleteTunedModel(modelName);
expect(response).to.deep.equal(fakeResponse);
expect(deleteStub.calledOnceWith("apikey", modelName)).to.be.true;
});
});
62 changes: 59 additions & 3 deletions src/gen-ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@ import {
import { CachedContent, ModelParams, RequestOptions } from "../types";
import { GenerativeModel } from "./models/generative-model";

import {
CheckTuningStatusResponse,
CreateTunedModelResponse,
DeleteTunedModelResponse,
ListTunedModelsResponse,
checkTuningStatus,
createTunedModel,
deleteTunedModel,
listTunedModels,
} from "./methods/fine-tuning";

export { ChatSession } from "./methods/chat-session";
export { GenerativeModel };

Expand All @@ -30,7 +41,7 @@ export { GenerativeModel };
* @public
*/
export class GoogleGenerativeAI {
constructor(public apiKey: string) {}
constructor(public apiKey: string) { }

/**
* Gets a {@link GenerativeModel} instance for the provided model name.
Expand All @@ -42,7 +53,7 @@ export class GoogleGenerativeAI {
if (!modelParams.model) {
throw new GoogleGenerativeAIError(
`Must provide a model name. ` +
`Example: genai.getGenerativeModel({ model: 'my-model-name' })`,
`Example: genai.getGenerativeModel({ model: 'my-model-name' })`,
);
}
return new GenerativeModel(this.apiKey, modelParams, requestOptions);
Expand Down Expand Up @@ -93,7 +104,7 @@ export class GoogleGenerativeAI {
}
throw new GoogleGenerativeAIRequestInputError(
`Different value for "${key}" specified in modelParams` +
` (${modelParams[key]}) and cachedContent (${cachedContent[key]})`,
` (${modelParams[key]}) and cachedContent (${cachedContent[key]})`,
);
}
}
Expand All @@ -112,4 +123,49 @@ export class GoogleGenerativeAI {
requestOptions,
);
}

/**
* Lists tuned models.
* @param pageSize - Optional number of models to list. Default is 5.
* @returns A promise that resolves to a {@link ListTunedModelsResponse}.
*/
async listTunedModels(pageSize = 5): Promise<ListTunedModelsResponse> {
return listTunedModels(this.apiKey, pageSize);
}

/**
* Creates a tuned model with the specified display name and training data.
* @param displayName - The name to display for the tuned model.
* @param trainingData - The training dataset.
* @returns A promise that resolves to a {@link CreateTunedModelResponse}.
*/
async createTunedModel(
displayName: string,
trainingData: unknown
): Promise<CreateTunedModelResponse> {
return createTunedModel(this.apiKey, displayName, trainingData);
}

/**
* Checks the tuning status of a fine-tuning operation.
* @param operationName - The operation ID to check.
* @returns A promise that resolves to a {@link CheckTuningStatusResponse}.
*/
async checkTuningStatus(
operationName: string
): Promise<CheckTuningStatusResponse> {
return checkTuningStatus(this.apiKey, operationName);
}

/**
* Deletes a tuned model by name.
* @param modelName - The name of the tuned model to delete.
* @returns A promise that resolves to a {@link DeleteTunedModelResponse}.
*/
async deleteTunedModel(
modelName: string
): Promise<DeleteTunedModelResponse> {
return deleteTunedModel(this.apiKey, modelName);
}

}
124 changes: 124 additions & 0 deletions src/methods/fine-tuning.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/**
* Example interfaces for your fine-tuning API responses.
* Adjust the fields to match the real API responses.
*/
export interface ListTunedModelsResponse {
tunedModels: Array<{ name: string }>;
}

export interface CreateTunedModelResponse {
name: string;
}

export interface CheckTuningStatusResponse {
metadata: {
completedPercent: number;
// Add more fields if needed
};
}

export interface DeleteTunedModelResponse {
success: boolean;
}

/**
* A simple fetchWithRetry helper. (No changes needed here)
*/
export async function fetchWithRetry(
url: string,
options: RequestInit,
retries = 3,
delay = 1000
): Promise<Response> {
let lastError: unknown;
for (let i = 0; i < retries; i++) {
try {
const response = await fetch(url, options);
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
return response;
} catch (error) {
lastError = error;
await new Promise((resolve) => setTimeout(resolve, delay));
}
}
throw lastError;
}

/**
* Lists tuned models.
*/
export async function listTunedModels(
apiKey: string,
pageSize = 5
): Promise<ListTunedModelsResponse> {
const url = `https://generativelanguage.googleapis.com/v1beta/tunedModels?page_size=${pageSize}&key=${apiKey}`;
const response = await fetchWithRetry(url, {
method: "GET",
headers: { "Content-Type": "application/json" },
});
return response.json() as Promise<ListTunedModelsResponse>;
}

/**
* Creates a tuned model.
*/
export async function createTunedModel(
apiKey: string,
displayName: string,
trainingData: unknown
): Promise<CreateTunedModelResponse> {
const url = `https://generativelanguage.googleapis.com/v1beta/tunedModels?key=${apiKey}`;
const payload = {
display_name: displayName,
base_model: "models/gemini-1.5-flash-001-tuning",
tuning_task: {
hyperparameters: {
batch_size: 2,
learning_rate: 0.001,
epoch_count: 5,
},
training_data: {
examples: trainingData,
},
},
};

const response = await fetchWithRetry(url, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(payload),
});
return response.json() as Promise<CreateTunedModelResponse>;
}

/**
* Checks the tuning status of a fine-tuning operation.
*/
export async function checkTuningStatus(
apiKey: string,
operationName: string
): Promise<CheckTuningStatusResponse> {
const url = `https://generativelanguage.googleapis.com/v1beta/${operationName}?key=${apiKey}`;
const response = await fetchWithRetry(url, {
method: "GET",
headers: { "Content-Type": "application/json" },
});
return response.json() as Promise<CheckTuningStatusResponse>;
}

/**
* Deletes a tuned model.
*/
export async function deleteTunedModel(
apiKey: string,
modelName: string
): Promise<DeleteTunedModelResponse> {
const url = `https://generativelanguage.googleapis.com/v1beta/tunedModels/${modelName}?key=${apiKey}`;
const response = await fetchWithRetry(url, {
method: "DELETE",
headers: { "Content-Type": "application/json" },
});
return response.json() as Promise<DeleteTunedModelResponse>;
}