Skip to content

Commit bfd23a8

Browse files
shukladivyanshcopybara-github
authored andcommitted
feat: Add VertexAiCodeExecutor
PiperOrigin-RevId: 794709944
1 parent 55fffb7 commit bfd23a8

File tree

6 files changed

+270
-5
lines changed

6 files changed

+270
-5
lines changed
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.adk.codeexecutors;
18+
19+
import com.fasterxml.jackson.core.type.TypeReference;
20+
import com.fasterxml.jackson.databind.ObjectMapper;
21+
import com.google.adk.agents.InvocationContext;
22+
import com.google.adk.codeexecutors.CodeExecutionUtils.CodeExecutionInput;
23+
import com.google.adk.codeexecutors.CodeExecutionUtils.CodeExecutionResult;
24+
import com.google.adk.codeexecutors.CodeExecutionUtils.File;
25+
import com.google.cloud.aiplatform.v1beta1.ExecuteExtensionRequest;
26+
import com.google.cloud.aiplatform.v1beta1.ExecuteExtensionResponse;
27+
import com.google.cloud.aiplatform.v1beta1.ExtensionExecutionServiceClient;
28+
import com.google.cloud.aiplatform.v1beta1.ExtensionExecutionServiceSettings;
29+
import com.google.common.collect.ImmutableList;
30+
import com.google.common.collect.ImmutableMap;
31+
import com.google.protobuf.ListValue;
32+
import com.google.protobuf.Struct;
33+
import com.google.protobuf.Value;
34+
import java.io.IOException;
35+
import java.net.URLConnection;
36+
import java.util.ArrayList;
37+
import java.util.List;
38+
import java.util.Map;
39+
import java.util.Optional;
40+
import java.util.logging.Level;
41+
import java.util.logging.Logger;
42+
43+
/**
44+
* A code executor that uses Vertex Code Interpreter Extension to execute code.
45+
*
46+
* <p>Attributes: resourceName: If set, load the existing resource name of the code interpreter
47+
* extension instead of creating a new one. Format:
48+
* projects/123/locations/us-central1/extensions/456
49+
*
50+
* <p>Follow https://cloud.google.com/vertex-ai/generative-ai/docs/extensions/code-interpreter for
51+
* setup.
52+
*/
53+
public final class VertexAiCodeExecutor extends BaseCodeExecutor {
54+
private static final Logger logger = Logger.getLogger(VertexAiCodeExecutor.class.getName());
55+
56+
private static final ImmutableList<String> SUPPORTED_IMAGE_TYPES =
57+
ImmutableList.of("png", "jpg", "jpeg");
58+
private static final ImmutableList<String> SUPPORTED_DATA_FILE_TYPES = ImmutableList.of("csv");
59+
60+
private static final String IMPORTED_LIBRARIES =
61+
"import io\n"
62+
+ "import math\n"
63+
+ "import re\n"
64+
+ "\n"
65+
+ "import matplotlib.pyplot as plt\n"
66+
+ "import numpy as np\n"
67+
+ "import pandas as pd\n"
68+
+ "import scipy\n"
69+
+ "\n"
70+
+ "def crop(s: str, max_chars: int = 64) -> str:\n"
71+
+ " \"\"\"Crops a string to max_chars characters.\"\"\"\n"
72+
+ " return s[: max_chars - 3] + '...' if len(s) > max_chars else s\n"
73+
+ "\n"
74+
+ "\n"
75+
+ "def explore_df(df: pd.DataFrame) -> None:\n"
76+
+ " \"\"\"Prints some information about a pandas DataFrame.\"\"\"\n"
77+
+ "\n"
78+
+ " with pd.option_context(\n"
79+
+ " 'display.max_columns', None, 'display.expand_frame_repr', False\n"
80+
+ " ):\n"
81+
+ " # Print the column names to never encounter KeyError when selecting one.\n"
82+
+ " df_dtypes = df.dtypes\n"
83+
+ "\n"
84+
+ " # Obtain information about data types and missing values.\n"
85+
+ " df_nulls = (len(df) - df.isnull().sum()).apply(\n"
86+
+ " lambda x: f'{x} / {df.shape[0]} non-null'\n"
87+
+ " )\n"
88+
+ "\n"
89+
+ " # Explore unique total values in columns using `.unique()`.\n"
90+
+ " df_unique_count = df.apply(lambda x: len(x.unique()))\n"
91+
+ "\n"
92+
+ " # Explore unique values in columns using `.unique()`.\n"
93+
+ " df_unique = df.apply(lambda x: crop(str(list(x.unique()))))\n"
94+
+ "\n"
95+
+ " df_info = pd.concat(\n"
96+
+ " (\n"
97+
+ " df_dtypes.rename('Dtype'),\n"
98+
+ " df_nulls.rename('Non-Null Count'),\n"
99+
+ " df_unique_count.rename('Unique Values Count'),\n"
100+
+ " df_unique.rename('Unique Values'),\n"
101+
+ " ),\n"
102+
+ " axis=1,\n"
103+
+ " )\n"
104+
+ " df_info.index.name = 'Columns'\n"
105+
+ " print(f\"\"\"Total rows: {df.shape[0]}\n"
106+
+ "Total columns: {df.shape[1]}\n"
107+
+ "\n"
108+
+ "{df_info}\"\"\")";
109+
110+
private final String resourceName;
111+
private final ExtensionExecutionServiceClient codeInterpreterExtension;
112+
113+
/**
114+
* Initializes the VertexAiCodeExecutor.
115+
*
116+
* @param resourceName If set, load the existing resource name of the code interpreter extension
117+
* instead of creating a new one. Format: projects/123/locations/us-central1/extensions/456
118+
*/
119+
public VertexAiCodeExecutor(String resourceName) {
120+
String resolvedResourceName = resourceName;
121+
if (resolvedResourceName == null || resolvedResourceName.isEmpty()) {
122+
resolvedResourceName = System.getenv("CODE_INTERPRETER_EXTENSION_NAME");
123+
}
124+
125+
if (resolvedResourceName == null || resolvedResourceName.isEmpty()) {
126+
logger.warning(
127+
"No resource name found for Vertex AI Code Interpreter. It will not be available.");
128+
this.resourceName = null;
129+
this.codeInterpreterExtension = null;
130+
} else {
131+
this.resourceName = resolvedResourceName;
132+
try {
133+
String[] parts = this.resourceName.split("/");
134+
if (parts.length < 4 || !parts[2].equals("locations")) {
135+
throw new IllegalArgumentException("Invalid resource name format: " + this.resourceName);
136+
}
137+
String location = parts[3];
138+
String endpoint = String.format("%s-aiplatform.googleapis.com:443", location);
139+
ExtensionExecutionServiceSettings settings =
140+
ExtensionExecutionServiceSettings.newBuilder().setEndpoint(endpoint).build();
141+
this.codeInterpreterExtension = ExtensionExecutionServiceClient.create(settings);
142+
} catch (IOException e) {
143+
logger.log(Level.SEVERE, "Failed to create ExtensionExecutionServiceClient", e);
144+
throw new IllegalStateException("Failed to create ExtensionExecutionServiceClient", e);
145+
}
146+
}
147+
}
148+
149+
@Override
150+
public CodeExecutionResult executeCode(
151+
InvocationContext invocationContext, CodeExecutionInput codeExecutionInput) {
152+
// Execute the code.
153+
Map<String, Object> codeExecutionResult =
154+
executeCodeInterpreter(
155+
getCodeWithImports(codeExecutionInput.code()),
156+
codeExecutionInput.inputFiles(),
157+
codeExecutionInput.executionId());
158+
159+
// Save output file as artifacts.
160+
List<File> savedFiles = new ArrayList<>();
161+
if (codeExecutionResult.containsKey("output_files")) {
162+
@SuppressWarnings("unchecked")
163+
List<Map<String, String>> outputFiles =
164+
(List<Map<String, String>>) codeExecutionResult.get("output_files");
165+
for (Map<String, String> outputFile : outputFiles) {
166+
String fileName = outputFile.get("name");
167+
String content = outputFile.get("contents"); // This is a base64 string.
168+
String fileType = fileName.substring(fileName.lastIndexOf('.') + 1);
169+
String mimeType;
170+
if (SUPPORTED_IMAGE_TYPES.contains(fileType)) {
171+
mimeType = "image/" + fileType;
172+
} else if (SUPPORTED_DATA_FILE_TYPES.contains(fileType)) {
173+
mimeType = "text/" + fileType;
174+
} else {
175+
mimeType = URLConnection.guessContentTypeFromName(fileName);
176+
}
177+
savedFiles.add(File.builder().name(fileName).content(content).mimeType(mimeType).build());
178+
}
179+
}
180+
181+
// Collect the final result.
182+
return CodeExecutionResult.builder()
183+
.stdout((String) codeExecutionResult.getOrDefault("execution_result", ""))
184+
.stderr((String) codeExecutionResult.getOrDefault("execution_error", ""))
185+
.outputFiles(savedFiles)
186+
.build();
187+
}
188+
189+
private Map<String, Object> executeCodeInterpreter(
190+
String code, List<File> inputFiles, Optional<String> sessionId) {
191+
if (codeInterpreterExtension == null) {
192+
logger.warning(
193+
"Vertex AI Code Interpreter execution is not available. Returning empty result.");
194+
return ImmutableMap.of(
195+
"execution_result", "", "execution_error", "", "output_files", new ArrayList<>());
196+
}
197+
198+
// Build operationParams
199+
Struct.Builder paramsBuilder = Struct.newBuilder();
200+
paramsBuilder.putFields("query", Value.newBuilder().setStringValue(code).build());
201+
if (inputFiles != null && !inputFiles.isEmpty()) {
202+
ListValue.Builder listBuilder = ListValue.newBuilder();
203+
for (File f : inputFiles) {
204+
Struct.Builder fileStructBuilder = Struct.newBuilder();
205+
fileStructBuilder.putFields("name", Value.newBuilder().setStringValue(f.name()).build());
206+
fileStructBuilder.putFields(
207+
"contents", Value.newBuilder().setStringValue(f.content()).build());
208+
listBuilder.addValues(Value.newBuilder().setStructValue(fileStructBuilder.build()));
209+
}
210+
paramsBuilder.putFields(
211+
"files", Value.newBuilder().setListValue(listBuilder.build()).build());
212+
}
213+
sessionId.ifPresent(
214+
s -> paramsBuilder.putFields("session_id", Value.newBuilder().setStringValue(s).build()));
215+
216+
ExecuteExtensionRequest request =
217+
ExecuteExtensionRequest.newBuilder()
218+
.setName(this.resourceName)
219+
.setOperationId("generate_and_execute")
220+
.setOperationParams(paramsBuilder.build())
221+
.build();
222+
223+
ExecuteExtensionResponse response = codeInterpreterExtension.executeExtension(request);
224+
String jsonOutput = response.getContent();
225+
if (jsonOutput == null || jsonOutput.isEmpty()) {
226+
return ImmutableMap.of(
227+
"execution_result", "", "execution_error", "", "output_files", new ArrayList<>());
228+
}
229+
230+
try {
231+
ObjectMapper mapper = new ObjectMapper();
232+
return mapper.readValue(jsonOutput, new TypeReference<Map<String, Object>>() {});
233+
} catch (IOException e) {
234+
logger.log(Level.SEVERE, "Failed to parse JSON from code interpreter: " + jsonOutput, e);
235+
return ImmutableMap.of(
236+
"execution_result",
237+
"",
238+
"execution_error",
239+
"Failed to parse extension response: " + e.getMessage(),
240+
"output_files",
241+
new ArrayList<>());
242+
}
243+
}
244+
245+
private String getCodeWithImports(String code) {
246+
return String.format("%s\n\n%s", IMPORTED_LIBRARIES, code);
247+
}
248+
}

