Skip to content

Commit 95841ec

Browse files
authored
Merge pull request #612 from mkbehr/genai-migrate
Migrate to new Gemini client SDK.
2 parents 8ffd4e5 + d162f09 commit 95841ec

File tree

4 files changed

+97
-43
lines changed

4 files changed

+97
-43
lines changed

functions/package-lock.json

Lines changed: 44 additions & 20 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

functions/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"main": "lib/index.js",
1818
"dependencies": {
1919
"@deliberation-lab/utils": "file:../utils",
20-
"@google/generative-ai": "^0.17.2",
20+
"@google/genai": "^1.11.0",
2121
"@sinclair/typebox": "^0.32.30",
2222
"firebase-admin": "^12.1.0",
2323
"firebase-functions": "^5.0.0",

functions/src/api/gemini.api.test.ts

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import nock = require('nock');
44
import {
55
ModelGenerationConfig,
66
ModelResponse,
7+
ModelResponseStatus,
78
StructuredOutputType,
89
StructuredOutputDataType,
910
} from '@deliberation-lab/utils';
@@ -152,5 +153,30 @@ describe('Gemini API', () => {
152153
expect(parsedResponse).toMatchObject(expectedResponse);
153154
});
154155

155-
// TODO(mkbehr): Add tests for error responses.
156+
it('handles a 503 error from the server', async () => {
157+
nock.cleanAll();
158+
nock('https://generativelanguage.googleapis.com')
159+
.post(`/v1beta/models/${MODEL_NAME}:generateContent`)
160+
.reply(503, 'Service Unavailable');
161+
162+
const generationConfig: ModelGenerationConfig = {
163+
maxTokens: 300,
164+
stopSequences: [],
165+
temperature: 0.4,
166+
topP: 0.9,
167+
frequencyPenalty: 0,
168+
presencePenalty: 0,
169+
customRequestBodyFields: [{ name: 'seed', value: 123 }],
170+
};
171+
172+
const response: ModelResponse = await getGeminiAPIResponse(
173+
'testapikey',
174+
MODEL_NAME,
175+
'This is a test prompt.',
176+
generationConfig,
177+
);
178+
179+
expect(response.status).toBe(ModelResponseStatus.PROVIDER_UNAVAILABLE_ERROR);
180+
expect(response.errorMessage).toContain('Service Unavailable');
181+
});
156182
});

functions/src/api/gemini.api.ts

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import {
2-
GoogleGenerativeAI,
2+
ApiError,
3+
GoogleGenAI,
34
GenerationConfig,
45
HarmCategory,
56
HarmBlockThreshold,
6-
} from '@google/generative-ai';
7+
} from '@google/genai';
78
import {
89
ModelGenerationConfig,
910
StructuredOutputType,
@@ -117,15 +118,18 @@ export async function callGemini(
117118
modelName = GEMINI_DEFAULT_MODEL,
118119
parseResponse = false, // parse if structured output
119120
) {
120-
const genAI = new GoogleGenerativeAI(apiKey);
121-
const model = genAI.getGenerativeModel({
121+
const genAI = new GoogleGenAI({apiKey});
122+
123+
const response = await genAI.models.generateContent({
122124
model: modelName,
123-
generationConfig,
124-
safetySettings: SAFETY_SETTINGS,
125+
contents: [{role: 'user', parts: [{text: prompt}]}],
126+
config: {
127+
...generationConfig,
128+
safetySettings: SAFETY_SETTINGS,
129+
},
125130
});
126131

127-
const result = await model.generateContent(prompt);
128-
const response = await result.response;
132+
console.log(`DEBUG: ${JSON.stringify(response)}`);
129133

130134
if (response.promptFeedback) {
131135
return {
@@ -236,20 +240,20 @@ export async function getGeminiAPIResponse(
236240
);
237241
// eslint-disable-next-line @typescript-eslint/no-explicit-any
238242
} catch (error: any) {
239-
// The GenerativeAI client doesn't return responses in a parseable format,
240-
// so try to parse the output string looking for the HTTP status code.
243+
if (!(error instanceof ApiError)) {
244+
return {
245+
status: ModelResponseStatus.UNKNOWN_ERROR,
246+
generationConfig: geminiConfig,
247+
errorMessage: JSON.stringify(error),
248+
};
249+
}
241250
let returnStatus = ModelResponseStatus.UNKNOWN_ERROR;
242-
// Match a status code and message between brackets, e.g. "[403 Forbidden]".
243-
const statusMatch = error.message.match(/\[(\d{3})[\s\w]*\]/);
244-
if (statusMatch) {
245-
const statusCode = parseInt(statusMatch[1]);
246-
if (statusCode == AUTHENTICATION_FAILURE_ERROR_CODE) {
247-
returnStatus = ModelResponseStatus.AUTHENTICATION_ERROR;
248-
} else if (statusCode == QUOTA_ERROR_CODE) {
249-
returnStatus = ModelResponseStatus.QUOTA_ERROR;
250-
} else if (statusCode >= 500 && statusCode < 600) {
251-
returnStatus = ModelResponseStatus.PROVIDER_UNAVAILABLE_ERROR;
252-
}
251+
if (error.status == AUTHENTICATION_FAILURE_ERROR_CODE) {
252+
returnStatus = ModelResponseStatus.AUTHENTICATION_ERROR;
253+
} else if (error.status == QUOTA_ERROR_CODE) {
254+
returnStatus = ModelResponseStatus.QUOTA_ERROR;
255+
} else if (error.status >= 500 && error.status < 600) {
256+
returnStatus = ModelResponseStatus.PROVIDER_UNAVAILABLE_ERROR;
253257
}
254258
return {
255259
status: returnStatus,

0 commit comments

Comments
 (0)