Skip to content

fix(firebaseai): Added token details parsing for Dev API #17609

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
42 changes: 16 additions & 26 deletions packages/firebase_ai/firebase_ai/lib/src/api.dart
Original file line number Diff line number Diff line change
Expand Up @@ -172,25 +172,6 @@ final class UsageMetadata {
final List<ModalityTokenCount>? candidatesTokensDetails;
}

/// Constructe a UsageMetadata with all it's fields.
///
/// Expose access to the private constructor for use within the package..
UsageMetadata createUsageMetadata({
required int? promptTokenCount,
required int? candidatesTokenCount,
required int? totalTokenCount,
required int? thoughtsTokenCount,
required List<ModalityTokenCount>? promptTokensDetails,
required List<ModalityTokenCount>? candidatesTokensDetails,
}) =>
UsageMetadata._(
promptTokenCount: promptTokenCount,
candidatesTokenCount: candidatesTokenCount,
totalTokenCount: totalTokenCount,
thoughtsTokenCount: thoughtsTokenCount,
promptTokensDetails: promptTokensDetails,
candidatesTokensDetails: candidatesTokensDetails);

/// Response candidate generated from a [GenerativeModel].
final class Candidate {
// TODO: token count?
Expand Down Expand Up @@ -1128,7 +1109,7 @@ final class VertexSerialization implements SerializationStrategy {
};
final usageMedata = switch (jsonObject) {
{'usageMetadata': final usageMetadata?} =>
_parseUsageMetadata(usageMetadata),
parseUsageMetadata(usageMetadata),
{'totalTokens': final int totalTokens} =>
UsageMetadata._(totalTokenCount: totalTokens),
_ => null,
Expand Down Expand Up @@ -1258,7 +1239,10 @@ PromptFeedback _parsePromptFeedback(Object jsonObject) {
};
}

UsageMetadata _parseUsageMetadata(Object jsonObject) {
/// Parses a UsageMetadata from a JSON object.
///
/// Expose access to the private helper for use within the package.
UsageMetadata parseUsageMetadata(Object jsonObject) {
if (jsonObject is! Map<String, Object?>) {
throw unhandledFormat('UsageMetadata', jsonObject);
}
Expand All @@ -1275,6 +1259,10 @@ UsageMetadata _parseUsageMetadata(Object jsonObject) {
{'totalTokenCount': final int totalTokenCount} => totalTokenCount,
_ => null,
};
final thoughtsTokenCount = switch (jsonObject) {
{'thoughtsTokenCount': final int thoughtsTokenCount} => thoughtsTokenCount,
_ => null,
};
final promptTokensDetails = switch (jsonObject) {
{'promptTokensDetails': final List<Object?> promptTokensDetails} =>
promptTokensDetails.map(_parseModalityTokenCount).toList(),
Expand All @@ -1286,11 +1274,13 @@ UsageMetadata _parseUsageMetadata(Object jsonObject) {
_ => null,
};
return UsageMetadata._(
promptTokenCount: promptTokenCount,
candidatesTokenCount: candidatesTokenCount,
totalTokenCount: totalTokenCount,
promptTokensDetails: promptTokensDetails,
candidatesTokensDetails: candidatesTokensDetails);
promptTokenCount: promptTokenCount,
candidatesTokenCount: candidatesTokenCount,
totalTokenCount: totalTokenCount,
thoughtsTokenCount: thoughtsTokenCount,
promptTokensDetails: promptTokensDetails,
candidatesTokensDetails: candidatesTokensDetails,
);
}

ModalityTokenCount _parseModalityTokenCount(Object? jsonObject) {
Expand Down
36 changes: 2 additions & 34 deletions packages/firebase_ai/firebase_ai/lib/src/developer/api.dart
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ import '../api.dart'
SafetyRating,
SafetySetting,
SerializationStrategy,
UsageMetadata,
createUsageMetadata;
parseUsageMetadata;
import '../content.dart'
show Content, FunctionCall, InlineDataPart, Part, TextPart;
import '../error.dart';
Expand Down Expand Up @@ -116,7 +115,7 @@ final class DeveloperSerialization implements SerializationStrategy {
};
final usageMedata = switch (jsonObject) {
{'usageMetadata': final usageMetadata?} =>
_parseUsageMetadata(usageMetadata),
parseUsageMetadata(usageMetadata),
_ => null,
};
return GenerateContentResponse(candidates, promptFeedback,
Expand Down Expand Up @@ -230,37 +229,6 @@ PromptFeedback _parsePromptFeedback(Object jsonObject) {
};
}

UsageMetadata _parseUsageMetadata(Object jsonObject) {
if (jsonObject is! Map<String, Object?>) {
throw unhandledFormat('UsageMetadata', jsonObject);
}
final promptTokenCount = switch (jsonObject) {
{'promptTokenCount': final int promptTokenCount} => promptTokenCount,
_ => null,
};
final candidatesTokenCount = switch (jsonObject) {
{'candidatesTokenCount': final int candidatesTokenCount} =>
candidatesTokenCount,
_ => null,
};
final totalTokenCount = switch (jsonObject) {
{'totalTokenCount': final int totalTokenCount} => totalTokenCount,
_ => null,
};
final thoughtsTokenCount = switch (jsonObject) {
{'thoughtsTokenCount': final int thoughtsTokenCount} => thoughtsTokenCount,
_ => null,
};
return createUsageMetadata(
promptTokenCount: promptTokenCount,
candidatesTokenCount: candidatesTokenCount,
totalTokenCount: totalTokenCount,
thoughtsTokenCount: thoughtsTokenCount,
promptTokensDetails: null,
candidatesTokensDetails: null,
);
}

SafetyRating _parseSafetyRating(Object? jsonObject) {
return switch (jsonObject) {
{
Expand Down
34 changes: 34 additions & 0 deletions packages/firebase_ai/firebase_ai/test/api_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,40 @@ void main() {
expect(response.usageMetadata!.candidatesTokensDetails, hasLength(1));
});

group('usageMetadata parsing', () {
test('parses usageMetadata when thoughtsTokenCount is set', () {
final json = {
'usageMetadata': {
'promptTokenCount': 10,
'candidatesTokenCount': 20,
'totalTokenCount': 30,
'thoughtsTokenCount': 5,
}
};
final response =
VertexSerialization().parseGenerateContentResponse(json);
expect(response.usageMetadata, isNotNull);
expect(response.usageMetadata!.promptTokenCount, 10);
expect(response.usageMetadata!.candidatesTokenCount, 20);
expect(response.usageMetadata!.totalTokenCount, 30);
expect(response.usageMetadata!.thoughtsTokenCount, 5);
});

test('parses usageMetadata when thoughtsTokenCount is missing', () {
final json = {
'usageMetadata': {
'promptTokenCount': 10,
'candidatesTokenCount': 20,
'totalTokenCount': 30,
}
};
final response =
VertexSerialization().parseGenerateContentResponse(json);
expect(response.usageMetadata, isNotNull);
expect(response.usageMetadata!.thoughtsTokenCount, isNull);
});
});

group('groundingMetadata parsing', () {
test('parses valid response with full grounding metadata', () {
final jsonResponse = {
Expand Down
41 changes: 41 additions & 0 deletions packages/firebase_ai/firebase_ai/test/developer_api_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ void main() {
'candidatesTokenCount': 5,
'totalTokenCount': 15,
'thoughtsTokenCount': 3,
'promptTokensDetails': [
{'modality': 'TEXT', 'tokenCount': 10}
],
'candidatesTokensDetails': [
{'modality': 'TEXT', 'tokenCount': 25}
],
}
};
final response =
Expand All @@ -48,6 +54,15 @@ void main() {
expect(response.usageMetadata!.candidatesTokenCount, 5);
expect(response.usageMetadata!.totalTokenCount, 15);
expect(response.usageMetadata!.thoughtsTokenCount, 3);
expect(response.usageMetadata!.promptTokensDetails, isNotNull);
expect(response.usageMetadata!.promptTokensDetails, hasLength(1));
expect(
response.usageMetadata!.promptTokensDetails!.first.tokenCount, 10);
expect(response.usageMetadata!.candidatesTokensDetails, isNotNull);
expect(response.usageMetadata!.candidatesTokensDetails, hasLength(1));
expect(
response.usageMetadata!.candidatesTokensDetails!.first.tokenCount,
25);
});

test('parses usageMetadata when thoughtsTokenCount is missing', () {
Expand All @@ -68,6 +83,12 @@ void main() {
'candidatesTokenCount': 5,
'totalTokenCount': 15,
// thoughtsTokenCount is missing
'promptTokensDetails': [
{'modality': 'TEXT', 'tokenCount': 10}
],
'candidatesTokensDetails': [
{'modality': 'TEXT', 'tokenCount': 25}
],
}
};
final response =
Expand Down Expand Up @@ -126,6 +147,26 @@ void main() {
expect(response.usageMetadata, isNull);
});

test('parses usageMetadata when token details are missing', () {
final jsonResponse = {
'usageMetadata': {
'promptTokenCount': 10,
'candidatesTokenCount': 25,
'totalTokenCount': 35,
}
};

final response =
DeveloperSerialization().parseGenerateContentResponse(jsonResponse);

expect(response.usageMetadata, isNotNull);
expect(response.usageMetadata!.promptTokenCount, 10);
expect(response.usageMetadata!.candidatesTokenCount, 25);
expect(response.usageMetadata!.totalTokenCount, 35);
expect(response.usageMetadata!.promptTokensDetails, isNull);
expect(response.usageMetadata!.candidatesTokensDetails, isNull);
});

test('parses inlineData part correctly', () {
final inlineData = Uint8List.fromList([1, 2, 3, 4]);
final jsonResponse = {
Expand Down
Loading