core/src/main/java/com/google/adk/events/Event.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,14 +237,28 @@ public final ImmutableList<FunctionResponse> functionResponses() {
237237
.collect(toImmutableList());
238238
}
239239

240+
/** Returns whether the event has a trailing code execution result. */
241+
@JsonIgnore
242+
public final boolean hasTrailingCodeExecutionResult() {
243+
return content()
244+
.flatMap(Content::parts)
245+
.filter(parts -> !parts.isEmpty())
246+
.map(parts -> parts.get(parts.size() - 1))
247+
.flatMap(part -> part.codeExecutionResult())
248+
.isPresent();
249+
}
250+
240251
/** Returns true if this is a final response. */
241252
@JsonIgnore
242253
public final boolean finalResponse() {
243254
if (actions().skipSummarization().orElse(false)
244255
|| (longRunningToolIds().isPresent() && !longRunningToolIds().get().isEmpty())) {
245256
return true;
246257
}
247-
return functionCalls().isEmpty() && functionResponses().isEmpty() && !partial().orElse(false);
258+
return functionCalls().isEmpty()
259+
&& functionResponses().isEmpty()
260+
&& !partial().orElse(false)
261+
&& !hasTrailingCodeExecutionResult();
248262
}
249263

250264
/**

core/src/main/java/com/google/adk/flows/llmflows/AutoFlow.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@ public class AutoFlow extends SingleFlow {
2929
.add(new AgentTransfer())
3030
.build();
3131

32-
/** No additional response processors. */
33-
private static final ImmutableList<ResponseProcessor> RESPONSE_PROCESSORS = ImmutableList.of();
32+
/** Only base response processors. */
33+
private static final ImmutableList<ResponseProcessor> RESPONSE_PROCESSORS =
34+
ImmutableList.<ResponseProcessor>builder().addAll(SingleFlow.RESPONSE_PROCESSORS).build();
3435

3536
public AutoFlow() {
3637
this(/* maxSteps= */ Optional.empty());

core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ private static Flowable<Event> runPostProcessor(
294294
}
295295
String codeStr = codeStrOptional.get();
296296
responseContent = responseContentBuilder.build();
297-
llmResponseBuilder.content(Content.builder().build());
297+
llmResponseBuilder.content(Optional.empty());
298298

299299
Event codeEvent =
300300
Event.builder()

core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import java.util.List;
2121
import java.util.Optional;
2222

23-
/** Basic LLM flow with fixed request processors and no response post-processing. */
23+
/** Basic LLM flow with fixed request and response processors. */
2424
public class SingleFlow extends BaseLlmFlow {
2525
// TODO: We should eventually remove this class since it complicates things.
2626

core/src/main/java/com/google/adk/models/LlmResponse.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ static LlmResponse.Builder jacksonBuilder() {
109109
@JsonProperty("content")
110110
public abstract Builder content(Content content);
111111

112+
public abstract Builder content(Optional<Content> content);
113+
112114
@JsonProperty("interrupted")
113115
public abstract Builder interrupted(@Nullable Boolean interrupted);
114116

0 commit comments

Comments
 (0)