|
1 | 1 | import { |
2 | | - GoogleGenerativeAI, |
| 2 | + ApiError, |
| 3 | + GoogleGenAI, |
3 | 4 | GenerationConfig, |
4 | 5 | HarmCategory, |
5 | 6 | HarmBlockThreshold, |
6 | | -} from '@google/generative-ai'; |
| 7 | +} from '@google/genai'; |
7 | 8 | import { |
8 | 9 | ModelGenerationConfig, |
9 | 10 | StructuredOutputType, |
@@ -117,15 +118,18 @@ export async function callGemini( |
117 | 118 | modelName = GEMINI_DEFAULT_MODEL, |
118 | 119 | parseResponse = false, // parse if structured output |
119 | 120 | ) { |
120 | | - const genAI = new GoogleGenerativeAI(apiKey); |
121 | | - const model = genAI.getGenerativeModel({ |
| 121 | + const genAI = new GoogleGenAI({apiKey}); |
| 122 | + |
| 123 | + const response = await genAI.models.generateContent({ |
122 | 124 | model: modelName, |
123 | | - generationConfig, |
124 | | - safetySettings: SAFETY_SETTINGS, |
| 125 | + contents: [{role: 'user', parts: [{text: prompt}]}], |
| 126 | + config: { |
| 127 | + ...generationConfig, |
| 128 | + safetySettings: SAFETY_SETTINGS, |
| 129 | + }, |
125 | 130 | }); |
126 | 131 |
|
127 | | - const result = await model.generateContent(prompt); |
128 | | - const response = await result.response; |
| 132 | + console.log(`DEBUG: ${JSON.stringify(response)}`); |
129 | 133 |
|
130 | 134 | if (response.promptFeedback) { |
131 | 135 | return { |
@@ -236,20 +240,20 @@ export async function getGeminiAPIResponse( |
236 | 240 | ); |
237 | 241 | // eslint-disable-next-line @typescript-eslint/no-explicit-any |
238 | 242 | } catch (error: any) { |
239 | | - // The GenerativeAI client doesn't return responses in a parseable format, |
240 | | - // so try to parse the output string looking for the HTTP status code. |
| 243 | + if (!(error instanceof ApiError)) { |
| 244 | + return { |
| 245 | + status: ModelResponseStatus.UNKNOWN_ERROR, |
| 246 | + generationConfig: geminiConfig, |
| 247 | + errorMessage: JSON.stringify(error), |
| 248 | + }; |
| 249 | + } |
241 | 250 | let returnStatus = ModelResponseStatus.UNKNOWN_ERROR; |
242 | | - // Match a status code and message between brackets, e.g. "[403 Forbidden]". |
243 | | - const statusMatch = error.message.match(/\[(\d{3})[\s\w]*\]/); |
244 | | - if (statusMatch) { |
245 | | - const statusCode = parseInt(statusMatch[1]); |
246 | | - if (statusCode == AUTHENTICATION_FAILURE_ERROR_CODE) { |
247 | | - returnStatus = ModelResponseStatus.AUTHENTICATION_ERROR; |
248 | | - } else if (statusCode == QUOTA_ERROR_CODE) { |
249 | | - returnStatus = ModelResponseStatus.QUOTA_ERROR; |
250 | | - } else if (statusCode >= 500 && statusCode < 600) { |
251 | | - returnStatus = ModelResponseStatus.PROVIDER_UNAVAILABLE_ERROR; |
252 | | - } |
| 251 | + if (error.status == AUTHENTICATION_FAILURE_ERROR_CODE) { |
| 252 | + returnStatus = ModelResponseStatus.AUTHENTICATION_ERROR; |
| 253 | + } else if (error.status == QUOTA_ERROR_CODE) { |
| 254 | + returnStatus = ModelResponseStatus.QUOTA_ERROR; |
| 255 | + } else if (error.status >= 500 && error.status < 600) { |
| 256 | + returnStatus = ModelResponseStatus.PROVIDER_UNAVAILABLE_ERROR; |
253 | 257 | } |
254 | 258 | return { |
255 | 259 | status: returnStatus, |
|
0 commit comments