Skip to content

Commit 73845ad

Browse files
authored
refactor: make the event loop thread part of the interpreter state (#990)
Signed-off-by: Louis Mandel <[email protected]>
1 parent db6394c commit 73845ad

File tree

4 files changed

+42
-37
lines changed

4 files changed

+42
-37
lines changed

src/pdl/pdl_granite_io.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# pylint: disable=import-outside-toplevel
2-
from asyncio import run_coroutine_threadsafe
2+
from asyncio import AbstractEventLoop, run_coroutine_threadsafe
33
from typing import Any, Optional
44

55
from granite_io.types import ChatCompletionInputs
@@ -13,7 +13,6 @@
1313
PDLRuntimeError,
1414
)
1515
from .pdl_lazy import PdlConst, PdlLazy, lazy_apply
16-
from .pdl_llms import _LOOP
1716
from .pdl_utils import value_of_expr
1817

1918

@@ -113,15 +112,14 @@ async def async_generate_text(
113112

114113
@staticmethod
115114
def generate_text(
116-
block: GraniteioModelBlock,
117-
messages: ModelInput,
115+
block: GraniteioModelBlock, messages: ModelInput, event_loop: AbstractEventLoop
118116
) -> tuple[LazyMessage, PdlLazy[Any]]:
119117
future = run_coroutine_threadsafe(
120118
GraniteioModel.async_generate_text(
121119
block,
122120
messages,
123121
),
124-
_LOOP,
122+
event_loop,
125123
)
126124
pdl_future: PdlLazy[tuple[dict[str, Any], Any]] = PdlConst(future)
127125
message = lazy_apply((lambda x: x[0]), pdl_future)

src/pdl/pdl_interpreter.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
# TODO: temporarily disabling warnings to mute a pydantic warning from liteLLM
1212
import warnings
13+
from asyncio import AbstractEventLoop
1314
from functools import partial
1415
from os import getenv
1516

@@ -31,7 +32,7 @@
3132
)
3233
from jinja2.nodes import TemplateData # noqa: E402
3334
from jinja2.runtime import Undefined # noqa: E402
34-
from pydantic import BaseModel # noqa: E402
35+
from pydantic import BaseModel, ConfigDict, Field # noqa: E402
3536

