Skip to content

Commit 25424f1

Browse files
committed
Await callbacks
1 parent 8e84abf commit 25424f1

File tree

3 files changed

+76
-27
lines changed

3 files changed

+76
-27
lines changed

python/e2b_code_interpreter/code_interpreter_async.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
Context,
2121
Result,
2222
aextract_exception,
23-
parse_output,
24-
OutputHandler,
23+
OutputHandlerWithAsync,
24+
async_parse_output,
2525
OutputMessage,
2626
)
2727
from e2b_code_interpreter.exceptions import (
@@ -69,10 +69,10 @@ async def run_code(
6969
self,
7070
code: str,
7171
language: Union[Literal["python"], None] = None,
72-
on_stdout: Optional[OutputHandler[OutputMessage]] = None,
73-
on_stderr: Optional[OutputHandler[OutputMessage]] = None,
74-
on_result: Optional[OutputHandler[Result]] = None,
75-
on_error: Optional[OutputHandler[ExecutionError]] = None,
72+
on_stdout: Optional[OutputHandlerWithAsync[OutputMessage]] = None,
73+
on_stderr: Optional[OutputHandlerWithAsync[OutputMessage]] = None,
74+
on_result: Optional[OutputHandlerWithAsync[Result]] = None,
75+
on_error: Optional[OutputHandlerWithAsync[ExecutionError]] = None,
7676
envs: Optional[Dict[str, str]] = None,
7777
timeout: Optional[float] = None,
7878
request_timeout: Optional[float] = None,
@@ -103,10 +103,10 @@ async def run_code(
103103
self,
104104
code: str,
105105
language: Optional[str] = None,
106-
on_stdout: Optional[OutputHandler[OutputMessage]] = None,
107-
on_stderr: Optional[OutputHandler[OutputMessage]] = None,
108-
on_result: Optional[OutputHandler[Result]] = None,
109-
on_error: Optional[OutputHandler[ExecutionError]] = None,
106+
on_stdout: Optional[OutputHandlerWithAsync[OutputMessage]] = None,
107+
on_stderr: Optional[OutputHandlerWithAsync[OutputMessage]] = None,
108+
on_result: Optional[OutputHandlerWithAsync[Result]] = None,
109+
on_error: Optional[OutputHandlerWithAsync[ExecutionError]] = None,
110110
envs: Optional[Dict[str, str]] = None,
111111
timeout: Optional[float] = None,
112112
request_timeout: Optional[float] = None,
@@ -138,10 +138,10 @@ async def run_code(
138138
self,
139139
code: str,
140140
context: Optional[Context] = None,
141-
on_stdout: Optional[OutputHandler[OutputMessage]] = None,
142-
on_stderr: Optional[OutputHandler[OutputMessage]] = None,
143-
on_result: Optional[OutputHandler[Result]] = None,
144-
on_error: Optional[OutputHandler[ExecutionError]] = None,
141+
on_stdout: Optional[OutputHandlerWithAsync[OutputMessage]] = None,
142+
on_stderr: Optional[OutputHandlerWithAsync[OutputMessage]] = None,
143+
on_result: Optional[OutputHandlerWithAsync[Result]] = None,
144+
on_error: Optional[OutputHandlerWithAsync[ExecutionError]] = None,
145145
envs: Optional[Dict[str, str]] = None,
146146
timeout: Optional[float] = None,
147147
request_timeout: Optional[float] = None,
@@ -172,10 +172,10 @@ async def run_code(
172172
code: str,
173173
language: Optional[str] = None,
174174
context: Optional[Context] = None,
175-
on_stdout: Optional[OutputHandler[OutputMessage]] = None,
176-
on_stderr: Optional[OutputHandler[OutputMessage]] = None,
177-
on_result: Optional[OutputHandler[Result]] = None,
178-
on_error: Optional[OutputHandler[ExecutionError]] = None,
175+
on_stdout: Optional[OutputHandlerWithAsync[OutputMessage]] = None,
176+
on_stderr: Optional[OutputHandlerWithAsync[OutputMessage]] = None,
177+
on_result: Optional[OutputHandlerWithAsync[Result]] = None,
178+
on_error: Optional[OutputHandlerWithAsync[ExecutionError]] = None,
179179
envs: Optional[Dict[str, str]] = None,
180180
timeout: Optional[float] = None,
181181
request_timeout: Optional[float] = None,
@@ -215,7 +215,7 @@ async def run_code(
215215
execution = Execution()
216216

217217
async for line in response.aiter_lines():
218-
parse_output(
218+
await async_parse_output(
219219
execution,
220220
line,
221221
on_stdout=on_stdout,

python/e2b_code_interpreter/models.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import json
23
import logging
34

@@ -20,8 +21,10 @@
2021
from .charts import Chart, _deserialize_chart
2122

2223
T = TypeVar("T")
23-
OutputHandler = Union[
24-
Callable[[T], Any],
24+
OutputHandler = Union[Callable[[T], Any],]
25+
26+
OutputHandlerWithAsync = Union[
27+
OutputHandler[T],
2528
Callable[[T], Awaitable[Any]],
2629
]
2730

@@ -446,6 +449,46 @@ def parse_output(
446449
execution.execution_count = data["execution_count"]
447450

448451

452+
async def async_parse_output(
453+
execution: Execution,
454+
output: str,
455+
on_stdout: Optional[OutputHandlerWithAsync[OutputMessage]] = None,
456+
on_stderr: Optional[OutputHandlerWithAsync[OutputMessage]] = None,
457+
on_result: Optional[OutputHandlerWithAsync[Result]] = None,
458+
on_error: Optional[OutputHandlerWithAsync[ExecutionError]] = None,
459+
):
460+
data = json.loads(output)
461+
data_type = data.pop("type")
462+
463+
if data_type == "result":
464+
result = Result(**data)
465+
execution.results.append(result)
466+
if on_result:
467+
cb = on_result(result)
468+
if inspect.isawaitable(cb):
469+
await cb
470+
elif data_type == "stdout":
471+
execution.logs.stdout.append(data["text"])
472+
if on_stdout:
473+
cb = on_stdout(OutputMessage(data["text"], data["timestamp"], False))
474+
if inspect.isawaitable(cb):
475+
await cb
476+
elif data_type == "stderr":
477+
execution.logs.stderr.append(data["text"])
478+
if on_stderr:
479+
cb = on_stderr(OutputMessage(data["text"], data["timestamp"], True))
480+
if inspect.isawaitable(cb):
481+
await cb
482+
elif data_type == "error":
483+
execution.error = ExecutionError(data["name"], data["value"], data["traceback"])
484+
if on_error:
485+
cb = on_error(execution.error)
486+
if inspect.isawaitable(cb):
487+
await cb
488+
elif data_type == "number_of_executions":
489+
execution.execution_count = data["execution_count"]
490+
491+
449492
@dataclass
450493
class Context:
451494
"""
Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,34 @@
11
from e2b_code_interpreter.code_interpreter_async import AsyncSandbox
22

33

4+
def async_append_fn(items):
5+
async def async_append(item):
6+
items.append(item)
7+
8+
return async_append
9+
10+
411
async def test_resuls(async_sandbox: AsyncSandbox):
512
results = []
13+
614
execution = await async_sandbox.run_code(
7-
"x = 1;x", on_result=lambda result: results.append(result)
15+
"x = 1;x", on_result=async_append_fn(results)
816
)
917
assert len(results) == 1
1018
assert execution.results[0].text == "1"
1119

1220

1321
async def test_error(async_sandbox: AsyncSandbox):
1422
errors = []
15-
execution = await async_sandbox.run_code(
16-
"xyz", on_error=lambda error: errors.append(error)
17-
)
23+
execution = await async_sandbox.run_code("xyz", on_error=async_append_fn(errors))
1824
assert len(errors) == 1
1925
assert execution.error.name == "NameError"
2026

2127

2228
async def test_stdout(async_sandbox: AsyncSandbox):
2329
stdout = []
2430
execution = await async_sandbox.run_code(
25-
"print('Hello from e2b')", on_stdout=lambda out: stdout.append(out)
31+
"print('Hello from e2b')", on_stdout=async_append_fn(stdout)
2632
)
2733
assert len(stdout) == 1
2834
assert execution.logs.stdout == ["Hello from e2b\n"]
@@ -32,7 +38,7 @@ async def test_stderr(async_sandbox: AsyncSandbox):
3238
stderr = []
3339
execution = await async_sandbox.run_code(
3440
'import sys;print("This is an error message", file=sys.stderr)',
35-
on_stderr=lambda err: stderr.append(err),
41+
on_stderr=async_append_fn(stderr),
3642
)
3743
assert len(stderr) == 1
3844
assert execution.logs.stderr == ["This is an error message\n"]

0 commit comments

Comments
 (0)