Skip to content

Commit 180f091

Browse files
committed
Introduce VertexAIModel base class, add documentation, and respond to other comments
1 parent ec35231 commit 180f091

18 files changed

+552
-463
lines changed

common/api-review/vertexai.api.md

Lines changed: 35 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -323,16 +323,14 @@ export interface GenerativeContentBlob {
323323
}
324324

325325
// @public
326-
export class GenerativeModel {
326+
export class GenerativeModel extends VertexAIModel {
327327
constructor(vertexAI: VertexAI, modelParams: ModelParams, requestOptions?: RequestOptions);
328328
countTokens(request: CountTokensRequest | string | Array<string | Part>): Promise<CountTokensResponse>;
329329
generateContent(request: GenerateContentRequest | string | Array<string | Part>): Promise<GenerateContentResult>;
330330
generateContentStream(request: GenerateContentRequest | string | Array<string | Part>): Promise<GenerateContentStreamResult>;
331331
// (undocumented)
332332
generationConfig: GenerationConfig;
333333
// (undocumented)
334-
model: string;
335-
// (undocumented)
336334
requestOptions?: RequestOptions;
337335
// (undocumented)
338336
safetySettings: SafetySetting[];
@@ -432,77 +430,53 @@ export enum HarmSeverity {
432430
HARM_SEVERITY_NEGLIGIBLE = "HARM_SEVERITY_NEGLIGIBLE"
433431
}
434432

435-
// @public (undocumented)
433+
// @public
436434
export enum ImagenAspectRatio {
437435
// (undocumented)
438-
CLASSIC_LANDSCAPE = "4:3",
436+
LANDSCAPE_16x9 = "16:9",
439437
// (undocumented)
440-
CLASSIC_PORTRAIT = "3:4",
438+
LANDSCAPE_3x4 = "3:4",
441439
// (undocumented)
442-
PORTRAIT = "9:16",
440+
PORTRAIT_4x3 = "4:3",
443441
// (undocumented)
444-
SQUARE = "1:1",
442+
PORTRAIT_9x16 = "9:16",
445443
// (undocumented)
446-
WIDESCREEN = "16:9"
444+
SQUARE = "1:1"
447445
}
448446

449-
// Warning: (ae-incompatible-release-tags) The symbol "ImagenGCSImage" is marked as @public, but its signature references "ImagenImage" which is marked as @internal
450-
//
451447
// @public
452-
export interface ImagenGCSImage extends ImagenImage {
448+
export interface ImagenGCSImage {
453449
gcsURI: string;
450+
mimeType: string;
454451
}
455452

456-
// @public (undocumented)
453+
// @public
457454
export interface ImagenGCSImageResponse {
458-
// (undocumented)
459455
filteredReason?: string;
460-
// (undocumented)
461456
images: ImagenGCSImage[];
462457
}
463458

464-
// @public (undocumented)
459+
// @public
465460
export interface ImagenGenerationConfig {
466-
// (undocumented)
467461
aspectRatio?: ImagenAspectRatio;
468-
// (undocumented)
469462
negativePrompt?: string;
470-
// (undocumented)
471463
numberOfImages?: number;
472464
}
473465

474-
// Warning: (ae-internal-missing-underscore) The name "ImagenImage" should be prefixed with an underscore because the declaration is marked as @internal
475-
//
476-
// @internal
477-
export interface ImagenImage {
478-
// (undocumented)
479-
mimeType: string;
480-
}
481-
482-
// @public (undocumented)
483-
export interface ImagenImageFormat {
466+
// @public
467+
export class ImagenImageFormat {
484468
// (undocumented)
485469
compressionQuality?: number;
470+
static jpeg(compressionQuality: number): ImagenImageFormat;
486471
// (undocumented)
487472
mimeType: string;
473+
static png(): ImagenImageFormat;
488474
}
489475

490-
// @public (undocumented)
491-
export interface ImagenImageReponse {
492-
// (undocumented)
493-
filteredReason?: string;
494-
// Warning: (ae-incompatible-release-tags) The symbol "images" is marked as @public, but its signature references "ImagenImage" which is marked as @internal
495-
//
496-
// (undocumented)
497-
images: ImagenImage[];
498-
}
499-
500-
// Warning: (ae-incompatible-release-tags) The symbol "ImagenInlineImage" is marked as @public, but its signature references "ImagenImage" which is marked as @internal
501-
//
502476
// @public
503-
export interface ImagenInlineImage extends ImagenImage {
504-
// (undocumented)
477+
export interface ImagenInlineImage {
505478
bytesBase64Encoded: string;
479+
mimeType: string;
506480
}
507481

508482
// @public
@@ -512,63 +486,43 @@ export interface ImagenInlineImageResponse {
512486
}
513487

514488
// @public
515-
export class ImagenModel {
489+
export class ImagenModel extends VertexAIModel {
516490
constructor(vertexAI: VertexAI, modelParams: ImagenModelParams, requestOptions?: RequestOptions | undefined);
517491
generateImages(prompt: string, imagenRequestOptions?: ImagenGenerationConfig): Promise<ImagenInlineImageResponse>;
518492
generateImagesGCS(prompt: string, gcsURI: string, imagenRequestOptions?: ImagenGenerationConfig): Promise<ImagenGCSImageResponse>;
493+
readonly modelConfig: ImagenModelConfig;
519494
// (undocumented)
520-
model: string;
521-
}
495+
readonly requestOptions?: RequestOptions | undefined;
496+
}
522497

523-
// @public (undocumented)
498+
// @public
524499
export interface ImagenModelConfig {
525-
// (undocumented)
526500
addWatermark?: boolean;
527-
// (undocumented)
528501
imageFormat?: ImagenImageFormat;
529-
// (undocumented)
530502
safetySettings?: ImagenSafetySettings;
531503
}
532504

533505
// @public
534506
export interface ImagenModelParams extends ImagenModelConfig {
535-
// (undocumented)
536507
model: string;
537508
}
538509

539-
// @public (undocumented)
510+
// @public
540511
export enum ImagenPersonFilterLevel {
541-
// (undocumented)
542512
ALLOW_ADULT = "allow_adult",
543-
// (undocumented)
544513
ALLOW_ALL = "allow_all",
545-
// (undocumented)
546514
BLOCK_ALL = "dont_allow"
547515
}
548516

549-
// Warning: (ae-internal-missing-underscore) The name "ImagenRequestConfig" should be prefixed with an underscore because the declaration is marked as @internal
550-
//
551-
// @internal
552-
export interface ImagenRequestConfig extends ImagenModelConfig, ImagenGenerationConfig {
553-
// (undocumented)
554-
gcsURI?: string;
555-
// (undocumented)
556-
prompt: string;
557-
}
558-
559-
// @public (undocumented)
517+
// @public
560518
export enum ImagenSafetyFilterLevel {
561-
// (undocumented)
562519
BLOCK_LOW_AND_ABOVE = "block_low_and_above",
563-
// (undocumented)
564520
BLOCK_MEDIUM_AND_ABOVE = "block_medium_and_above",
565-
// (undocumented)
566521
BLOCK_NONE = "block_none",
567-
// (undocumented)
568522
BLOCK_ONLY_HIGH = "block_only_high"
569523
}
570524

571-
// @public (undocumented)
525+
// @public
572526
export interface ImagenSafetySettings {
573527
personFilterLevel?: ImagenPersonFilterLevel;
574528
safetyFilterLevel?: ImagenSafetyFilterLevel;
@@ -592,9 +546,6 @@ export class IntegerSchema extends Schema {
592546
constructor(schemaParams?: SchemaParams);
593547
}
594548

595-
// @public
596-
export function jpeg(compressionQuality: number): ImagenImageFormat;
597-
598549
// @public
599550
export interface ModelParams extends BaseParams {
600551
// (undocumented)
@@ -638,37 +589,9 @@ export interface ObjectSchemaInterface extends SchemaInterface {
638589
// @public
639590
export type Part = TextPart | InlineDataPart | FunctionCallPart | FunctionResponsePart | FileDataPart;
640591

641-
// @public
642-
export function png(): ImagenImageFormat;
643-
644592
// @public
645593
export const POSSIBLE_ROLES: readonly ["user", "model", "function", "system"];
646594

647-
// Warning: (ae-internal-missing-underscore) The name "PredictRequestBody" should be prefixed with an underscore because the declaration is marked as @internal
648-
//
649-
// @internal
650-
export interface PredictRequestBody {
651-
// (undocumented)
652-
instances: [
653-
{
654-
prompt: string;
655-
}
656-
];
657-
// (undocumented)
658-
parameters: {
659-
sampleCount: number;
660-
aspectRatio: string;
661-
mimeType: string;
662-
compressionQuality?: number;
663-
negativePrompt?: string;
664-
storageUri?: string;
665-
addWatermark?: boolean;
666-
safetyFilterLevel?: string;
667-
personGeneration?: string;
668-
includeRaiReason: boolean;
669-
};
670-
}
671-
672595
// @public
673596
export interface PromptFeedback {
674597
// (undocumented)
@@ -696,14 +619,6 @@ export interface RetrievedContextAttribution {
696619
// @public
697620
export type Role = (typeof POSSIBLE_ROLES)[number];
698621

699-
// @public (undocumented)
700-
export interface SafetyAttributes {
701-
// (undocumented)
702-
categories: string[];
703-
// (undocumented)
704-
scores: number[];
705-
}
706-
707622
// @public
708623
export interface SafetyRating {
709624
// (undocumented)
@@ -902,6 +817,16 @@ export const enum VertexAIErrorCode {
902817
RESPONSE_ERROR = "response-error"
903818
}
904819

820+
// @public
821+
export class VertexAIModel {
822+
// @internal
823+
protected constructor(vertexAI: VertexAI, modelName: string);
824+
// (undocumented)
825+
protected _apiSettings: ApiSettings;
826+
readonly model: string;
827+
static normalizeModelName(modelName: string): string;
828+
}
829+
905830
// @public
906831
export interface VertexAIOptions {
907832
// (undocumented)

packages/vertexai/src/api.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ describe('Top level API', () => {
129129
);
130130
}
131131
});
132-
it('getGenerativeModel gets an ImagenModel', () => {
132+
it('getImagenModel gets an ImagenModel', () => {
133133
const genModel = getImagenModel(fakeVertexAI, { model: 'my-model' });
134134
expect(genModel).to.be.an.instanceOf(ImagenModel);
135135
expect(genModel.model).to.equal('publishers/google/models/my-model');

packages/vertexai/src/api.ts

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,12 @@ import {
2828
VertexAIErrorCode
2929
} from './types';
3030
import { VertexAIError } from './errors';
31-
import { GenerativeModel } from './models/generative-model';
32-
import { ImagenModel, jpeg, png } from './models/imagen-model';
31+
import { VertexAIModel, GenerativeModel, ImagenModel } from './models';
3332

3433
export { ChatSession } from './methods/chat-session';
3534
export * from './requests/schema-builder';
36-
37-
export { jpeg, png };
38-
39-
export { GenerativeModel, ImagenModel };
40-
35+
export { ImagenImageFormat } from './requests/imagen-image-format';
36+
export { VertexAIModel, GenerativeModel, ImagenModel };
4137
export { VertexAIError };
4238

4339
declare module '@firebase/component' {

packages/vertexai/src/models/generative-model.test.ts

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -37,28 +37,6 @@ const fakeVertexAI: VertexAI = {
3737
};
3838

3939
describe('GenerativeModel', () => {
40-
it('handles plain model name', () => {
41-
const genModel = new GenerativeModel(fakeVertexAI, { model: 'my-model' });
42-
expect(genModel.model).to.equal('publishers/google/models/my-model');
43-
});
44-
it('handles models/ prefixed model name', () => {
45-
const genModel = new GenerativeModel(fakeVertexAI, {
46-
model: 'models/my-model'
47-
});
48-
expect(genModel.model).to.equal('publishers/google/models/my-model');
49-
});
50-
it('handles full model name', () => {
51-
const genModel = new GenerativeModel(fakeVertexAI, {
52-
model: 'publishers/google/models/my-model'
53-
});
54-
expect(genModel.model).to.equal('publishers/google/models/my-model');
55-
});
56-
it('handles prefixed tuned model name', () => {
57-
const genModel = new GenerativeModel(fakeVertexAI, {
58-
model: 'tunedModels/my-model'
59-
});
60-
expect(genModel.model).to.equal('tunedModels/my-model');
61-
});
6240
it('passes params through to generateContent', async () => {
6341
const genModel = new GenerativeModel(fakeVertexAI, {
6442
model: 'my-model',

packages/vertexai/src/models/generative-model.ts

Lines changed: 4 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -33,27 +33,22 @@ import {
3333
SafetySetting,
3434
StartChatParams,
3535
Tool,
36-
ToolConfig,
37-
VertexAIErrorCode
36+
ToolConfig
3837
} from '../types';
39-
import { VertexAIError } from '../errors';
4038
import { ChatSession } from '../methods/chat-session';
4139
import { countTokens } from '../methods/count-tokens';
4240
import {
4341
formatGenerateContentInput,
4442
formatSystemInstruction
4543
} from '../requests/request-helpers';
4644
import { VertexAI } from '../public-types';
47-
import { ApiSettings } from '../types/internal';
48-
import { VertexAIService } from '../service';
45+
import { VertexAIModel } from './vertexai-model';
4946

5047
/**
5148
* Class for generative model APIs.
5249
* @public
5350
*/
54-
export class GenerativeModel {
55-
private _apiSettings: ApiSettings;
56-
model: string;
51+
export class GenerativeModel extends VertexAIModel {
5752
generationConfig: GenerationConfig;
5853
safetySettings: SafetySetting[];
5954
requestOptions?: RequestOptions;
@@ -66,44 +61,7 @@ export class GenerativeModel {
6661
modelParams: ModelParams,
6762
requestOptions?: RequestOptions
6863
) {
69-
if (!vertexAI.app?.options?.apiKey) {
70-
throw new VertexAIError(
71-
VertexAIErrorCode.NO_API_KEY,
72-
`The "apiKey" field is empty in the local Firebase config. Firebase VertexAI requires this field to contain a valid API key.`
73-
);
74-
} else if (!vertexAI.app?.options?.projectId) {
75-
throw new VertexAIError(
76-
VertexAIErrorCode.NO_PROJECT_ID,
77-
`The "projectId" field is empty in the local Firebase config. Firebase VertexAI requires this field to contain a valid project ID.`
78-
);
79-
} else {
80-
this._apiSettings = {
81-
apiKey: vertexAI.app.options.apiKey,
82-
project: vertexAI.app.options.projectId,
83-
location: vertexAI.location
84-
};
85-
if ((vertexAI as VertexAIService).appCheck) {
86-
this._apiSettings.getAppCheckToken = () =>
87-
(vertexAI as VertexAIService).appCheck!.getToken();
88-
}
89-
90-
if ((vertexAI as VertexAIService).auth) {
91-
this._apiSettings.getAuthToken = () =>
92-
(vertexAI as VertexAIService).auth!.getToken();
93-
}
94-
}
95-
if (modelParams.model.includes('/')) {
96-
if (modelParams.model.startsWith('models/')) {
97-
// Add "publishers/google" if the user is only passing in 'models/model-name'.
98-
this.model = `publishers/google/${modelParams.model}`;
99-
} else {
100-
// Any other custom format (e.g. tuned models) must be passed in correctly.
101-
this.model = modelParams.model;
102-
}
103-
} else {
104-
// If path is not included, assume it's a non-tuned model.
105-
this.model = `publishers/google/models/${modelParams.model}`;
106-
}
64+
super(vertexAI, modelParams.model);
10765
this.generationConfig = modelParams.generationConfig || {};
10866
this.safetySettings = modelParams.safetySettings || [];
10967
this.tools = modelParams.tools;

0 commit comments

Comments
 (0)