diff --git a/packages/firebase_ai/firebase_ai/lib/src/api.dart b/packages/firebase_ai/firebase_ai/lib/src/api.dart index 7a482c087f1d..1d50fe2f2330 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/api.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/api.dart @@ -172,25 +172,6 @@ final class UsageMetadata { final List? candidatesTokensDetails; } -/// Constructe a UsageMetadata with all it's fields. -/// -/// Expose access to the private constructor for use within the package.. -UsageMetadata createUsageMetadata({ - required int? promptTokenCount, - required int? candidatesTokenCount, - required int? totalTokenCount, - required int? thoughtsTokenCount, - required List? promptTokensDetails, - required List? candidatesTokensDetails, -}) => - UsageMetadata._( - promptTokenCount: promptTokenCount, - candidatesTokenCount: candidatesTokenCount, - totalTokenCount: totalTokenCount, - thoughtsTokenCount: thoughtsTokenCount, - promptTokensDetails: promptTokensDetails, - candidatesTokensDetails: candidatesTokensDetails); - /// Response candidate generated from a [GenerativeModel]. final class Candidate { // TODO: token count? @@ -1128,7 +1109,7 @@ final class VertexSerialization implements SerializationStrategy { }; final usageMedata = switch (jsonObject) { {'usageMetadata': final usageMetadata?} => - _parseUsageMetadata(usageMetadata), + parseUsageMetadata(usageMetadata), {'totalTokens': final int totalTokens} => UsageMetadata._(totalTokenCount: totalTokens), _ => null, @@ -1258,7 +1239,10 @@ PromptFeedback _parsePromptFeedback(Object jsonObject) { }; } -UsageMetadata _parseUsageMetadata(Object jsonObject) { +/// Parses a UsageMetadata from a JSON object. +/// +/// Expose access to the private helper for use within the package. +UsageMetadata parseUsageMetadata(Object jsonObject) { if (jsonObject is! Map) { throw unhandledFormat('UsageMetadata', jsonObject); } @@ -1275,6 +1259,10 @@ UsageMetadata _parseUsageMetadata(Object jsonObject) { {'totalTokenCount': final int totalTokenCount} => totalTokenCount, _ => null, }; + final thoughtsTokenCount = switch (jsonObject) { + {'thoughtsTokenCount': final int thoughtsTokenCount} => thoughtsTokenCount, + _ => null, + }; final promptTokensDetails = switch (jsonObject) { {'promptTokensDetails': final List promptTokensDetails} => promptTokensDetails.map(_parseModalityTokenCount).toList(), @@ -1286,11 +1274,13 @@ UsageMetadata _parseUsageMetadata(Object jsonObject) { _ => null, }; return UsageMetadata._( - promptTokenCount: promptTokenCount, - candidatesTokenCount: candidatesTokenCount, - totalTokenCount: totalTokenCount, - promptTokensDetails: promptTokensDetails, - candidatesTokensDetails: candidatesTokensDetails); + promptTokenCount: promptTokenCount, + candidatesTokenCount: candidatesTokenCount, + totalTokenCount: totalTokenCount, + thoughtsTokenCount: thoughtsTokenCount, + promptTokensDetails: promptTokensDetails, + candidatesTokensDetails: candidatesTokensDetails, + ); } ModalityTokenCount _parseModalityTokenCount(Object? jsonObject) { diff --git a/packages/firebase_ai/firebase_ai/lib/src/developer/api.dart b/packages/firebase_ai/firebase_ai/lib/src/developer/api.dart index 6be1fd771c33..1cb731d310e4 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/developer/api.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/developer/api.dart @@ -31,8 +31,7 @@ import '../api.dart' SafetyRating, SafetySetting, SerializationStrategy, - UsageMetadata, - createUsageMetadata; + parseUsageMetadata; import '../content.dart' show Content, FunctionCall, InlineDataPart, Part, TextPart; import '../error.dart'; @@ -116,7 +115,7 @@ final class DeveloperSerialization implements SerializationStrategy { }; final usageMedata = switch (jsonObject) { {'usageMetadata': final usageMetadata?} => - _parseUsageMetadata(usageMetadata), + parseUsageMetadata(usageMetadata), _ => null, }; return GenerateContentResponse(candidates, promptFeedback, @@ -230,37 +229,6 @@ PromptFeedback _parsePromptFeedback(Object jsonObject) { }; } -UsageMetadata _parseUsageMetadata(Object jsonObject) { - if (jsonObject is! Map) { - throw unhandledFormat('UsageMetadata', jsonObject); - } - final promptTokenCount = switch (jsonObject) { - {'promptTokenCount': final int promptTokenCount} => promptTokenCount, - _ => null, - }; - final candidatesTokenCount = switch (jsonObject) { - {'candidatesTokenCount': final int candidatesTokenCount} => - candidatesTokenCount, - _ => null, - }; - final totalTokenCount = switch (jsonObject) { - {'totalTokenCount': final int totalTokenCount} => totalTokenCount, - _ => null, - }; - final thoughtsTokenCount = switch (jsonObject) { - {'thoughtsTokenCount': final int thoughtsTokenCount} => thoughtsTokenCount, - _ => null, - }; - return createUsageMetadata( - promptTokenCount: promptTokenCount, - candidatesTokenCount: candidatesTokenCount, - totalTokenCount: totalTokenCount, - thoughtsTokenCount: thoughtsTokenCount, - promptTokensDetails: null, - candidatesTokensDetails: null, - ); -} - SafetyRating _parseSafetyRating(Object? jsonObject) { return switch (jsonObject) { { diff --git a/packages/firebase_ai/firebase_ai/test/api_test.dart b/packages/firebase_ai/firebase_ai/test/api_test.dart index a21b17a0ee56..ed5596868d62 100644 --- a/packages/firebase_ai/firebase_ai/test/api_test.dart +++ b/packages/firebase_ai/firebase_ai/test/api_test.dart @@ -615,6 +615,40 @@ void main() { expect(response.usageMetadata!.candidatesTokensDetails, hasLength(1)); }); + group('usageMetadata parsing', () { + test('parses usageMetadata when thoughtsTokenCount is set', () { + final json = { + 'usageMetadata': { + 'promptTokenCount': 10, + 'candidatesTokenCount': 20, + 'totalTokenCount': 30, + 'thoughtsTokenCount': 5, + } + }; + final response = + VertexSerialization().parseGenerateContentResponse(json); + expect(response.usageMetadata, isNotNull); + expect(response.usageMetadata!.promptTokenCount, 10); + expect(response.usageMetadata!.candidatesTokenCount, 20); + expect(response.usageMetadata!.totalTokenCount, 30); + expect(response.usageMetadata!.thoughtsTokenCount, 5); + }); + + test('parses usageMetadata when thoughtsTokenCount is missing', () { + final json = { + 'usageMetadata': { + 'promptTokenCount': 10, + 'candidatesTokenCount': 20, + 'totalTokenCount': 30, + } + }; + final response = + VertexSerialization().parseGenerateContentResponse(json); + expect(response.usageMetadata, isNotNull); + expect(response.usageMetadata!.thoughtsTokenCount, isNull); + }); + }); + group('groundingMetadata parsing', () { test('parses valid response with full grounding metadata', () { final jsonResponse = { diff --git a/packages/firebase_ai/firebase_ai/test/developer_api_test.dart b/packages/firebase_ai/firebase_ai/test/developer_api_test.dart index 5a26eaf01f30..e838f42e629c 100644 --- a/packages/firebase_ai/firebase_ai/test/developer_api_test.dart +++ b/packages/firebase_ai/firebase_ai/test/developer_api_test.dart @@ -39,6 +39,12 @@ void main() { 'candidatesTokenCount': 5, 'totalTokenCount': 15, 'thoughtsTokenCount': 3, + 'promptTokensDetails': [ + {'modality': 'TEXT', 'tokenCount': 10} + ], + 'candidatesTokensDetails': [ + {'modality': 'TEXT', 'tokenCount': 25} + ], } }; final response = @@ -48,6 +54,15 @@ void main() { expect(response.usageMetadata!.candidatesTokenCount, 5); expect(response.usageMetadata!.totalTokenCount, 15); expect(response.usageMetadata!.thoughtsTokenCount, 3); + expect(response.usageMetadata!.promptTokensDetails, isNotNull); + expect(response.usageMetadata!.promptTokensDetails, hasLength(1)); + expect( + response.usageMetadata!.promptTokensDetails!.first.tokenCount, 10); + expect(response.usageMetadata!.candidatesTokensDetails, isNotNull); + expect(response.usageMetadata!.candidatesTokensDetails, hasLength(1)); + expect( + response.usageMetadata!.candidatesTokensDetails!.first.tokenCount, + 25); }); test('parses usageMetadata when thoughtsTokenCount is missing', () { @@ -68,6 +83,12 @@ void main() { 'candidatesTokenCount': 5, 'totalTokenCount': 15, // thoughtsTokenCount is missing + 'promptTokensDetails': [ + {'modality': 'TEXT', 'tokenCount': 10} + ], + 'candidatesTokensDetails': [ + {'modality': 'TEXT', 'tokenCount': 25} + ], } }; final response = @@ -126,6 +147,26 @@ void main() { expect(response.usageMetadata, isNull); }); + test('parses usageMetadata when token details are missing', () { + final jsonResponse = { + 'usageMetadata': { + 'promptTokenCount': 10, + 'candidatesTokenCount': 25, + 'totalTokenCount': 35, + } + }; + + final response = + DeveloperSerialization().parseGenerateContentResponse(jsonResponse); + + expect(response.usageMetadata, isNotNull); + expect(response.usageMetadata!.promptTokenCount, 10); + expect(response.usageMetadata!.candidatesTokenCount, 25); + expect(response.usageMetadata!.totalTokenCount, 35); + expect(response.usageMetadata!.promptTokensDetails, isNull); + expect(response.usageMetadata!.candidatesTokensDetails, isNull); + }); + test('parses inlineData part correctly', () { final inlineData = Uint8List.fromList([1, 2, 3, 4]); final jsonResponse = {