Skip to content

Commit 9a6b795

Browse files
russellwheatleymikehardy
authored andcommitted
stream-reader.ts
1 parent 6b4d577 commit 9a6b795

File tree

4 files changed

+330
-17
lines changed

4 files changed

+330
-17
lines changed

packages/ai/lib/googleai-mappers.ts

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
/**
2+
* @license
3+
* Copyright 2025 Google LLC
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
import { AIError } from './errors';
19+
import { logger } from './logger';
20+
import {
21+
CitationMetadata,
22+
CountTokensRequest,
23+
GenerateContentCandidate,
24+
GenerateContentRequest,
25+
GenerateContentResponse,
26+
HarmSeverity,
27+
InlineDataPart,
28+
PromptFeedback,
29+
SafetyRating,
30+
AIErrorCode,
31+
} from './types';
32+
import {
33+
GoogleAIGenerateContentResponse,
34+
GoogleAIGenerateContentCandidate,
35+
GoogleAICountTokensRequest,
36+
} from './types/googleai';
37+
38+
/**
39+
* This SDK supports both the Vertex AI Gemini API and the Gemini Developer API (using Google AI).
40+
* The public API prioritizes the format used by the Vertex AI Gemini API.
41+
* We avoid having two sets of types by translating requests and responses between the two API formats.
42+
* This translation allows developers to switch between the Vertex AI Gemini API and the Gemini Developer API
43+
* with minimal code changes.
44+
*
45+
* In here are functions that map requests and responses between the two API formats.
46+
* Requests in the Vertex AI format are mapped to the Google AI format before being sent.
47+
* Responses from the Google AI backend are mapped back to the Vertex AI format before being returned to the user.
48+
*/
49+
50+
/**
51+
* Maps a Vertex AI {@link GenerateContentRequest} to a format that can be sent to Google AI.
52+
*
53+
* @param generateContentRequest The {@link GenerateContentRequest} to map.
54+
* @returns A {@link GenerateContentResponse} that conforms to the Google AI format.
55+
*
56+
* @throws If the request contains properties that are unsupported by Google AI.
57+
*
58+
* @internal
59+
*/
60+
export function mapGenerateContentRequest(
61+
generateContentRequest: GenerateContentRequest,
62+
): GenerateContentRequest {
63+
generateContentRequest.safetySettings?.forEach(safetySetting => {
64+
if (safetySetting.method) {
65+
throw new AIError(
66+
AIErrorCode.UNSUPPORTED,
67+
'SafetySetting.method is not supported in the the Gemini Developer API. Please remove this property.',
68+
);
69+
}
70+
});
71+
72+
if (generateContentRequest.generationConfig?.topK) {
73+
const roundedTopK = Math.round(generateContentRequest.generationConfig.topK);
74+
75+
if (roundedTopK !== generateContentRequest.generationConfig.topK) {
76+
logger.warn(
77+
'topK in GenerationConfig has been rounded to the nearest integer to match the format for requests to the Gemini Developer API.',
78+
);
79+
generateContentRequest.generationConfig.topK = roundedTopK;
80+
}
81+
}
82+
83+
return generateContentRequest;
84+
}
85+
86+
/**
87+
* Maps a {@link GenerateContentResponse} from Google AI to the format of the
88+
* {@link GenerateContentResponse} that we get from VertexAI that is exposed in the public API.
89+
*
90+
* @param googleAIResponse The {@link GenerateContentResponse} from Google AI.
91+
* @returns A {@link GenerateContentResponse} that conforms to the public API's format.
92+
*
93+
* @internal
94+
*/
95+
export function mapGenerateContentResponse(
96+
googleAIResponse: GoogleAIGenerateContentResponse,
97+
): GenerateContentResponse {
98+
const generateContentResponse = {
99+
candidates: googleAIResponse.candidates
100+
? mapGenerateContentCandidates(googleAIResponse.candidates)
101+
: undefined,
102+
prompt: googleAIResponse.promptFeedback
103+
? mapPromptFeedback(googleAIResponse.promptFeedback)
104+
: undefined,
105+
usageMetadata: googleAIResponse.usageMetadata,
106+
};
107+
108+
return generateContentResponse;
109+
}
110+
111+
/**
112+
* Maps a Vertex AI {@link CountTokensRequest} to a format that can be sent to Google AI.
113+
*
114+
* @param countTokensRequest The {@link CountTokensRequest} to map.
115+
* @param model The model to count tokens with.
116+
* @returns A {@link CountTokensRequest} that conforms to the Google AI format.
117+
*
118+
* @internal
119+
*/
120+
export function mapCountTokensRequest(
121+
countTokensRequest: CountTokensRequest,
122+
model: string,
123+
): GoogleAICountTokensRequest {
124+
const mappedCountTokensRequest: GoogleAICountTokensRequest = {
125+
generateContentRequest: {
126+
model,
127+
...countTokensRequest,
128+
},
129+
};
130+
131+
return mappedCountTokensRequest;
132+
}
133+
134+
/**
135+
* Maps a Google AI {@link GoogleAIGenerateContentCandidate} to a format that conforms
136+
* to the Vertex AI API format.
137+
*
138+
* @param candidates The {@link GoogleAIGenerateContentCandidate} to map.
139+
* @returns A {@link GenerateContentCandidate} that conforms to the Vertex AI format.
140+
*
141+
* @throws If any {@link Part} in the candidates has a `videoMetadata` property.
142+
*
143+
* @internal
144+
*/
145+
export function mapGenerateContentCandidates(
146+
candidates: GoogleAIGenerateContentCandidate[],
147+
): GenerateContentCandidate[] {
148+
const mappedCandidates: GenerateContentCandidate[] = [];
149+
let mappedSafetyRatings: SafetyRating[];
150+
if (mappedCandidates) {
151+
candidates.forEach(candidate => {
152+
// Map citationSources to citations.
153+
let citationMetadata: CitationMetadata | undefined;
154+
if (candidate.citationMetadata) {
155+
citationMetadata = {
156+
citations: candidate.citationMetadata.citationSources,
157+
};
158+
}
159+
160+
// Assign missing candidate SafetyRatings properties to their defaults if undefined.
161+
if (candidate.safetyRatings) {
162+
mappedSafetyRatings = candidate.safetyRatings.map(safetyRating => {
163+
return {
164+
...safetyRating,
165+
severity: safetyRating.severity ?? HarmSeverity.HARM_SEVERITY_UNSUPPORTED,
166+
probabilityScore: safetyRating.probabilityScore ?? 0,
167+
severityScore: safetyRating.severityScore ?? 0,
168+
};
169+
});
170+
}
171+
172+
// videoMetadata is not supported.
173+
// Throw early since developers may send a long video as input and only expect to pay
174+
// for inference on a small portion of the video.
175+
if (candidate.content?.parts.some(part => (part as InlineDataPart)?.videoMetadata)) {
176+
throw new AIError(
177+
AIErrorCode.UNSUPPORTED,
178+
'Part.videoMetadata is not supported in the Gemini Developer API. Please remove this property.',
179+
);
180+
}
181+
182+
const mappedCandidate = {
183+
index: candidate.index,
184+
content: candidate.content,
185+
finishReason: candidate.finishReason,
186+
finishMessage: candidate.finishMessage,
187+
safetyRatings: mappedSafetyRatings,
188+
citationMetadata,
189+
groundingMetadata: candidate.groundingMetadata,
190+
};
191+
mappedCandidates.push(mappedCandidate);
192+
});
193+
}
194+
195+
return mappedCandidates;
196+
}
197+
198+
export function mapPromptFeedback(promptFeedback: PromptFeedback): PromptFeedback {
199+
// Assign missing SafetyRating properties to their defaults if undefined.
200+
const mappedSafetyRatings: SafetyRating[] = [];
201+
promptFeedback.safetyRatings.forEach(safetyRating => {
202+
mappedSafetyRatings.push({
203+
category: safetyRating.category,
204+
probability: safetyRating.probability,
205+
severity: safetyRating.severity ?? HarmSeverity.HARM_SEVERITY_UNSUPPORTED,
206+
probabilityScore: safetyRating.probabilityScore ?? 0,
207+
severityScore: safetyRating.severityScore ?? 0,
208+
blocked: safetyRating.blocked,
209+
});
210+
});
211+
212+
const mappedPromptFeedback: PromptFeedback = {
213+
blockReason: promptFeedback.blockReason,
214+
safetyRatings: mappedSafetyRatings,
215+
blockReasonMessage: promptFeedback.blockReasonMessage,
216+
};
217+
return mappedPromptFeedback;
218+
}

