diff --git a/package-lock.json b/package-lock.json index a11f272ff..a1f826652 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@google/generative-ai", - "version": "0.21.0", + "version": "0.24.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "@google/generative-ai", - "version": "0.21.0", + "version": "0.24.0", "license": "Apache-2.0", "devDependencies": { "@changesets/cli": "^2.27.1", diff --git a/src/gen-ai.ts b/src/gen-ai.ts index f65f489db..3f5ca401b 100644 --- a/src/gen-ai.ts +++ b/src/gen-ai.ts @@ -21,6 +21,7 @@ import { } from "./errors"; import { CachedContent, ModelParams, RequestOptions } from "../types"; import { GenerativeModel } from "./models/generative-model"; +import { ImageModel } from "./models/image-model"; export { ChatSession } from "./methods/chat-session"; export { GenerativeModel }; @@ -48,6 +49,20 @@ export class GoogleGenerativeAI { return new GenerativeModel(this.apiKey, modelParams, requestOptions); } + getImageModel( + modelParams: ModelParams, + requestOptions?: RequestOptions, + ): ImageModel { + if (!modelParams.model) { + throw new GoogleGenerativeAIError( + `Must provide a model name. ` + + `Example: genai.getGenerativeModel({ model: 'my-model-name' })`, + ); + } + return new ImageModel(this.apiKey, modelParams, requestOptions); + } + + /** * Creates a {@link GenerativeModel} instance from provided content cache. */ diff --git a/src/methods/generate-image.ts b/src/methods/generate-image.ts new file mode 100644 index 000000000..71b266e0b --- /dev/null +++ b/src/methods/generate-image.ts @@ -0,0 +1,20 @@ +import { GenerateImageRequest, GenerateImagesResult, SingleRequestOptions } from "../../types"; +import { Task, makeModelRequest } from "../requests/request"; + + export async function generateImages( + apiKey: string, + model: string, + params: GenerateImageRequest, + requestOptions: SingleRequestOptions, + ): Promise { + const response = await makeModelRequest( + model, + Task.GENERATE_IMAGES, + apiKey, + /* stream */ false, + JSON.stringify(params), + requestOptions, + ); + return response.json(); + } + \ No newline at end of file diff --git a/src/models/generative-model.ts b/src/models/generative-model.ts index 7cd3fe622..90bca767c 100644 --- a/src/models/generative-model.ts +++ b/src/models/generative-model.ts @@ -32,6 +32,7 @@ import { GenerateContentResult, GenerateContentStreamResult, GenerationConfig, + ImageModelParams, ModelParams, Part, RequestOptions, @@ -85,7 +86,8 @@ export class GenerativeModel { ); this.cachedContent = modelParams.cachedContent; } - + + /** * Makes a single non-streaming call to the model * and returns an object containing a single {@link GenerateContentResponse}. diff --git a/src/models/image-model.ts b/src/models/image-model.ts new file mode 100644 index 000000000..9fad847b2 --- /dev/null +++ b/src/models/image-model.ts @@ -0,0 +1,45 @@ +import { BaseImageParams, GenerateImagesResult, ImageModelParams, SingleRequestOptions } from "../../types"; +import { RequestOptions } from "../server"; +import {generateImages} from "../methods/generate-image" +export class ImageModel{ + model:string; + modelParams:ImageModelParams; + + constructor( + public apiKey: string, + modelParams: ImageModelParams, + private _requestOptions: RequestOptions = {}, + ) + { + if (modelParams.model.includes("/")) { + // Models may be named "models/model-name" or "tunedModels/model-name" + this.model = modelParams.model; + } else { + // If path is not included, assume it's a non-tuned model. + this.model = `models/${modelParams.model}`; + } + this.modelParams=modelParams; + } + + + async generateImages(prompt:string,requestConfig?:BaseImageParams,requestOptions?:SingleRequestOptions + ):Promise{ + const Params: ImageModelParams={ + model:this.model, + ...requestConfig, + } + + const generativeModelRequestOptions: SingleRequestOptions = { + ...this._requestOptions, + ...requestOptions, + }; + + return generateImages(this.apiKey,this.model,{ + instances:[{prompt}], + paramaters: { + ...this.modelParams, + ...Params, + }, + },generativeModelRequestOptions); + } +} diff --git a/src/requests/request-helpers.ts b/src/requests/request-helpers.ts index 58232e057..0cee1a783 100644 --- a/src/requests/request-helpers.ts +++ b/src/requests/request-helpers.ts @@ -29,6 +29,7 @@ import { GoogleGenerativeAIError, GoogleGenerativeAIRequestInputError, } from "../errors"; +import { generateImages } from "../methods/generate-image"; export function formatSystemInstruction( input?: string | Part | Content, diff --git a/src/requests/request.ts b/src/requests/request.ts index 64c3703f9..2fbc2d7d2 100644 --- a/src/requests/request.ts +++ b/src/requests/request.ts @@ -40,6 +40,8 @@ export enum Task { COUNT_TOKENS = "countTokens", EMBED_CONTENT = "embedContent", BATCH_EMBED_CONTENTS = "batchEmbedContents", + GENERATE_IMAGES="predict", + } export class RequestUrl { diff --git a/types/requests.ts b/types/requests.ts index 81285bc20..73538c431 100644 --- a/types/requests.ts +++ b/types/requests.ts @@ -34,6 +34,44 @@ export interface BaseParams { generationConfig?: GenerationConfig; } +export interface BaseImageParams{ + + guidanceScale?:number; + seed?:number; + safetyFilterLevel?:safetyFilterLevel; + personGeneration?:PersonGeneration; + includeSafetyAttributes?:boolean; + includeRaiReason?:boolean; + language?:ImagePromptLanguage; + outputMimeType?:string; + outputCompressionQuality?:number; + addWatermark?:boolean; + enhancePrompt?:boolean; +} +enum safetyFilterLevel{ + BLOCK_LOW_AND_ABOVE='BLOCK_LOW_AND_ABOVE', + BLOCK_NONE='BLOCK_NONE', + BLOCK_ONLY_HIGH='BLOCK_ONLY_HIGH', + BLOCK_MEDIUM_AND_ABOVE='BLOCK_MEDIUM_AND_ABOVE' + +} + +enum PersonGeneration{ + DONT_ALLOW='DONT_ALLOW', + ALLOW_ADULT='ALLOW_ADULT', + ALLOW_ALL='ALLOW_ALL' +} +enum ImagePromptLanguage{ + auto ='auto', + en='en', + ja='ja', + ko='ko', + hi='hi' +} + +export interface ImageModelParams extends BaseImageParams{ + model: string; +} /** * Params passed to {@link GoogleGenerativeAI.getGenerativeModel}. * @public @@ -61,6 +99,16 @@ export interface GenerateContentRequest extends BaseParams { cachedContent?: string; } + +/** + * Request sent to `generateImage` endpoint. + * @public + */ +export interface GenerateImageRequest{ + instances:Array<{prompt:string}>; + paramaters:BaseImageParams; +} + /** * Request sent to `generateContent` endpoint. * @internal @@ -153,6 +201,12 @@ export interface CountTokensRequest { contents?: Content[]; } +export interface ImageModelParams { + model: string; + aspectRatio?: string; +} + + /** * Params for calling {@link GenerativeModel.countTokens} * @internal diff --git a/types/responses.ts b/types/responses.ts index 6648fbff6..1fbaa255f 100644 --- a/types/responses.ts +++ b/types/responses.ts @@ -33,6 +33,24 @@ export interface GenerateContentResult { response: EnhancedGenerateContentResponse; } +interface Image{ + gcsUri?:string; + bytesBase64Encoded:string; + mimeType:string; +} + + +interface GeneratedImage{ + image?:Image[]; + raiFilterReason?:string[]; + enhancedPrompt?:string[]; + +} + +export interface GenerateImagesResult{ + predictions?:GeneratedImage[]; + } + /** * Result object returned from generateContentStream() call. * Iterate over `stream` to get chunks as they come in and/or