|
| 1 | +/* |
| 2 | + * Copyright 2025 Google LLC |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +package com.google.adk.codeexecutors; |
| 18 | + |
| 19 | +import com.fasterxml.jackson.core.type.TypeReference; |
| 20 | +import com.fasterxml.jackson.databind.ObjectMapper; |
| 21 | +import com.google.adk.agents.InvocationContext; |
| 22 | +import com.google.adk.codeexecutors.CodeExecutionUtils.CodeExecutionInput; |
| 23 | +import com.google.adk.codeexecutors.CodeExecutionUtils.CodeExecutionResult; |
| 24 | +import com.google.adk.codeexecutors.CodeExecutionUtils.File; |
| 25 | +import com.google.cloud.aiplatform.v1beta1.ExecuteExtensionRequest; |
| 26 | +import com.google.cloud.aiplatform.v1beta1.ExecuteExtensionResponse; |
| 27 | +import com.google.cloud.aiplatform.v1beta1.ExtensionExecutionServiceClient; |
| 28 | +import com.google.cloud.aiplatform.v1beta1.ExtensionExecutionServiceSettings; |
| 29 | +import com.google.common.collect.ImmutableList; |
| 30 | +import com.google.common.collect.ImmutableMap; |
| 31 | +import com.google.protobuf.ListValue; |
| 32 | +import com.google.protobuf.Struct; |
| 33 | +import com.google.protobuf.Value; |
| 34 | +import java.io.IOException; |
| 35 | +import java.net.URLConnection; |
| 36 | +import java.util.ArrayList; |
| 37 | +import java.util.List; |
| 38 | +import java.util.Map; |
| 39 | +import java.util.Optional; |
| 40 | +import java.util.logging.Level; |
| 41 | +import java.util.logging.Logger; |
| 42 | + |
| 43 | +/** |
| 44 | + * A code executor that uses Vertex Code Interpreter Extension to execute code. |
| 45 | + * |
| 46 | + * <p>Attributes: resourceName: If set, load the existing resource name of the code interpreter |
| 47 | + * extension instead of creating a new one. Format: |
| 48 | + * projects/123/locations/us-central1/extensions/456 |
| 49 | + * |
| 50 | + * <p>Follow https://cloud.google.com/vertex-ai/generative-ai/docs/extensions/code-interpreter for |
| 51 | + * setup. |
| 52 | + */ |
| 53 | +public final class VertexAiCodeExecutor extends BaseCodeExecutor { |
| 54 | + private static final Logger logger = Logger.getLogger(VertexAiCodeExecutor.class.getName()); |
| 55 | + |
| 56 | + private static final ImmutableList<String> SUPPORTED_IMAGE_TYPES = |
| 57 | + ImmutableList.of("png", "jpg", "jpeg"); |
| 58 | + private static final ImmutableList<String> SUPPORTED_DATA_FILE_TYPES = ImmutableList.of("csv"); |
| 59 | + |
| 60 | + private static final String IMPORTED_LIBRARIES = |
| 61 | + "import io\n" |
| 62 | + + "import math\n" |
| 63 | + + "import re\n" |
| 64 | + + "\n" |
| 65 | + + "import matplotlib.pyplot as plt\n" |
| 66 | + + "import numpy as np\n" |
| 67 | + + "import pandas as pd\n" |
| 68 | + + "import scipy\n" |
| 69 | + + "\n" |
| 70 | + + "def crop(s: str, max_chars: int = 64) -> str:\n" |
| 71 | + + " \"\"\"Crops a string to max_chars characters.\"\"\"\n" |
| 72 | + + " return s[: max_chars - 3] + '...' if len(s) > max_chars else s\n" |
| 73 | + + "\n" |
| 74 | + + "\n" |
| 75 | + + "def explore_df(df: pd.DataFrame) -> None:\n" |
| 76 | + + " \"\"\"Prints some information about a pandas DataFrame.\"\"\"\n" |
| 77 | + + "\n" |
| 78 | + + " with pd.option_context(\n" |
| 79 | + + " 'display.max_columns', None, 'display.expand_frame_repr', False\n" |
| 80 | + + " ):\n" |
| 81 | + + " # Print the column names to never encounter KeyError when selecting one.\n" |
| 82 | + + " df_dtypes = df.dtypes\n" |
| 83 | + + "\n" |
| 84 | + + " # Obtain information about data types and missing values.\n" |
| 85 | + + " df_nulls = (len(df) - df.isnull().sum()).apply(\n" |
| 86 | + + " lambda x: f'{x} / {df.shape[0]} non-null'\n" |
| 87 | + + " )\n" |
| 88 | + + "\n" |
| 89 | + + " # Explore unique total values in columns using `.unique()`.\n" |
| 90 | + + " df_unique_count = df.apply(lambda x: len(x.unique()))\n" |
| 91 | + + "\n" |
| 92 | + + " # Explore unique values in columns using `.unique()`.\n" |
| 93 | + + " df_unique = df.apply(lambda x: crop(str(list(x.unique()))))\n" |
| 94 | + + "\n" |
| 95 | + + " df_info = pd.concat(\n" |
| 96 | + + " (\n" |
| 97 | + + " df_dtypes.rename('Dtype'),\n" |
| 98 | + + " df_nulls.rename('Non-Null Count'),\n" |
| 99 | + + " df_unique_count.rename('Unique Values Count'),\n" |
| 100 | + + " df_unique.rename('Unique Values'),\n" |
| 101 | + + " ),\n" |
| 102 | + + " axis=1,\n" |
| 103 | + + " )\n" |
| 104 | + + " df_info.index.name = 'Columns'\n" |
| 105 | + + " print(f\"\"\"Total rows: {df.shape[0]}\n" |
| 106 | + + "Total columns: {df.shape[1]}\n" |
| 107 | + + "\n" |
| 108 | + + "{df_info}\"\"\")"; |
| 109 | + |
| 110 | + private final String resourceName; |
| 111 | + private final ExtensionExecutionServiceClient codeInterpreterExtension; |
| 112 | + |
| 113 | + /** |
| 114 | + * Initializes the VertexAiCodeExecutor. |
| 115 | + * |
| 116 | + * @param resourceName If set, load the existing resource name of the code interpreter extension |
| 117 | + * instead of creating a new one. Format: projects/123/locations/us-central1/extensions/456 |
| 118 | + */ |
| 119 | + public VertexAiCodeExecutor(String resourceName) { |
| 120 | + String resolvedResourceName = resourceName; |
| 121 | + if (resolvedResourceName == null || resolvedResourceName.isEmpty()) { |
| 122 | + resolvedResourceName = System.getenv("CODE_INTERPRETER_EXTENSION_NAME"); |
| 123 | + } |
| 124 | + |
| 125 | + if (resolvedResourceName == null || resolvedResourceName.isEmpty()) { |
| 126 | + logger.warning( |
| 127 | + "No resource name found for Vertex AI Code Interpreter. It will not be available."); |
| 128 | + this.resourceName = null; |
| 129 | + this.codeInterpreterExtension = null; |
| 130 | + } else { |
| 131 | + this.resourceName = resolvedResourceName; |
| 132 | + try { |
| 133 | + String[] parts = this.resourceName.split("/"); |
| 134 | + if (parts.length < 4 || !parts[2].equals("locations")) { |
| 135 | + throw new IllegalArgumentException("Invalid resource name format: " + this.resourceName); |
| 136 | + } |
| 137 | + String location = parts[3]; |
| 138 | + String endpoint = String.format("%s-aiplatform.googleapis.com:443", location); |
| 139 | + ExtensionExecutionServiceSettings settings = |
| 140 | + ExtensionExecutionServiceSettings.newBuilder().setEndpoint(endpoint).build(); |
| 141 | + this.codeInterpreterExtension = ExtensionExecutionServiceClient.create(settings); |
| 142 | + } catch (IOException e) { |
| 143 | + logger.log(Level.SEVERE, "Failed to create ExtensionExecutionServiceClient", e); |
| 144 | + throw new IllegalStateException("Failed to create ExtensionExecutionServiceClient", e); |
| 145 | + } |
| 146 | + } |
| 147 | + } |
| 148 | + |
| 149 | + @Override |
| 150 | + public CodeExecutionResult executeCode( |
| 151 | + InvocationContext invocationContext, CodeExecutionInput codeExecutionInput) { |
| 152 | + // Execute the code. |
| 153 | + Map<String, Object> codeExecutionResult = |
| 154 | + executeCodeInterpreter( |
| 155 | + getCodeWithImports(codeExecutionInput.code()), |
| 156 | + codeExecutionInput.inputFiles(), |
| 157 | + codeExecutionInput.executionId()); |
| 158 | + |
| 159 | + // Save output file as artifacts. |
| 160 | + List<File> savedFiles = new ArrayList<>(); |
| 161 | + if (codeExecutionResult.containsKey("output_files")) { |
| 162 | + @SuppressWarnings("unchecked") |
| 163 | + List<Map<String, String>> outputFiles = |
| 164 | + (List<Map<String, String>>) codeExecutionResult.get("output_files"); |
| 165 | + for (Map<String, String> outputFile : outputFiles) { |
| 166 | + String fileName = outputFile.get("name"); |
| 167 | + String content = outputFile.get("contents"); // This is a base64 string. |
| 168 | + String fileType = fileName.substring(fileName.lastIndexOf('.') + 1); |
| 169 | + String mimeType; |
| 170 | + if (SUPPORTED_IMAGE_TYPES.contains(fileType)) { |
| 171 | + mimeType = "image/" + fileType; |
| 172 | + } else if (SUPPORTED_DATA_FILE_TYPES.contains(fileType)) { |
| 173 | + mimeType = "text/" + fileType; |
| 174 | + } else { |
| 175 | + mimeType = URLConnection.guessContentTypeFromName(fileName); |
| 176 | + } |
| 177 | + savedFiles.add(File.builder().name(fileName).content(content).mimeType(mimeType).build()); |
| 178 | + } |
| 179 | + } |
| 180 | + |
| 181 | + // Collect the final result. |
| 182 | + return CodeExecutionResult.builder() |
| 183 | + .stdout((String) codeExecutionResult.getOrDefault("execution_result", "")) |
| 184 | + .stderr((String) codeExecutionResult.getOrDefault("execution_error", "")) |
| 185 | + .outputFiles(savedFiles) |
| 186 | + .build(); |
| 187 | + } |
| 188 | + |
| 189 | + private Map<String, Object> executeCodeInterpreter( |
| 190 | + String code, List<File> inputFiles, Optional<String> sessionId) { |
| 191 | + if (codeInterpreterExtension == null) { |
| 192 | + logger.warning( |
| 193 | + "Vertex AI Code Interpreter execution is not available. Returning empty result."); |
| 194 | + return ImmutableMap.of( |
| 195 | + "execution_result", "", "execution_error", "", "output_files", new ArrayList<>()); |
| 196 | + } |
| 197 | + |
| 198 | + // Build operationParams |
| 199 | + Struct.Builder paramsBuilder = Struct.newBuilder(); |
| 200 | + paramsBuilder.putFields("query", Value.newBuilder().setStringValue(code).build()); |
| 201 | + if (inputFiles != null && !inputFiles.isEmpty()) { |
| 202 | + ListValue.Builder listBuilder = ListValue.newBuilder(); |
| 203 | + for (File f : inputFiles) { |
| 204 | + Struct.Builder fileStructBuilder = Struct.newBuilder(); |
| 205 | + fileStructBuilder.putFields("name", Value.newBuilder().setStringValue(f.name()).build()); |
| 206 | + fileStructBuilder.putFields( |
| 207 | + "contents", Value.newBuilder().setStringValue(f.content()).build()); |
| 208 | + listBuilder.addValues(Value.newBuilder().setStructValue(fileStructBuilder.build())); |
| 209 | + } |
| 210 | + paramsBuilder.putFields( |
| 211 | + "files", Value.newBuilder().setListValue(listBuilder.build()).build()); |
| 212 | + } |
| 213 | + sessionId.ifPresent( |
| 214 | + s -> paramsBuilder.putFields("session_id", Value.newBuilder().setStringValue(s).build())); |
| 215 | + |
| 216 | + ExecuteExtensionRequest request = |
| 217 | + ExecuteExtensionRequest.newBuilder() |
| 218 | + .setName(this.resourceName) |
| 219 | + .setOperationId("generate_and_execute") |
| 220 | + .setOperationParams(paramsBuilder.build()) |
| 221 | + .build(); |
| 222 | + |
| 223 | + ExecuteExtensionResponse response = codeInterpreterExtension.executeExtension(request); |
| 224 | + String jsonOutput = response.getContent(); |
| 225 | + if (jsonOutput == null || jsonOutput.isEmpty()) { |
| 226 | + return ImmutableMap.of( |
| 227 | + "execution_result", "", "execution_error", "", "output_files", new ArrayList<>()); |
| 228 | + } |
| 229 | + |
| 230 | + try { |
| 231 | + ObjectMapper mapper = new ObjectMapper(); |
| 232 | + return mapper.readValue(jsonOutput, new TypeReference<Map<String, Object>>() {}); |
| 233 | + } catch (IOException e) { |
| 234 | + logger.log(Level.SEVERE, "Failed to parse JSON from code interpreter: " + jsonOutput, e); |
| 235 | + return ImmutableMap.of( |
| 236 | + "execution_result", |
| 237 | + "", |
| 238 | + "execution_error", |
| 239 | + "Failed to parse extension response: " + e.getMessage(), |
| 240 | + "output_files", |
| 241 | + new ArrayList<>()); |
| 242 | + } |
| 243 | + } |
| 244 | + |
| 245 | + private String getCodeWithImports(String code) { |
| 246 | + return String.format("%s\n\n%s", IMPORTED_LIBRARIES, code); |
| 247 | + } |
| 248 | +} |
0 commit comments