Skip to content

Commit f5b8fda

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Add usage metadata to LLM Response model
It contains the number of tokens used by the model. This is in line with Python ADK and fixes #212. PiperOrigin-RevId: 788107292
1 parent d5ca1ec commit f5b8fda

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

core/src/main/java/com/google/adk/models/LlmResponse.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import com.google.genai.types.FinishReason;
2929
import com.google.genai.types.GenerateContentResponse;
3030
import com.google.genai.types.GenerateContentResponsePromptFeedback;
31+
import com.google.genai.types.GenerateContentResponseUsageMetadata;
3132
import com.google.genai.types.GroundingMetadata;
3233
import java.util.List;
3334
import java.util.Optional;
@@ -89,6 +90,10 @@ public abstract class LlmResponse extends JsonBaseModel {
8990
@JsonProperty("interrupted")
9091
public abstract Optional<Boolean> interrupted();
9192

93+
/** Usage metadata about the response(s). */
94+
@JsonProperty("usageMetadata")
95+
public abstract Optional<GenerateContentResponseUsageMetadata> usageMetadata();
96+
9297
public abstract Builder toBuilder();
9398

9499
/** Builder for constructing {@link LlmResponse} instances. */
@@ -134,6 +139,13 @@ static LlmResponse.Builder jacksonBuilder() {
134139

135140
public abstract Builder errorMessage(Optional<String> errorMessage);
136141

142+
@JsonProperty("usageMetadata")
143+
public abstract Builder usageMetadata(
144+
@Nullable GenerateContentResponseUsageMetadata usageMetadata);
145+
146+
public abstract Builder usageMetadata(
147+
Optional<GenerateContentResponseUsageMetadata> usageMetadata);
148+
137149
@CanIgnoreReturnValue
138150
public final Builder response(GenerateContentResponse response) {
139151
Optional<List<Candidate>> candidatesOpt = response.candidates();
@@ -160,6 +172,7 @@ public final Builder response(GenerateContentResponse response) {
160172
this.errorMessage("Unknown error.");
161173
}
162174
}
175+
this.usageMetadata(response.usageMetadata());
163176
return this;
164177
}
165178

core/src/test/java/com/google/adk/models/LlmResponseTest.java

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import com.google.genai.types.Content;
2727
import com.google.genai.types.FinishReason;
2828
import com.google.genai.types.FunctionCall;
29+
import com.google.genai.types.GenerateContentResponseUsageMetadata;
2930
import com.google.genai.types.Part;
3031
import java.util.Optional;
3132
import org.junit.Before;
@@ -61,6 +62,12 @@ private Content createSampleFunctionCallContent(String functionName) {
6162
public void testSerializationAndDeserialization_allFieldsPresent()
6263
throws JsonProcessingException {
6364
Content sampleContent = createSampleContent("Hello, world!");
65+
GenerateContentResponseUsageMetadata usageMetadata =
66+
GenerateContentResponseUsageMetadata.builder()
67+
.promptTokenCount(10)
68+
.candidatesTokenCount(20)
69+
.totalTokenCount(30)
70+
.build();
6471
LlmResponse originalResponse =
6572
LlmResponse.builder()
6673
.content(sampleContent)
@@ -69,6 +76,7 @@ public void testSerializationAndDeserialization_allFieldsPresent()
6976
.errorCode(new FinishReason("ERR_123"))
7077
.errorMessage(Optional.of("An error occurred."))
7178
.interrupted(Optional.of(true))
79+
.usageMetadata(usageMetadata)
7280
.build();
7381

7482
String json = originalResponse.toJson();
@@ -83,6 +91,10 @@ public void testSerializationAndDeserialization_allFieldsPresent()
8391
assertThat(jsonNode.get("errorCode").asText()).isEqualTo("ERR_123");
8492
assertThat(jsonNode.get("errorMessage").asText()).isEqualTo("An error occurred.");
8593
assertThat(jsonNode.get("interrupted").asBoolean()).isTrue();
94+
assertThat(jsonNode.has("usageMetadata")).isTrue();
95+
assertThat(jsonNode.get("usageMetadata").get("promptTokenCount").asInt()).isEqualTo(10);
96+
assertThat(jsonNode.get("usageMetadata").get("candidatesTokenCount").asInt()).isEqualTo(20);
97+
assertThat(jsonNode.get("usageMetadata").get("totalTokenCount").asInt()).isEqualTo(30);
8698

8799
LlmResponse deserializedResponse = LlmResponse.fromJsonString(json, LlmResponse.class);
88100

@@ -93,6 +105,7 @@ public void testSerializationAndDeserialization_allFieldsPresent()
93105
assertThat(deserializedResponse.errorCode()).hasValue(new FinishReason("ERR_123"));
94106
assertThat(deserializedResponse.errorMessage()).hasValue("An error occurred.");
95107
assertThat(deserializedResponse.interrupted()).hasValue(true);
108+
assertThat(deserializedResponse.usageMetadata()).hasValue(usageMetadata);
96109
}
97110

98111
@Test
@@ -108,6 +121,7 @@ public void testSerializationAndDeserialization_optionalFieldsEmpty()
108121
.errorCode(Optional.empty())
109122
.errorMessage(Optional.empty())
110123
.interrupted(Optional.empty())
124+
.usageMetadata(Optional.empty())
111125
.build();
112126

113127
String json = originalResponse.toJson();
@@ -122,6 +136,7 @@ public void testSerializationAndDeserialization_optionalFieldsEmpty()
122136
assertThat(jsonNode.has("errorCode")).isFalse();
123137
assertThat(jsonNode.has("errorMessage")).isFalse();
124138
assertThat(jsonNode.has("interrupted")).isFalse();
139+
assertThat(jsonNode.has("usageMetadata")).isFalse();
125140

126141
LlmResponse deserializedResponse = LlmResponse.fromJsonString(json, LlmResponse.class);
127142

@@ -133,6 +148,7 @@ public void testSerializationAndDeserialization_optionalFieldsEmpty()
133148
assertThat(deserializedResponse.errorCode()).isEmpty();
134149
assertThat(deserializedResponse.errorMessage()).isEmpty();
135150
assertThat(deserializedResponse.interrupted()).isEmpty();
151+
assertThat(deserializedResponse.usageMetadata()).isEmpty();
136152
}
137153

138154
@Test
@@ -146,7 +162,8 @@ public void testDeserialization_optionalFieldsNullInJson() throws JsonProcessing
146162
+ "\"turnComplete\": true,"
147163
+ "\"errorCode\": null,"
148164
+ "\"errorMessage\": null,"
149-
+ "\"interrupted\": null"
165+
+ "\"interrupted\": null,"
166+
+ "\"usageMetadata\": null"
150167
+ "}";
151168

152169
LlmResponse deserializedResponse = LlmResponse.fromJsonString(jsonWithNulls, LlmResponse.class);
@@ -160,6 +177,7 @@ public void testDeserialization_optionalFieldsNullInJson() throws JsonProcessing
160177
assertThat(deserializedResponse.errorCode()).isEmpty();
161178
assertThat(deserializedResponse.errorMessage()).isEmpty();
162179
assertThat(deserializedResponse.interrupted()).isEmpty();
180+
assertThat(deserializedResponse.usageMetadata()).isEmpty();
163181
}
164182

165183
@Test
@@ -185,6 +203,7 @@ public void testDeserialization_someOptionalFieldsMissingSomePresent()
185203
assertThat(jsonNode.get("errorCode").asText()).isEqualTo("FATAL_ERROR");
186204
assertThat(jsonNode.has("errorMessage")).isFalse();
187205
assertThat(jsonNode.has("interrupted")).isFalse();
206+
assertThat(jsonNode.has("usageMetadata")).isFalse();
188207

189208
LlmResponse deserializedResponse = LlmResponse.fromJsonString(json, LlmResponse.class);
190209
assertThat(deserializedResponse).isEqualTo(originalResponse);
@@ -197,5 +216,6 @@ public void testDeserialization_someOptionalFieldsMissingSomePresent()
197216
assertThat(deserializedResponse.errorCode()).hasValue(new FinishReason("FATAL_ERROR"));
198217
assertThat(deserializedResponse.errorMessage()).isEmpty();
199218
assertThat(deserializedResponse.interrupted()).isEmpty();
219+
assertThat(deserializedResponse.usageMetadata()).isEmpty();
200220
}
201221
}

0 commit comments

Comments
 (0)