From 0c0d0d4661ecd72c968d47f7b1a7b923b3a89b54 Mon Sep 17 00:00:00 2001 From: krishna agrawal Date: Sun, 30 Mar 2025 10:03:39 +0530 Subject: [PATCH] - --- src/gen-ai.test.ts | 62 +++++++++++++++++++ src/gen-ai.ts | 62 ++++++++++++++++++- src/methods/fine-tuning.ts | 124 +++++++++++++++++++++++++++++++++++++ 3 files changed, 245 insertions(+), 3 deletions(-) create mode 100644 src/methods/fine-tuning.ts diff --git a/src/gen-ai.test.ts b/src/gen-ai.test.ts index fd17aa4ef..832df2340 100644 --- a/src/gen-ai.test.ts +++ b/src/gen-ai.test.ts @@ -17,6 +17,8 @@ import { ModelParams } from "../types"; import { GenerativeModel, GoogleGenerativeAI } from "./gen-ai"; import { expect } from "chai"; +import * as sinon from "sinon"; +import * as fineTuningMethods from "./methods/fine-tuning"; const fakeContents = [{ role: "user", parts: [{ text: "hello" }] }]; @@ -119,3 +121,63 @@ describe("GoogleGenerativeAI", () => { ); }); }); + +// ------------------ Added Fine-Tuning Test Cases ------------------ + +describe("GoogleGenerativeAI Fine-Tuning API Methods", () => { + let genAI: GoogleGenerativeAI; + let sandbox: sinon.SinonSandbox; + + beforeEach(() => { + sandbox = sinon.createSandbox(); + genAI = new GoogleGenerativeAI("apikey"); + }); + + afterEach(() => { + sandbox.restore(); + }); + + it("listTunedModels returns a list of tuned models", async () => { + const fakeResponse = { tunedModels: [{ name: "model-1" }, { name: "model-2" }] }; + const listStub = sandbox + .stub(fineTuningMethods, "listTunedModels") + .resolves(fakeResponse); + const response = await genAI.listTunedModels(2); + expect(response).to.deep.equal(fakeResponse); + expect(listStub.calledOnceWith("apikey", 2)).to.be.true; + }); + + it("createTunedModel returns a tuned model creation response", async () => { + const displayName = "Test Model"; + const trainingData = [{ input: "example", output: "response" }]; + const fakeResponse = { name: "tuned-model-1" }; + const createStub = sandbox + .stub(fineTuningMethods, "createTunedModel") + .resolves(fakeResponse); + const response = await genAI.createTunedModel(displayName, trainingData); + expect(response).to.deep.equal(fakeResponse); + expect(createStub.calledOnceWith("apikey", displayName, trainingData)).to.be.true; + }); + + it("checkTuningStatus returns the tuning status", async () => { + const operationName = "operation-123"; + const fakeResponse = { metadata: { completedPercent: 50 } }; + const checkStub = sandbox + .stub(fineTuningMethods, "checkTuningStatus") + .resolves(fakeResponse); + const response = await genAI.checkTuningStatus(operationName); + expect(response).to.deep.equal(fakeResponse); + expect(checkStub.calledOnceWith("apikey", operationName)).to.be.true; + }); + + it("deleteTunedModel returns the delete tuned model response", async () => { + const modelName = "tuned-model-1"; + const fakeResponse = { success: true }; + const deleteStub = sandbox + .stub(fineTuningMethods, "deleteTunedModel") + .resolves(fakeResponse); + const response = await genAI.deleteTunedModel(modelName); + expect(response).to.deep.equal(fakeResponse); + expect(deleteStub.calledOnceWith("apikey", modelName)).to.be.true; + }); +}); diff --git a/src/gen-ai.ts b/src/gen-ai.ts index f65f489db..052064210 100644 --- a/src/gen-ai.ts +++ b/src/gen-ai.ts @@ -22,6 +22,17 @@ import { import { CachedContent, ModelParams, RequestOptions } from "../types"; import { GenerativeModel } from "./models/generative-model"; +import { + CheckTuningStatusResponse, + CreateTunedModelResponse, + DeleteTunedModelResponse, + ListTunedModelsResponse, + checkTuningStatus, + createTunedModel, + deleteTunedModel, + listTunedModels, +} from "./methods/fine-tuning"; + export { ChatSession } from "./methods/chat-session"; export { GenerativeModel }; @@ -30,7 +41,7 @@ export { GenerativeModel }; * @public */ export class GoogleGenerativeAI { - constructor(public apiKey: string) {} + constructor(public apiKey: string) { } /** * Gets a {@link GenerativeModel} instance for the provided model name. @@ -42,7 +53,7 @@ export class GoogleGenerativeAI { if (!modelParams.model) { throw new GoogleGenerativeAIError( `Must provide a model name. ` + - `Example: genai.getGenerativeModel({ model: 'my-model-name' })`, + `Example: genai.getGenerativeModel({ model: 'my-model-name' })`, ); } return new GenerativeModel(this.apiKey, modelParams, requestOptions); @@ -93,7 +104,7 @@ export class GoogleGenerativeAI { } throw new GoogleGenerativeAIRequestInputError( `Different value for "${key}" specified in modelParams` + - ` (${modelParams[key]}) and cachedContent (${cachedContent[key]})`, + ` (${modelParams[key]}) and cachedContent (${cachedContent[key]})`, ); } } @@ -112,4 +123,49 @@ export class GoogleGenerativeAI { requestOptions, ); } + + /** + * Lists tuned models. + * @param pageSize - Optional number of models to list. Default is 5. + * @returns A promise that resolves to a {@link ListTunedModelsResponse}. + */ + async listTunedModels(pageSize = 5): Promise { + return listTunedModels(this.apiKey, pageSize); + } + + /** + * Creates a tuned model with the specified display name and training data. + * @param displayName - The name to display for the tuned model. + * @param trainingData - The training dataset. + * @returns A promise that resolves to a {@link CreateTunedModelResponse}. + */ + async createTunedModel( + displayName: string, + trainingData: unknown + ): Promise { + return createTunedModel(this.apiKey, displayName, trainingData); + } + + /** + * Checks the tuning status of a fine-tuning operation. + * @param operationName - The operation ID to check. + * @returns A promise that resolves to a {@link CheckTuningStatusResponse}. + */ + async checkTuningStatus( + operationName: string + ): Promise { + return checkTuningStatus(this.apiKey, operationName); + } + + /** + * Deletes a tuned model by name. + * @param modelName - The name of the tuned model to delete. + * @returns A promise that resolves to a {@link DeleteTunedModelResponse}. + */ + async deleteTunedModel( + modelName: string + ): Promise { + return deleteTunedModel(this.apiKey, modelName); + } + } diff --git a/src/methods/fine-tuning.ts b/src/methods/fine-tuning.ts new file mode 100644 index 000000000..8705d2e83 --- /dev/null +++ b/src/methods/fine-tuning.ts @@ -0,0 +1,124 @@ +/** + * Example interfaces for your fine-tuning API responses. + * Adjust the fields to match the real API responses. + */ +export interface ListTunedModelsResponse { + tunedModels: Array<{ name: string }>; +} + +export interface CreateTunedModelResponse { + name: string; +} + +export interface CheckTuningStatusResponse { + metadata: { + completedPercent: number; + // Add more fields if needed + }; +} + +export interface DeleteTunedModelResponse { + success: boolean; +} + +/** + * A simple fetchWithRetry helper. (No changes needed here) + */ +export async function fetchWithRetry( + url: string, + options: RequestInit, + retries = 3, + delay = 1000 +): Promise { + let lastError: unknown; + for (let i = 0; i < retries; i++) { + try { + const response = await fetch(url, options); + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + return response; + } catch (error) { + lastError = error; + await new Promise((resolve) => setTimeout(resolve, delay)); + } + } + throw lastError; +} + +/** + * Lists tuned models. + */ +export async function listTunedModels( + apiKey: string, + pageSize = 5 +): Promise { + const url = `https://generativelanguage.googleapis.com/v1beta/tunedModels?page_size=${pageSize}&key=${apiKey}`; + const response = await fetchWithRetry(url, { + method: "GET", + headers: { "Content-Type": "application/json" }, + }); + return response.json() as Promise; +} + +/** + * Creates a tuned model. + */ +export async function createTunedModel( + apiKey: string, + displayName: string, + trainingData: unknown +): Promise { + const url = `https://generativelanguage.googleapis.com/v1beta/tunedModels?key=${apiKey}`; + const payload = { + display_name: displayName, + base_model: "models/gemini-1.5-flash-001-tuning", + tuning_task: { + hyperparameters: { + batch_size: 2, + learning_rate: 0.001, + epoch_count: 5, + }, + training_data: { + examples: trainingData, + }, + }, + }; + + const response = await fetchWithRetry(url, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(payload), + }); + return response.json() as Promise; +} + +/** + * Checks the tuning status of a fine-tuning operation. + */ +export async function checkTuningStatus( + apiKey: string, + operationName: string +): Promise { + const url = `https://generativelanguage.googleapis.com/v1beta/${operationName}?key=${apiKey}`; + const response = await fetchWithRetry(url, { + method: "GET", + headers: { "Content-Type": "application/json" }, + }); + return response.json() as Promise; +} + +/** + * Deletes a tuned model. + */ +export async function deleteTunedModel( + apiKey: string, + modelName: string +): Promise { + const url = `https://generativelanguage.googleapis.com/v1beta/tunedModels/${modelName}?key=${apiKey}`; + const response = await fetchWithRetry(url, { + method: "DELETE", + headers: { "Content-Type": "application/json" }, + }); + return response.json() as Promise; +} \ No newline at end of file