33https://platform.openai.com/docs/guides/vision
44"""
55
6+ from copy import deepcopy
67from pprint import pformat
8+ from typing import Any
79
810from loguru import logger
11+ from PIL import Image
912import requests
1013
11- from openadapt import cache
14+ from openadapt import cache , utils
1215from openadapt .config import config
1316
1417MODEL_NAME = [
1518 "gpt-4-vision-preview" ,
1619 "gpt-4-turbo-2024-04-09" ,
20+ "gpt-4o" ,
1721][- 1 ]
22+ # TODO XXX: per model
1823MAX_TOKENS = 4096
1924# TODO XXX undocumented
2025MAX_IMAGES = None
2328def create_payload (
2429 prompt : str ,
2530 system_prompt : str | None = None ,
26- base64_images : list [str ] | None = None ,
31+ images : list [Image . Image ] | None = None ,
2732 model : str = MODEL_NAME ,
2833 detail : str = "high" , # "low" or "high"
2934 max_tokens : int | None = None ,
@@ -33,7 +38,7 @@ def create_payload(
3338 Args:
3439 prompt: the prompt
3540 system_prompt: the system prompt
36- base64_images : list of base64 encoded images
41+ images : list of images
3742 model: name of OpenAI model
3843 detail: detail level of images, "low" or "high"
3944 max_tokens: maximum number of tokens
@@ -59,8 +64,9 @@ def create_payload(
5964 },
6065 ]
6166
62- base64_images = base64_images or []
63- for base64_image in base64_images :
67+ images = images or []
68+ for image in images :
69+ base64_image = utils .image2utf8 (image )
6470 messages [0 ]["content" ].append (
6571 {
6672 "type" : "image_url" ,
@@ -94,18 +100,22 @@ def create_payload(
94100
95101
96102@cache .cache ()
97- def get_response (payload : dict ) -> requests .Response :
103+ def get_response (
104+ payload : dict ,
105+ api_key : str = config .OPENAI_API_KEY ,
106+ ) -> requests .Response :
98107 """Sends a request to the OpenAI API and returns the response.
99108
100109 Args:
101110 payload: dictionary returned by create_payload
111+ api_key (str): api key
102112
103113 Returns:
104114 response from OpenAI API
105115 """
106116 headers = {
107117 "Content-Type" : "application/json" ,
108- "Authorization" : f"Bearer { config . OPENAI_API_KEY } " ,
118+ "Authorization" : f"Bearer { api_key } " ,
109119 }
110120 response = requests .post (
111121 "https://api.openai.com/v1/chat/completions" ,
@@ -115,14 +125,15 @@ def get_response(payload: dict) -> requests.Response:
115125 return response
116126
117127
118- def get_completion (payload : dict ) -> str :
128+ def get_completion (payload : dict , dev_mode : bool = False ) -> str :
119129 """Sends a request to the OpenAI API and returns the first message.
120130
121131 Args:
122- pyalod: dictionary returned by create_payload
132+ payload (dict): dictionary returned by create_payload
133+ dev_mode (bool): whether to launch a debugger on error
123134
124135 Returns:
125- string containing the first message from the response
136+ (str) first message from the response
126137 """
127138 response = get_response (payload )
128139 result = response .json ()
@@ -133,22 +144,37 @@ def get_completion(payload: dict) -> str:
133144 # TODO: fail after maximum number of attempts
134145 if "retry your request" in message :
135146 return get_completion (payload )
136- else :
147+ elif dev_mode :
137148 import ipdb
138149
139150 ipdb .set_trace ()
140151 # TODO: handle more errors
152+ else :
153+ raise ValueError (result ["error" ]["message" ])
141154 choices = result ["choices" ]
142155 choice = choices [0 ]
143156 message = choice ["message" ]
144157 content = message ["content" ]
145158 return content
146159
147160
161+ def log_payload (payload : dict [Any , Any ]) -> None :
162+ """Logs a payload after removing base-64 encoded values recursively."""
163+ # TODO: detect base64 encoded strings dynamically
164+ # messages["content"][{"image_url": ...
165+ # payload["messages"][1]["content"][9]["image_url"]
166+ payload_copy = deepcopy (payload )
167+ for message in payload_copy ["messages" ]:
168+ for content in message ["content" ]:
169+ if "image_url" in content :
170+ content ["image_url" ]["url" ] = "[REDACTED]"
171+ logger .info (f"payload=\n { pformat (payload_copy )} " )
172+
173+
148174def prompt (
149175 prompt : str ,
150176 system_prompt : str | None = None ,
151- base64_images : list [str ] | None = None ,
177+ images : list [Image . Image ] | None = None ,
152178 max_tokens : int | None = None ,
153179 detail : str = "high" ,
154180) -> str :
@@ -157,7 +183,7 @@ def prompt(
157183 Args:
158184 prompt: the prompt
159185 system_prompt: the system prompt
160- base64_images : list of base64 encoded images
186+ images : list of images
161187 model: name of OpenAI model
162188 detail: detail level of images, "low" or "high"
163189 max_tokens: maximum number of tokens
@@ -168,11 +194,11 @@ def prompt(
168194 payload = create_payload (
169195 prompt ,
170196 system_prompt ,
171- base64_images ,
197+ images ,
172198 max_tokens = max_tokens ,
173199 detail = detail ,
174200 )
175- logger . info ( f" payload= \n { pformat ( payload ) } " )
201+ log_payload ( payload )
176202 result = get_completion (payload )
177203 logger .info (f"result=\n { pformat (result )} " )
178204 return result
0 commit comments