Skip to content

Commit a9930b8

Browse files
Add tool hosting object parsing
1 parent d7f9b87 commit a9930b8

File tree

3 files changed

+35
-1
lines changed

3 files changed

+35
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ packages = [
3434
Repository = 'https://github.com/axiomatic-ai/axiomatic-python-sdk'
3535

3636
[tool.poetry.dependencies]
37+
dill = ">=0.3.9"
3738
python = "^3.8"
3839
httpx = ">=0.21.2"
3940
pydantic = ">= 1.9.2"

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
dill==0.3.9
12
httpx>=0.21.2
23
pydantic>= 1.9.2
34
pydantic-core==^2.18.2

src/axiomatic/client.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import base64
2+
import dill
3+
import json
24
import requests
35
import os
46
import time
@@ -129,11 +131,41 @@ def tool_exec(self, tool: str, code: str, poll_interval: int = 3, debug: bool =
129131
if debug:
130132
print(f"status: {result.status}")
131133
if result.status == "SUCCEEDED":
132-
return result.output
134+
output = json.loads(result.output)
135+
if not output['objects']:
136+
return result.output
137+
else:
138+
return {
139+
"messages": output['messages'],
140+
"objects": self._load_objects_from_base64(output['objects'])
141+
}
133142
else:
134143
return result.error_trace
135144
else:
136145
return output.error_trace
137146

147+
def load(self, job_id: str, obj_key: str):
148+
result = self._ax_client.tools.status(job_id=job_id)
149+
if result.status == "SUCCEEDED":
150+
output = json.loads(result.output)
151+
if not output['objects']:
152+
return result.output
153+
else:
154+
return self._load_objects_from_base64(output['objects'])[obj_key]
155+
else:
156+
return result.error_trace
157+
158+
def _load_objects_from_base64(self, encoded_dict):
159+
loaded_objects = {}
160+
for key, encoded_str in encoded_dict.items():
161+
try:
162+
decoded_bytes = base64.b64decode(encoded_str)
163+
loaded_obj = dill.loads(decoded_bytes)
164+
loaded_objects[key] = loaded_obj
165+
except Exception as e:
166+
print(f"Error loading object for key '{key}': {e}")
167+
loaded_objects[key] = None
168+
return loaded_objects
169+
138170

139171
class AsyncAxiomatic(AsyncBaseClient): ...

0 commit comments

Comments
 (0)