Skip to content

Commit 3f39f5b

Browse files
committed
Initial prototype
1 parent 5e7df2f commit 3f39f5b

File tree

12 files changed

+103
-33
lines changed

12 files changed

+103
-33
lines changed

common/api-review/vertexai.api.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ export interface GenerativeContentBlob {
324324

325325
// @public
326326
export class GenerativeModel extends VertexAIModel {
327-
constructor(vertexAI: VertexAI, modelParams: ModelParams, requestOptions?: RequestOptions);
327+
constructor(vertexAI: VertexAI, modelParams: ModelParams, requestOptions?: RequestOptions, developerAPIEnabled?: boolean);
328328
countTokens(request: CountTokensRequest | string | Array<string | Part>): Promise<CountTokensResponse>;
329329
generateContent(request: GenerateContentRequest | string | Array<string | Part>): Promise<GenerateContentResult>;
330330
generateContentStream(request: GenerateContentRequest | string | Array<string | Part>): Promise<GenerateContentStreamResult>;
@@ -344,7 +344,7 @@ export class GenerativeModel extends VertexAIModel {
344344
}
345345

346346
// @public
347-
export function getGenerativeModel(vertexAI: VertexAI, modelParams: ModelParams, requestOptions?: RequestOptions): GenerativeModel;
347+
export function getGenerativeModel(vertexAI: VertexAI, modelParams: ModelParams, requestOptions?: RequestOptions, enableDeveloperAPI?: boolean): GenerativeModel;
348348

349349
// @beta
350350
export function getImagenModel(vertexAI: VertexAI, modelParams: ImagenModelParams, requestOptions?: RequestOptions): ImagenModel;
@@ -776,6 +776,8 @@ export interface UsageMetadata {
776776
export interface VertexAI {
777777
app: FirebaseApp;
778778
// (undocumented)
779+
developerAPIEnabled: boolean;
780+
// (undocumented)
779781
location: string;
780782
}
781783

@@ -806,15 +808,17 @@ export const enum VertexAIErrorCode {
806808
// @public
807809
export abstract class VertexAIModel {
808810
// @internal
809-
protected constructor(vertexAI: VertexAI, modelName: string);
811+
protected constructor(vertexAI: VertexAI, modelName: string, developerAPIEnabled?: boolean);
810812
// @internal (undocumented)
811813
protected _apiSettings: ApiSettings;
812814
readonly model: string;
813-
static normalizeModelName(modelName: string): string;
815+
static normalizeModelName(modelName: string, developerAPIEnabled?: boolean): string;
814816
}
815817

816818
// @public
817819
export interface VertexAIOptions {
820+
// (undocumented)
821+
developerAPIEnabled: boolean;
818822
// (undocumented)
819823
location?: string;
820824
}

packages/vertexai/src/api.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import {
2929
} from './types';
3030
import { VertexAIError } from './errors';
3131
import { VertexAIModel, GenerativeModel, ImagenModel } from './models';
32+
import { createInstanceIdentifier } from './helpers';
3233

3334
export { ChatSession } from './methods/chat-session';
3435
export * from './requests/schema-builder';
@@ -57,8 +58,9 @@ export function getVertexAI(
5758
// Dependencies
5859
const vertexProvider: Provider<'vertexAI'> = _getProvider(app, VERTEX_TYPE);
5960

61+
const identifier = createInstanceIdentifier(options?.developerAPIEnabled, options?.location);
6062
return vertexProvider.getImmediate({
61-
identifier: options?.location || DEFAULT_LOCATION
63+
identifier
6264
});
6365
}
6466

@@ -71,15 +73,16 @@ export function getVertexAI(
7173
export function getGenerativeModel(
7274
vertexAI: VertexAI,
7375
modelParams: ModelParams,
74-
requestOptions?: RequestOptions
76+
requestOptions?: RequestOptions,
77+
enableDeveloperAPI?: boolean
7578
): GenerativeModel {
7679
if (!modelParams.model) {
7780
throw new VertexAIError(
7881
VertexAIErrorCode.NO_MODEL,
7982
`Must provide a model name. Example: getGenerativeModel({ model: 'my-model-name' })`
8083
);
8184
}
82-
return new GenerativeModel(vertexAI, modelParams, requestOptions);
85+
return new GenerativeModel(vertexAI, modelParams, requestOptions, enableDeveloperAPI);
8386
}
8487

8588
/**

packages/vertexai/src/constants.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ export const DEFAULT_LOCATION = 'us-central1';
2323

2424
export const DEFAULT_BASE_URL = 'https://firebasevertexai.googleapis.com';
2525

26+
export const DEVELOPER_API_BASE_URL = "https://generativelanguage.googleapis.com";
27+
2628
export const DEFAULT_API_VERSION = 'v1beta';
2729

2830
export const PACKAGE_VERSION = version;

packages/vertexai/src/helpers.ts

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import { DEFAULT_LOCATION } from "./constants";
2+
3+
/**
4+
* @internal
5+
*/
6+
export function createInstanceIdentifier(developerAPIEnabled?: boolean, location?: string): string {
7+
if (developerAPIEnabled) {
8+
return 'developerAPI';
9+
} else {
10+
return `vertexAI/${location || DEFAULT_LOCATION}`;
11+
}
12+
}
13+
14+
/**
15+
* @internal
16+
*/
17+
export function parseInstanceIdentifier(instanceIdentifier: string): { developerAPIEnabled: boolean, location?: string } {
18+
const identifierParts = instanceIdentifier.split("/");
19+
if (identifierParts[0] === 'developerAPI') {
20+
return {
21+
developerAPIEnabled: true,
22+
location: undefined
23+
}
24+
} else {
25+
const location = identifierParts[1];
26+
return {
27+
developerAPIEnabled: false,
28+
location
29+
}
30+
}
31+
}

packages/vertexai/src/index.ts

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
import { registerVersion, _registerComponent } from '@firebase/app';
2525
import { VertexAIService } from './service';
2626
import { VERTEX_TYPE } from './constants';
27-
import { Component, ComponentType } from '@firebase/component';
27+
import { Component, ComponentType, InstanceFactoryOptions } from '@firebase/component';
2828
import { name, version } from '../package.json';
29+
import { parseInstanceIdentifier } from './helpers';
30+
import { VertexAIOptions } from './public-types';
2931

3032
declare global {
3133
interface Window {
@@ -37,12 +39,23 @@ function registerVertex(): void {
3739
_registerComponent(
3840
new Component(
3941
VERTEX_TYPE,
40-
(container, { instanceIdentifier: location }) => {
42+
(container, options) => {
4143
// getImmediate for FirebaseApp will always succeed
4244
const app = container.getProvider('app').getImmediate();
4345
const auth = container.getProvider('auth-internal');
4446
const appCheckProvider = container.getProvider('app-check-internal');
45-
return new VertexAIService(app, auth, appCheckProvider, { location });
47+
48+
let vertexAIOptions: VertexAIOptions;
49+
if (options.instanceIdentifier) {
50+
vertexAIOptions = parseInstanceIdentifier(options.instanceIdentifier);
51+
} else {
52+
vertexAIOptions = {
53+
developerAPIEnabled: false,
54+
location: undefined
55+
}
56+
}
57+
58+
return new VertexAIService(app, auth, appCheckProvider, vertexAIOptions);
4659
},
4760
ComponentType.PUBLIC
4861
).setMultipleInstances(true)

packages/vertexai/src/methods/generate-content.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ export async function generateContent(
5656
apiSettings,
5757
/* stream */ false,
5858
JSON.stringify(params),
59-
requestOptions
59+
requestOptions,
6060
);
6161
const responseJson: GenerateContentResponse = await response.json();
6262
const enhancedResponse = createEnhancedContentResponse(responseJson);

packages/vertexai/src/models/generative-model.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ export class GenerativeModel extends VertexAIModel {
5959
constructor(
6060
vertexAI: VertexAI,
6161
modelParams: ModelParams,
62-
requestOptions?: RequestOptions
62+
requestOptions?: RequestOptions,
6363
) {
6464
super(vertexAI, modelParams.model);
6565
this.generationConfig = modelParams.generationConfig || {};

packages/vertexai/src/models/vertexai-model.ts

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,6 @@ export abstract class VertexAIModel {
5656
* @internal
5757
*/
5858
protected constructor(vertexAI: VertexAI, modelName: string) {
59-
this.model = VertexAIModel.normalizeModelName(modelName);
60-
6159
if (!vertexAI.app?.options?.apiKey) {
6260
throw new VertexAIError(
6361
VertexAIErrorCode.NO_API_KEY,
@@ -72,7 +70,8 @@ export abstract class VertexAIModel {
7270
this._apiSettings = {
7371
apiKey: vertexAI.app.options.apiKey,
7472
project: vertexAI.app.options.projectId,
75-
location: vertexAI.location
73+
developerAPIEnabled: vertexAI.developerAPIEnabled,
74+
location: vertexAI.location,
7675
};
7776

7877
if (
@@ -92,6 +91,8 @@ export abstract class VertexAIModel {
9291
this._apiSettings.getAuthToken = () =>
9392
(vertexAI as VertexAIService).auth!.getToken();
9493
}
94+
95+
this.model = this.normalizeModelName(modelName);
9596
}
9697
}
9798

@@ -101,19 +102,23 @@ export abstract class VertexAIModel {
101102
* @param modelName - The model name to normalize.
102103
* @returns The fully qualified model resource name.
103104
*/
104-
static normalizeModelName(modelName: string): string {
105+
normalizeModelName(modelName: string): string {
105106
let model: string;
106-
if (modelName.includes('/')) {
107-
if (modelName.startsWith('models/')) {
108-
// Add 'publishers/google' if the user is only passing in 'models/model-name'.
109-
model = `publishers/google/${modelName}`;
107+
if (this._apiSettings.developerAPIEnabled) {
108+
model = `models/${modelName}`;
109+
} else {
110+
if (modelName.includes('/')) {
111+
if (modelName.startsWith('models/')) {
112+
// Add 'publishers/google' if the user is only passing in 'models/model-name'.
113+
model = `publishers/google/${modelName}`;
114+
} else {
115+
// Any other custom format (e.g. tuned models) must be passed in correctly.
116+
model = modelName;
117+
}
110118
} else {
111-
// Any other custom format (e.g. tuned models) must be passed in correctly.
112-
model = modelName;
119+
// If path is not included, assume it's a non-tuned model.
120+
model = `publishers/google/models/${modelName}`;
113121
}
114-
} else {
115-
// If path is not included, assume it's a non-tuned model.
116-
model = `publishers/google/models/${modelName}`;
117122
}
118123

119124
return model;

packages/vertexai/src/public-types.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,15 @@ export interface VertexAI {
2828
* The {@link @firebase/app#FirebaseApp} this <code>{@link VertexAI}</code> instance is associated with.
2929
*/
3030
app: FirebaseApp;
31-
location: string;
31+
developerAPIEnabled: boolean;
32+
location: string; // This is only applicable if we're using the VertexAI API.
3233
}
3334

3435
/**
3536
* Options when initializing the Vertex AI in Firebase SDK.
3637
* @public
3738
*/
3839
export interface VertexAIOptions {
40+
developerAPIEnabled: boolean;
3941
location?: string;
4042
}

packages/vertexai/src/requests/request.ts

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import {
2222
DEFAULT_API_VERSION,
2323
DEFAULT_BASE_URL,
2424
DEFAULT_FETCH_TIMEOUT_MS,
25+
DEVELOPER_API_BASE_URL,
2526
LANGUAGE_TAG,
2627
PACKAGE_VERSION
2728
} from '../constants';
@@ -40,17 +41,23 @@ export class RequestUrl {
4041
public task: Task,
4142
public apiSettings: ApiSettings,
4243
public stream: boolean,
43-
public requestOptions?: RequestOptions
44+
public requestOptions?: RequestOptions,
4445
) {}
4546
toString(): string {
4647
// TODO: allow user-set option if that feature becomes available
4748
const apiVersion = DEFAULT_API_VERSION;
48-
const baseUrl = this.requestOptions?.baseUrl || DEFAULT_BASE_URL;
49-
let url = `${baseUrl}/${apiVersion}`;
50-
url += `/projects/${this.apiSettings.project}`;
51-
url += `/locations/${this.apiSettings.location}`;
52-
url += `/${this.model}`;
53-
url += `:${this.task}`;
49+
let url;
50+
if (this.apiSettings.developerAPIEnabled) {
51+
const baseUrl = this.requestOptions?.baseUrl || DEVELOPER_API_BASE_URL;
52+
url = `${baseUrl}/${apiVersion}/${this.model}:${this.task}`;
53+
} else {
54+
const baseUrl = this.requestOptions?.baseUrl || DEFAULT_BASE_URL;
55+
url = `${baseUrl}/${apiVersion}`;
56+
url += `/projects/${this.apiSettings.project}`;
57+
url += `/locations/${this.apiSettings.location}`;
58+
url += `/${this.model}`;
59+
url += `:${this.task}`;
60+
}
5461
if (this.stream) {
5562
url += '?alt=sse';
5663
}

0 commit comments

Comments
 (0)