Skip to content

Commit d210d06

Browse files
julian-computescopybara-github
authored andcommitted
Copybara import of the project:
-- f433eeb by Julian Fernandez <[email protected]>: feat: Allow labels in GenerateContentRequest * Add optional `labels` record to GenerateContentRequest -- c08232a by Julian Fernandez <[email protected]>: include labels in sent `generateContentRequest` -- 974e320 by Julian Fernandez <[email protected]>: Add unit tests * Test that `generateContent` provides labels to the Vertex endpoint * Test that `generateContentStream` provides labels to the Vertex endpoint -- 5b5d22d by Julian Fernandez <[email protected]>: cleanup log and spacing -- 2fe3d32 by Julian Fernandez <[email protected]>: fix style COPYBARA_INTEGRATE_REVIEW=#501 from julian-computes:feature/allow-labels 2fe3d32 PiperOrigin-RevId: 746132973
1 parent 5a53266 commit d210d06

File tree

3 files changed

+59
-0
lines changed

3 files changed

+59
-0
lines changed

src/functions/generate_content.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ export async function generateContent(
7777
safetySettings: request.safetySettings ?? safetySettings,
7878
tools: request.tools ?? tools,
7979
toolConfig: request.toolConfig ?? toolConfig,
80+
labels: request.labels,
8081
};
8182
const response: Response | undefined = await postRequest({
8283
region: location,
@@ -132,6 +133,7 @@ export async function generateContentStream(
132133
safetySettings: request.safetySettings ?? safetySettings,
133134
tools: request.tools ?? tools,
134135
toolConfig: request.toolConfig ?? toolConfig,
136+
labels: request.labels,
135137
};
136138
const response = await postRequest({
137139
region: location,

src/functions/test/functions_test.ts

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,10 @@ const TEST_TOOLS_WITH_RAG: Tool[] = [
226226
},
227227
];
228228

229+
const TEST_LABELS: Record<string, string> = {
230+
test_key: 'test_value',
231+
};
232+
229233
const fetchResponseObj = {
230234
status: 200,
231235
statusText: 'OK',
@@ -648,6 +652,27 @@ describe('generateContent', () => {
648652
const vertexEndpoint = fetchSpy.calls.allArgs()[0][0];
649653
expect(vertexEndpoint).toContain('/v1beta1/');
650654
});
655+
656+
it('provides labels to the vertex endpoint', async () => {
657+
const request: GenerateContentRequest = {
658+
contents: CONTENTS,
659+
labels: TEST_LABELS,
660+
};
661+
fetchSpy.and.resolveTo(buildFetchResponse(TEST_MODEL_RESPONSE));
662+
663+
await generateContent(
664+
TEST_LOCATION,
665+
TEST_RESOURCE_PATH,
666+
TEST_TOKEN_PROMISE,
667+
request,
668+
TEST_API_ENDPOINT
669+
);
670+
671+
const httpRequest = fetchSpy.calls.allArgs()[0][1];
672+
const body = JSON.parse(httpRequest.body);
673+
// @ts-ignore
674+
expect(body.labels).toEqual(TEST_LABELS);
675+
});
651676
});
652677

653678
describe('generateContentStream', () => {
@@ -830,4 +855,30 @@ describe('generateContentStream', () => {
830855
)
831856
).toHaveSize(0);
832857
});
858+
859+
it('provides labels to the vertex endpoint', async () => {
860+
const request: GenerateContentRequest = {
861+
contents: CONTENTS,
862+
labels: TEST_LABELS,
863+
};
864+
const expectedStreamResult: StreamGenerateContentResult = {
865+
response: Promise.resolve(TEST_MODEL_RESPONSE_WITH_INVALID_DATA),
866+
stream: testGenerator(),
867+
};
868+
fetchSpy = spyOn(global, 'fetch').and.resolveTo(fetchResult);
869+
spyOn(StreamFunctions, 'processStream').and.resolveTo(expectedStreamResult);
870+
871+
await generateContentStream(
872+
TEST_LOCATION,
873+
TEST_RESOURCE_PATH,
874+
TEST_TOKEN_PROMISE,
875+
request,
876+
TEST_API_ENDPOINT
877+
);
878+
879+
const httpRequest = fetchSpy.calls.allArgs()[0][1];
880+
const body = JSON.parse(httpRequest.body);
881+
// @ts-ignore
882+
expect(body.labels).toEqual(TEST_LABELS);
883+
});
833884
});

src/types/content.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ export declare interface GenerateContentRequest extends BaseModelParams {
6868
* This is the name of a `CachedContent` and not the cache object itself.
6969
*/
7070
cachedContent?: string;
71+
72+
/**
73+
* Optional. Custom metadata labels for organizing API calls and managing costs at scale. See
74+
* https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/add-labels-to-api-calls
75+
*/
76+
labels?: Record<string, string>;
7177
}
7278

7379
/**

0 commit comments

Comments
 (0)