Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 115 additions & 16 deletions packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ import 'schema.dart';
/// Response for Count Tokens
final class CountTokensResponse {
/// Constructor
CountTokensResponse(this.totalTokens, {this.totalBillableCharacters});
CountTokensResponse(this.totalTokens,
{this.totalBillableCharacters, this.promptTokensDetails});

/// The number of tokens that the `model` tokenizes the `prompt` into.
///
Expand All @@ -30,6 +31,9 @@ final class CountTokensResponse {
///
/// Always non-negative.
final int? totalBillableCharacters;

/// List of modalities that were processed in the request input.
final List<ModalityTokenCount>? promptTokensDetails;
}

/// Response from the model; supports multiple candidates.
Expand Down Expand Up @@ -128,11 +132,12 @@ final class PromptFeedback {
/// Metadata on the generation request's token usage.
final class UsageMetadata {
/// Constructor
UsageMetadata._({
this.promptTokenCount,
this.candidatesTokenCount,
this.totalTokenCount,
});
UsageMetadata._(
{this.promptTokenCount,
this.candidatesTokenCount,
this.totalTokenCount,
this.promptTokensDetails,
this.candidatesTokensDetails});

/// Number of tokens in the prompt.
final int? promptTokenCount;
Expand All @@ -142,6 +147,12 @@ final class UsageMetadata {

/// Total token count for the generation request (prompt + candidates).
final int? totalTokenCount;

/// List of modalities that were processed in the request input.
final List<ModalityTokenCount>? promptTokensDetails;

/// List of modalities that were returned in the response.
final List<ModalityTokenCount>? candidatesTokensDetails;
}

/// Response candidate generated from a [GenerativeModel].
Expand Down Expand Up @@ -481,6 +492,62 @@ enum FinishReason {
String toString() => name;
}

/// Represents token counting info for a single modality.
final class ModalityTokenCount {
/// Constructor
ModalityTokenCount(this.modality, this.tokenCount);

/// The modality associated with this token count.
final ContentModality modality;

/// The number of tokens counted.
final int tokenCount;
}

/// Content part modality.
enum ContentModality {
/// Unspecified modality.
unspecified('MODALITY_UNSPECIFIED'),

/// Plain text.
text('TEXT'),

/// Image.
image('IMAGE'),

/// Video.
video('VIDEO'),

/// Audio.
audio('AUDIO'),

/// Document, e.g. PDF.
document('DOCUMENT');

const ContentModality(this._jsonString);

static ContentModality _parseValue(Object jsonObject) {
return switch (jsonObject) {
'MODALITY_UNSPECIFIED' => ContentModality.unspecified,
'TEXT' => ContentModality.text,
'IMAGE' => ContentModality.image,
'video' => ContentModality.video,
'audio' => ContentModality.audio,
'document' => ContentModality.document,
_ =>
throw FormatException('Unhandled ContentModality format', jsonObject),
};
}

final String _jsonString;

@override
String toString() => name;

/// Convert to json format.
Object toJson() => _jsonString;
}

/// Safety setting, affecting the safety-blocking behavior.
///
/// Passing a safety setting for a category changes the allowed probability that
Expand Down Expand Up @@ -696,16 +763,28 @@ GenerateContentResponse parseGenerateContentResponse(Object jsonObject) {
/// Parse the json to [CountTokensResponse]
CountTokensResponse parseCountTokensResponse(Object jsonObject) {
if (jsonObject case {'error': final Object error}) throw parseError(error);
if (jsonObject case {'totalTokens': final int totalTokens}) {
if (jsonObject
case {'totalBillableCharacters': final int totalBillableCharacters}) {
return CountTokensResponse(totalTokens,
totalBillableCharacters: totalBillableCharacters);
} else {
return CountTokensResponse(totalTokens);
}

if (jsonObject is! Map) {
throw unhandledFormat('CountTokensResponse', jsonObject);
}
throw unhandledFormat('CountTokensResponse', jsonObject);

final totalTokens = jsonObject['totalTokens'] as int;
final totalBillableCharacters = switch (jsonObject) {
{'totalBillableCharacters': final int totalBillableCharacters} =>
totalBillableCharacters,
_ => null,
};
final promptTokensDetails = switch (jsonObject) {
{'promptTokensDetails': final List<Object?> promptTokensDetails} =>
promptTokensDetails.map(_parseModalityTokenCount).toList(),
_ => null,
};

return CountTokensResponse(
totalTokens,
totalBillableCharacters: totalBillableCharacters,
promptTokensDetails: promptTokensDetails,
);
}

Candidate _parseCandidate(Object? jsonObject) {
Expand Down Expand Up @@ -777,10 +856,30 @@ UsageMetadata _parseUsageMetadata(Object jsonObject) {
{'totalTokenCount': final int totalTokenCount} => totalTokenCount,
_ => null,
};
final promptTokensDetails = switch (jsonObject) {
{'promptTokensDetails': final List<Object?> promptTokensDetails} =>
promptTokensDetails.map(_parseModalityTokenCount).toList(),
_ => null,
};
final candidatesTokensDetails = switch (jsonObject) {
{'candidatesTokensDetails': final List<Object?> candidatesTokensDetails} =>
candidatesTokensDetails.map(_parseModalityTokenCount).toList(),
_ => null,
};
return UsageMetadata._(
promptTokenCount: promptTokenCount,
candidatesTokenCount: candidatesTokenCount,
totalTokenCount: totalTokenCount);
totalTokenCount: totalTokenCount,
promptTokensDetails: promptTokensDetails,
candidatesTokensDetails: candidatesTokensDetails);
}

ModalityTokenCount _parseModalityTokenCount(Object? jsonObject) {
if (jsonObject is! Map) {
throw unhandledFormat('ModalityTokenCount', jsonObject);
}
return ModalityTokenCount(ContentModality._parseValue(jsonObject['modality']),
jsonObject['tokenCount'] as int);
}

SafetyRating _parseSafetyRating(Object? jsonObject) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,81 @@ void main() {
);
});

test('response including usage metadata', () async {
const response = '''
{
"candidates": [{
"content": {
"role": "model",
"parts": [{
"text": "Here is a description of the image:"
}]
},
"finishReason": "STOP"
}],
"usageMetadata": {
"promptTokenCount": 1837,
"candidatesTokenCount": 76,
"totalTokenCount": 1913,
"promptTokensDetails": [{
"modality": "TEXT",
"tokenCount": 76
}, {
"modality": "IMAGE",
"tokenCount": 1806
}],
"candidatesTokensDetails": [{
"modality": "TEXT",
"tokenCount": 76
}]
}
}
''';
final decoded = jsonDecode(response) as Object;
final generateContentResponse = parseGenerateContentResponse(decoded);
expect(
generateContentResponse.text, 'Here is a description of the image:');
expect(generateContentResponse.usageMetadata?.totalTokenCount, 1913);
expect(
generateContentResponse
.usageMetadata?.promptTokensDetails?[1].modality,
ContentModality.image);
expect(
generateContentResponse
.usageMetadata?.promptTokensDetails?[1].tokenCount,
1806);
expect(
generateContentResponse
.usageMetadata?.candidatesTokensDetails?.first.modality,
ContentModality.text);
expect(
generateContentResponse
.usageMetadata?.candidatesTokensDetails?.first.tokenCount,
76);
});

test('countTokens with modality fields returned', () async {
const response = '''
{
"totalTokens": 1837,
"totalBillableCharacters": 117,
"promptTokensDetails": [{
"modality": "IMAGE",
"tokenCount": 1806
}, {
"modality": "TEXT",
"tokenCount": 31
}]
}
''';
final decoded = jsonDecode(response) as Object;
final countTokensResponse = parseCountTokensResponse(decoded);
expect(countTokensResponse.totalTokens, 1837);
expect(countTokensResponse.promptTokensDetails?.first.modality,
ContentModality.image);
expect(countTokensResponse.promptTokensDetails?.first.tokenCount, 1806);
});

test('text getter joins content', () async {
const response = '''
{
Expand Down
Loading