Skip to content

Commit 0075baa

Browse files
feat: Allow labels in GenerateContentRequest (#466) (#501)
* feat: Allow labels in GenerateContentRequest * Add optional `labels` record to GenerateContentRequest * include labels in sent `generateContentRequest` * Add unit tests * Test that `generateContent` provides labels to the Vertex endpoint * Test that `generateContentStream` provides labels to the Vertex endpoint * cleanup log and spacing * fix style --------- Co-authored-by: Yvonne Yu <[email protected]>
1 parent b62483a commit 0075baa

File tree

3 files changed

+57
-0
lines changed

3 files changed

+57
-0
lines changed

src/functions/generate_content.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ export async function generateContent(
7777
safetySettings: request.safetySettings ?? safetySettings,
7878
tools: request.tools ?? tools,
7979
toolConfig: request.toolConfig ?? toolConfig,
80+
labels: request.labels,
8081
};
8182
const response: Response | undefined = await postRequest({
8283
region: location,
@@ -132,6 +133,7 @@ export async function generateContentStream(
132133
safetySettings: request.safetySettings ?? safetySettings,
133134
tools: request.tools ?? tools,
134135
toolConfig: request.toolConfig ?? toolConfig,
136+
labels: request.labels,
135137
};
136138
const response = await postRequest({
137139
region: location,

src/functions/test/functions_test.ts

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,10 @@ const TEST_TOOLS_WITH_RAG: Tool[] = [
226226
},
227227
];
228228

229+
const TEST_LABELS: Record<string, string> = {
230+
test_key: 'test_value',
231+
};
232+
229233
const fetchResponseObj = {
230234
status: 200,
231235
statusText: 'OK',
@@ -648,6 +652,26 @@ describe('generateContent', () => {
648652
const vertexEndpoint = fetchSpy.calls.allArgs()[0][0];
649653
expect(vertexEndpoint).toContain('/v1beta1/');
650654
});
655+
656+
it('provides labels to the vertex endpoint', async () => {
657+
const request: GenerateContentRequest = {
658+
contents: CONTENTS,
659+
labels: TEST_LABELS,
660+
};
661+
fetchSpy.and.resolveTo(buildFetchResponse(TEST_MODEL_RESPONSE));
662+
663+
await generateContent(
664+
TEST_LOCATION,
665+
TEST_RESOURCE_PATH,
666+
TEST_TOKEN_PROMISE,
667+
request,
668+
TEST_API_ENDPOINT
669+
);
670+
671+
const httpRequest = fetchSpy.calls.allArgs()[0][1];
672+
const body = JSON.parse(httpRequest.body);
673+
expect(body.labels).toEqual(TEST_LABELS);
674+
});
651675
});
652676

653677
describe('generateContentStream', () => {
@@ -830,4 +854,29 @@ describe('generateContentStream', () => {
830854
)
831855
).toHaveSize(0);
832856
});
857+
858+
it('provides labels to the vertex endpoint', async () => {
859+
const request: GenerateContentRequest = {
860+
contents: CONTENTS,
861+
labels: TEST_LABELS,
862+
};
863+
const expectedStreamResult: StreamGenerateContentResult = {
864+
response: Promise.resolve(TEST_MODEL_RESPONSE_WITH_INVALID_DATA),
865+
stream: testGenerator(),
866+
};
867+
fetchSpy = spyOn(global, 'fetch').and.resolveTo(fetchResult);
868+
spyOn(StreamFunctions, 'processStream').and.resolveTo(expectedStreamResult);
869+
870+
await generateContentStream(
871+
TEST_LOCATION,
872+
TEST_RESOURCE_PATH,
873+
TEST_TOKEN_PROMISE,
874+
request,
875+
TEST_API_ENDPOINT
876+
);
877+
878+
const httpRequest = fetchSpy.calls.allArgs()[0][1];
879+
const body = JSON.parse(httpRequest.body);
880+
expect(body.labels).toEqual(TEST_LABELS);
881+
});
833882
});

src/types/content.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ export declare interface GenerateContentRequest extends BaseModelParams {
6868
* This is the name of a `CachedContent` and not the cache object itself.
6969
*/
7070
cachedContent?: string;
71+
72+
/**
73+
* Optional. Custom metadata labels for organizing API calls and managing costs at scale. See
74+
* https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/add-labels-to-api-calls
75+
*/
76+
labels?: Record<string, string>;
7177
}
7278

7379
/**

0 commit comments

Comments
 (0)