packages/ai/lib/requests/stream-reader.ts

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,14 @@ import {
2222
GenerateContentResponse,
2323
GenerateContentStreamResult,
2424
Part,
25-
VertexAIErrorCode,
25+
AIErrorCode,
2626
} from '../types';
27-
import { VertexAIError } from '../errors';
27+
import { AIError } from '../errors';
2828
import { createEnhancedContentResponse } from './response-helpers';
29+
import { ApiSettings } from '../types/internal';
30+
import { BackendType } from '../public-types';
31+
import * as GoogleAIMapper from '../googleai-mappers';
32+
import { GoogleAIGenerateContentResponse } from '../types/googleai';
2933

3034
const responseLineRE = /^data\: (.*)(?:\n\n|\r\r|\r\n\r\n)/;
3135

@@ -37,7 +41,10 @@ const responseLineRE = /^data\: (.*)(?:\n\n|\r\r|\r\n\r\n)/;
3741
*
3842
* @param response - Response from a fetch call
3943
*/
40-
export function processStream(response: Response): GenerateContentStreamResult {
44+
export function processStream(
45+
response: Response,
46+
apiSettings: ApiSettings,
47+
): GenerateContentStreamResult {
4148
const inputStream = new ReadableStream<string>({
4249
async start(controller) {
4350
const reader = response.body!.getReader();
@@ -56,28 +63,36 @@ export function processStream(response: Response): GenerateContentStreamResult {
5663
const responseStream = getResponseStream<GenerateContentResponse>(inputStream);
5764
const [stream1, stream2] = responseStream.tee();
5865
return {
59-
stream: generateResponseSequence(stream1),
60-
response: getResponsePromise(stream2),
66+
stream: generateResponseSequence(stream1, apiSettings),
67+
response: getResponsePromise(stream2, apiSettings),
6168
};
6269
}
6370

6471
async function getResponsePromise(
6572
stream: ReadableStream<GenerateContentResponse>,
73+
apiSettings: ApiSettings,
6674
): Promise<EnhancedGenerateContentResponse> {
6775
const allResponses: GenerateContentResponse[] = [];
6876
const reader = stream.getReader();
6977
while (true) {
7078
const { done, value } = await reader.read();
7179
if (done) {
72-
const enhancedResponse = createEnhancedContentResponse(aggregateResponses(allResponses));
73-
return enhancedResponse;
80+
let generateContentResponse = aggregateResponses(allResponses);
81+
if (apiSettings.backend.backendType === BackendType.GOOGLE_AI) {
82+
generateContentResponse = GoogleAIMapper.mapGenerateContentResponse(
83+
generateContentResponse as GoogleAIGenerateContentResponse,
84+
);
85+
}
86+
return createEnhancedContentResponse(generateContentResponse);
7487
}
88+
7589
allResponses.push(value);
7690
}
7791
}
7892

7993
async function* generateResponseSequence(
8094
stream: ReadableStream<GenerateContentResponse>,
95+
apiSettings: ApiSettings,
8196
): AsyncGenerator<EnhancedGenerateContentResponse> {
8297
const reader = stream.getReader();
8398
while (true) {
@@ -86,7 +101,15 @@ async function* generateResponseSequence(
86101
break;
87102
}
88103

89-
const enhancedResponse = createEnhancedContentResponse(value);
104+
let enhancedResponse: EnhancedGenerateContentResponse;
105+
if (apiSettings.backend.backendType === BackendType.GOOGLE_AI) {
106+
enhancedResponse = createEnhancedContentResponse(
107+
GoogleAIMapper.mapGenerateContentResponse(value as GoogleAIGenerateContentResponse),
108+
);
109+
} else {
110+
enhancedResponse = createEnhancedContentResponse(value);
111+
}
112+
90113
yield enhancedResponse;
91114
}
92115
}
@@ -106,9 +129,7 @@ export function getResponseStream<T>(inputStream: ReadableStream<string>): Reada
106129
return reader.read().then(({ value, done }) => {
107130
if (done) {
108131
if (currentText.trim()) {
109-
controller.error(
110-
new VertexAIError(VertexAIErrorCode.PARSE_FAILED, 'Failed to parse stream'),
111-
);
132+
controller.error(new AIError(AIErrorCode.PARSE_FAILED, 'Failed to parse stream'));
112133
return;
113134
}
114135
controller.close();
@@ -123,10 +144,7 @@ export function getResponseStream<T>(inputStream: ReadableStream<string>): Reada
123144
parsedResponse = JSON.parse(match[1]!);
124145
} catch (_) {
125146
controller.error(
126-
new VertexAIError(
127-
VertexAIErrorCode.PARSE_FAILED,
128-
`Error parsing JSON response: "${match[1]}`,
129-
),
147+
new AIError(AIErrorCode.PARSE_FAILED, `Error parsing JSON response: "${match[1]}`),
130148
);
131149
return;
132150
}
@@ -197,8 +215,8 @@ export function aggregateResponses(responses: GenerateContentResponse[]): Genera
197215
newPart.functionCall = part.functionCall;
198216
}
199217
if (Object.keys(newPart).length === 0) {
200-
throw new VertexAIError(
201-
VertexAIErrorCode.INVALID_CONTENT,
218+
throw new AIError(
219+
AIErrorCode.INVALID_CONTENT,
202220
'Part should have at least one property, but there are none. This is likely caused ' +
203221
'by a malformed response from the backend.',
204222
);

packages/ai/lib/types/enums.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,13 @@ export enum HarmSeverity {
9191
HARM_SEVERITY_MEDIUM = 'HARM_SEVERITY_MEDIUM',
9292
// High level of harm severity.
9393
HARM_SEVERITY_HIGH = 'HARM_SEVERITY_HIGH',
94+
/**
95+
* Harm severity is not supported.
96+
*
97+
* @remarks
98+
* The GoogleAI backend does not support `HarmSeverity`, so this value is used as a fallback.
99+
*/
100+
HARM_SEVERITY_UNSUPPORTED = 'HARM_SEVERITY_UNSUPPORTED',
94101
}
95102

96103
/**

0 commit comments

Comments
 (0)