3637
from .pdl_ast import ( # noqa: E402
3738
AdvancedBlockType,
@@ -110,7 +111,11 @@
110111
from .pdl_location_utils import append, get_loc_string # noqa: E402
111112
from .pdl_parser import PDLParseError, parse_file, parse_str # noqa: E402
112113
from .pdl_python_repl import PythonREPL # noqa: E402
113-
from .pdl_scheduler import yield_background, yield_result # noqa: E402
114+
from .pdl_scheduler import ( # noqa: E402
115+
create_event_loop_thread,
116+
yield_background,
117+
yield_result,
118+
)
114119
from .pdl_schema_utils import get_json_schema # noqa: E402
115120
from .pdl_schema_validator import type_check_args, type_check_spec # noqa: E402
116121
from .pdl_utils import ( # noqa: E402
@@ -127,6 +132,8 @@
127132

128133

129134
class InterpreterState(BaseModel):
135+
model_config = ConfigDict(arbitrary_types_allowed=True)
136+
130137
yield_result: bool = False
131138
yield_background: bool = False
132139
batch: int = 1
@@ -136,6 +143,7 @@ class InterpreterState(BaseModel):
136143
cwd: Path = Path.cwd()
137144
# background_tasks = {}
138145
id_stack: list[str] = []
146+
event_loop: AbstractEventLoop = Field(default_factory=create_event_loop_thread)
139147

140148
def with_yield_result(self: "InterpreterState", b: bool) -> "InterpreterState":
141149
return self.model_copy(update={"yield_result": b})
@@ -1638,13 +1646,15 @@ def generate_client_response_single(
16381646
model_id=value_of_expr(block.model),
16391647
messages=model_input,
16401648
parameters=litellm_parameters_to_dict(parameters),
1649+
event_loop=state.event_loop,
16411650
)
16421651
case GraniteioModelBlock():
16431652
from .pdl_granite_io import GraniteioModel
16441653

16451654
message, response = GraniteioModel.generate_text(
16461655
block=block,
16471656
messages=model_input,
1657+
event_loop=state.event_loop,
16481658
)
16491659
case _:
16501660
assert False

src/pdl/pdl_llms.py

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
# pylint: disable=import-outside-toplevel
2-
import asyncio
3-
import threading
4-
from concurrent.futures import Future
2+
from asyncio import AbstractEventLoop, run_coroutine_threadsafe
53
from os import environ
64
from sys import stderr
7-
from typing import Any, Callable, Generator, Optional, TypeVar
5+
from typing import Any, Generator, Optional, TypeVar
86

97
import httpx
108
from dotenv import load_dotenv
@@ -25,19 +23,6 @@
2523
load_dotenv()
2624

2725

28-
def _start_background_loop(loop):
29-
asyncio.set_event_loop(loop)
30-
loop.run_forever()
31-
32-
33-
_LOOP = asyncio.new_event_loop()
34-
_LOOP_THREAD = threading.Thread(
35-
target=_start_background_loop, args=(_LOOP,), daemon=True
36-
)
37-
_LOOP_THREAD.start()
38-
# _BACKGROUND_TASKS = set()
39-
40-
4126
class LitellmModel:
4227
@staticmethod
4328
async def async_generate_text(
@@ -88,21 +73,19 @@ def generate_text(
8873
model_id: str,
8974
messages: ModelInput,
9075
parameters: dict[str, Any],
76+
event_loop: AbstractEventLoop,
9177
) -> tuple[LazyMessage, PdlLazy[Any]]:
9278
if "PDL_VERBOSE_ASYNC" in environ:
9379
print(f"Asynchronous model call started to {model_id}", file=stderr)
94-
# global _BACKGROUND_TASKS
95-
future = asyncio.run_coroutine_threadsafe(
80+
future = run_coroutine_threadsafe(
9681
LitellmModel.async_generate_text(
9782
block,
9883
model_id,
9984
messages,
10085
parameters,
10186
),
102-
_LOOP,
87+
event_loop,
10388
)
104-
# _BACKGROUND_TASKS.add(future)
105-
# future.add_done_callback(_BACKGROUND_TASKS.discard)
10689
pdl_future: PdlLazy[tuple[dict[str, Any], Any]] = PdlConst(future)
10790
message = lazy_apply((lambda x: x[0]), pdl_future)
10891
response = lazy_apply((lambda x: x[1]), pdl_future)
@@ -213,13 +196,13 @@ def set_structured_decoding_parameters(
213196
MapOutputT = TypeVar("MapOutputT")
214197

215198

216-
def map_future(
217-
f: Callable[[MapInputT], MapOutputT], x: Future[MapInputT]
218-
) -> Future[MapOutputT]:
219-
future = asyncio.run_coroutine_threadsafe(_async_call(f, x), _LOOP)
220-
return future
199+
# def map_future(
200+
# f: Callable[[MapInputT], MapOutputT], x: Future[MapInputT]
201+
# ) -> Future[MapOutputT]:
202+
# future = asyncio.run_coroutine_threadsafe(_async_call(f, x), _LOOP)
203+
# return future
221204

222205

223-
async def _async_call(f, x):
224-
v = x.result()
225-
return f(v)
206+
# async def _async_call(f, x):
207+
# v = x.result()
208+
# return f(v)

src/pdl/pdl_scheduler.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from asyncio import AbstractEventLoop, new_event_loop, set_event_loop
2+
from threading import Thread
13
from typing import Any, Optional
24

35
from termcolor import colored
@@ -6,6 +8,18 @@
68
from .pdl_utils import stringify
79

810

11+
def _start_background_loop(loop):
12+
set_event_loop(loop)
13+
loop.run_forever()
14+
15+
16+
def create_event_loop_thread() -> AbstractEventLoop:
17+
loop = new_event_loop()
18+
loop_thread = Thread(target=_start_background_loop, args=(loop,), daemon=True)
19+
loop_thread.start()
20+
return loop
21+
22+
923
def color_of(kind: BlockKind):
1024
color: Optional[str]
1125
match kind:

0 commit comments

Comments
 (0)