diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart index 4101364057a9..14052ab539ef 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart @@ -19,7 +19,8 @@ import 'schema.dart'; /// Response for Count Tokens final class CountTokensResponse { /// Constructor - CountTokensResponse(this.totalTokens, {this.totalBillableCharacters}); + CountTokensResponse(this.totalTokens, + {this.totalBillableCharacters, this.promptTokensDetails}); /// The number of tokens that the `model` tokenizes the `prompt` into. /// @@ -30,6 +31,9 @@ final class CountTokensResponse { /// /// Always non-negative. final int? totalBillableCharacters; + + /// List of modalities that were processed in the request input. + final List? promptTokensDetails; } /// Response from the model; supports multiple candidates. @@ -128,11 +132,12 @@ final class PromptFeedback { /// Metadata on the generation request's token usage. final class UsageMetadata { /// Constructor - UsageMetadata._({ - this.promptTokenCount, - this.candidatesTokenCount, - this.totalTokenCount, - }); + UsageMetadata._( + {this.promptTokenCount, + this.candidatesTokenCount, + this.totalTokenCount, + this.promptTokensDetails, + this.candidatesTokensDetails}); /// Number of tokens in the prompt. final int? promptTokenCount; @@ -142,6 +147,12 @@ final class UsageMetadata { /// Total token count for the generation request (prompt + candidates). final int? totalTokenCount; + + /// List of modalities that were processed in the request input. + final List? promptTokensDetails; + + /// List of modalities that were returned in the response. + final List? candidatesTokensDetails; } /// Response candidate generated from a [GenerativeModel]. @@ -481,6 +492,62 @@ enum FinishReason { String toString() => name; } +/// Represents token counting info for a single modality. +final class ModalityTokenCount { + /// Constructor + ModalityTokenCount(this.modality, this.tokenCount); + + /// The modality associated with this token count. + final ContentModality modality; + + /// The number of tokens counted. + final int tokenCount; +} + +/// Content part modality. +enum ContentModality { + /// Unspecified modality. + unspecified('MODALITY_UNSPECIFIED'), + + /// Plain text. + text('TEXT'), + + /// Image. + image('IMAGE'), + + /// Video. + video('VIDEO'), + + /// Audio. + audio('AUDIO'), + + /// Document, e.g. PDF. + document('DOCUMENT'); + + const ContentModality(this._jsonString); + + static ContentModality _parseValue(Object jsonObject) { + return switch (jsonObject) { + 'MODALITY_UNSPECIFIED' => ContentModality.unspecified, + 'TEXT' => ContentModality.text, + 'IMAGE' => ContentModality.image, + 'video' => ContentModality.video, + 'audio' => ContentModality.audio, + 'document' => ContentModality.document, + _ => + throw FormatException('Unhandled ContentModality format', jsonObject), + }; + } + + final String _jsonString; + + @override + String toString() => name; + + /// Convert to json format. + Object toJson() => _jsonString; +} + /// Safety setting, affecting the safety-blocking behavior. /// /// Passing a safety setting for a category changes the allowed probability that @@ -696,16 +763,28 @@ GenerateContentResponse parseGenerateContentResponse(Object jsonObject) { /// Parse the json to [CountTokensResponse] CountTokensResponse parseCountTokensResponse(Object jsonObject) { if (jsonObject case {'error': final Object error}) throw parseError(error); - if (jsonObject case {'totalTokens': final int totalTokens}) { - if (jsonObject - case {'totalBillableCharacters': final int totalBillableCharacters}) { - return CountTokensResponse(totalTokens, - totalBillableCharacters: totalBillableCharacters); - } else { - return CountTokensResponse(totalTokens); - } + + if (jsonObject is! Map) { + throw unhandledFormat('CountTokensResponse', jsonObject); } - throw unhandledFormat('CountTokensResponse', jsonObject); + + final totalTokens = jsonObject['totalTokens'] as int; + final totalBillableCharacters = switch (jsonObject) { + {'totalBillableCharacters': final int totalBillableCharacters} => + totalBillableCharacters, + _ => null, + }; + final promptTokensDetails = switch (jsonObject) { + {'promptTokensDetails': final List promptTokensDetails} => + promptTokensDetails.map(_parseModalityTokenCount).toList(), + _ => null, + }; + + return CountTokensResponse( + totalTokens, + totalBillableCharacters: totalBillableCharacters, + promptTokensDetails: promptTokensDetails, + ); } Candidate _parseCandidate(Object? jsonObject) { @@ -777,10 +856,30 @@ UsageMetadata _parseUsageMetadata(Object jsonObject) { {'totalTokenCount': final int totalTokenCount} => totalTokenCount, _ => null, }; + final promptTokensDetails = switch (jsonObject) { + {'promptTokensDetails': final List promptTokensDetails} => + promptTokensDetails.map(_parseModalityTokenCount).toList(), + _ => null, + }; + final candidatesTokensDetails = switch (jsonObject) { + {'candidatesTokensDetails': final List candidatesTokensDetails} => + candidatesTokensDetails.map(_parseModalityTokenCount).toList(), + _ => null, + }; return UsageMetadata._( promptTokenCount: promptTokenCount, candidatesTokenCount: candidatesTokenCount, - totalTokenCount: totalTokenCount); + totalTokenCount: totalTokenCount, + promptTokensDetails: promptTokensDetails, + candidatesTokensDetails: candidatesTokensDetails); +} + +ModalityTokenCount _parseModalityTokenCount(Object? jsonObject) { + if (jsonObject is! Map) { + throw unhandledFormat('ModalityTokenCount', jsonObject); + } + return ModalityTokenCount(ContentModality._parseValue(jsonObject['modality']), + jsonObject['tokenCount'] as int); } SafetyRating _parseSafetyRating(Object? jsonObject) { diff --git a/packages/firebase_vertexai/firebase_vertexai/test/response_parsing_test.dart b/packages/firebase_vertexai/firebase_vertexai/test/response_parsing_test.dart index 2d3802a96eac..d5c8d64f16d5 100644 --- a/packages/firebase_vertexai/firebase_vertexai/test/response_parsing_test.dart +++ b/packages/firebase_vertexai/firebase_vertexai/test/response_parsing_test.dart @@ -654,6 +654,81 @@ void main() { ); }); + test('response including usage metadata', () async { + const response = ''' +{ + "candidates": [{ + "content": { + "role": "model", + "parts": [{ + "text": "Here is a description of the image:" + }] + }, + "finishReason": "STOP" + }], + "usageMetadata": { + "promptTokenCount": 1837, + "candidatesTokenCount": 76, + "totalTokenCount": 1913, + "promptTokensDetails": [{ + "modality": "TEXT", + "tokenCount": 76 + }, { + "modality": "IMAGE", + "tokenCount": 1806 + }], + "candidatesTokensDetails": [{ + "modality": "TEXT", + "tokenCount": 76 + }] + } +} + '''; + final decoded = jsonDecode(response) as Object; + final generateContentResponse = parseGenerateContentResponse(decoded); + expect( + generateContentResponse.text, 'Here is a description of the image:'); + expect(generateContentResponse.usageMetadata?.totalTokenCount, 1913); + expect( + generateContentResponse + .usageMetadata?.promptTokensDetails?[1].modality, + ContentModality.image); + expect( + generateContentResponse + .usageMetadata?.promptTokensDetails?[1].tokenCount, + 1806); + expect( + generateContentResponse + .usageMetadata?.candidatesTokensDetails?.first.modality, + ContentModality.text); + expect( + generateContentResponse + .usageMetadata?.candidatesTokensDetails?.first.tokenCount, + 76); + }); + + test('countTokens with modality fields returned', () async { + const response = ''' +{ + "totalTokens": 1837, + "totalBillableCharacters": 117, + "promptTokensDetails": [{ + "modality": "IMAGE", + "tokenCount": 1806 + }, { + "modality": "TEXT", + "tokenCount": 31 + }] +} + '''; + final decoded = jsonDecode(response) as Object; + final countTokensResponse = parseCountTokensResponse(decoded); + expect(countTokensResponse.totalTokens, 1837); + expect(countTokensResponse.promptTokensDetails?.first.modality, + ContentModality.image); + expect(countTokensResponse.promptTokensDetails?.first.tokenCount, 1806); + }); + test('text getter joins content', () async { const response = ''' {