diff --git a/src/pdl/pdl_granite_io.py b/src/pdl/pdl_granite_io.py index 6b31dbcbe..4f14dca6e 100644 --- a/src/pdl/pdl_granite_io.py +++ b/src/pdl/pdl_granite_io.py @@ -1,5 +1,5 @@ # pylint: disable=import-outside-toplevel -from asyncio import run_coroutine_threadsafe +from asyncio import AbstractEventLoop, run_coroutine_threadsafe from typing import Any, Optional from granite_io.types import ChatCompletionInputs @@ -13,7 +13,6 @@ PDLRuntimeError, ) from .pdl_lazy import PdlConst, PdlLazy, lazy_apply -from .pdl_llms import _LOOP from .pdl_utils import value_of_expr @@ -113,15 +112,14 @@ async def async_generate_text( @staticmethod def generate_text( - block: GraniteioModelBlock, - messages: ModelInput, + block: GraniteioModelBlock, messages: ModelInput, event_loop: AbstractEventLoop ) -> tuple[LazyMessage, PdlLazy[Any]]: future = run_coroutine_threadsafe( GraniteioModel.async_generate_text( block, messages, ), - _LOOP, + event_loop, ) pdl_future: PdlLazy[tuple[dict[str, Any], Any]] = PdlConst(future) message = lazy_apply((lambda x: x[0]), pdl_future) diff --git a/src/pdl/pdl_interpreter.py b/src/pdl/pdl_interpreter.py index 9b1fee1b3..85fe87041 100644 --- a/src/pdl/pdl_interpreter.py +++ b/src/pdl/pdl_interpreter.py @@ -10,6 +10,7 @@ # TODO: temporarily disabling warnings to mute a pydantic warning from liteLLM import warnings +from asyncio import AbstractEventLoop from functools import partial from os import getenv @@ -31,7 +32,7 @@ ) from jinja2.nodes import TemplateData # noqa: E402 from jinja2.runtime import Undefined # noqa: E402 -from pydantic import BaseModel # noqa: E402 +from pydantic import BaseModel, ConfigDict, Field # noqa: E402 from .pdl_ast import ( # noqa: E402 AdvancedBlockType, @@ -110,7 +111,11 @@ from .pdl_location_utils import append, get_loc_string # noqa: E402 from .pdl_parser import PDLParseError, parse_file, parse_str # noqa: E402 from .pdl_python_repl import PythonREPL # noqa: E402 -from .pdl_scheduler import yield_background, yield_result # noqa: E402 +from .pdl_scheduler import ( # noqa: E402 + create_event_loop_thread, + yield_background, + yield_result, +) from .pdl_schema_utils import get_json_schema # noqa: E402 from .pdl_schema_validator import type_check_args, type_check_spec # noqa: E402 from .pdl_utils import ( # noqa: E402 @@ -127,6 +132,8 @@ class InterpreterState(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + yield_result: bool = False yield_background: bool = False batch: int = 1 @@ -136,6 +143,7 @@ class InterpreterState(BaseModel): cwd: Path = Path.cwd() # background_tasks = {} id_stack: list[str] = [] + event_loop: AbstractEventLoop = Field(default_factory=create_event_loop_thread) def with_yield_result(self: "InterpreterState", b: bool) -> "InterpreterState": return self.model_copy(update={"yield_result": b}) @@ -1638,6 +1646,7 @@ def generate_client_response_single( model_id=value_of_expr(block.model), messages=model_input, parameters=litellm_parameters_to_dict(parameters), + event_loop=state.event_loop, ) case GraniteioModelBlock(): from .pdl_granite_io import GraniteioModel @@ -1645,6 +1654,7 @@ def generate_client_response_single( message, response = GraniteioModel.generate_text( block=block, messages=model_input, + event_loop=state.event_loop, ) case _: assert False diff --git a/src/pdl/pdl_llms.py b/src/pdl/pdl_llms.py index f9d569654..ec78c1272 100644 --- a/src/pdl/pdl_llms.py +++ b/src/pdl/pdl_llms.py @@ -1,10 +1,8 @@ # pylint: disable=import-outside-toplevel -import asyncio -import threading -from concurrent.futures import Future +from asyncio import AbstractEventLoop, run_coroutine_threadsafe from os import environ from sys import stderr -from typing import Any, Callable, Generator, Optional, TypeVar +from typing import Any, Generator, Optional, TypeVar import httpx from dotenv import load_dotenv @@ -25,19 +23,6 @@ load_dotenv() -def _start_background_loop(loop): - asyncio.set_event_loop(loop) - loop.run_forever() - - -_LOOP = asyncio.new_event_loop() -_LOOP_THREAD = threading.Thread( - target=_start_background_loop, args=(_LOOP,), daemon=True -) -_LOOP_THREAD.start() -# _BACKGROUND_TASKS = set() - - class LitellmModel: @staticmethod async def async_generate_text( @@ -88,21 +73,19 @@ def generate_text( model_id: str, messages: ModelInput, parameters: dict[str, Any], + event_loop: AbstractEventLoop, ) -> tuple[LazyMessage, PdlLazy[Any]]: if "PDL_VERBOSE_ASYNC" in environ: print(f"Asynchronous model call started to {model_id}", file=stderr) - # global _BACKGROUND_TASKS - future = asyncio.run_coroutine_threadsafe( + future = run_coroutine_threadsafe( LitellmModel.async_generate_text( block, model_id, messages, parameters, ), - _LOOP, + event_loop, ) - # _BACKGROUND_TASKS.add(future) - # future.add_done_callback(_BACKGROUND_TASKS.discard) pdl_future: PdlLazy[tuple[dict[str, Any], Any]] = PdlConst(future) message = lazy_apply((lambda x: x[0]), pdl_future) response = lazy_apply((lambda x: x[1]), pdl_future) @@ -213,13 +196,13 @@ def set_structured_decoding_parameters( MapOutputT = TypeVar("MapOutputT") -def map_future( - f: Callable[[MapInputT], MapOutputT], x: Future[MapInputT] -) -> Future[MapOutputT]: - future = asyncio.run_coroutine_threadsafe(_async_call(f, x), _LOOP) - return future +# def map_future( +# f: Callable[[MapInputT], MapOutputT], x: Future[MapInputT] +# ) -> Future[MapOutputT]: +# future = asyncio.run_coroutine_threadsafe(_async_call(f, x), _LOOP) +# return future -async def _async_call(f, x): - v = x.result() - return f(v) +# async def _async_call(f, x): +# v = x.result() +# return f(v) diff --git a/src/pdl/pdl_scheduler.py b/src/pdl/pdl_scheduler.py index 2ce0d0d56..5e126ae7f 100644 --- a/src/pdl/pdl_scheduler.py +++ b/src/pdl/pdl_scheduler.py @@ -1,3 +1,5 @@ +from asyncio import AbstractEventLoop, new_event_loop, set_event_loop +from threading import Thread from typing import Any, Optional from termcolor import colored @@ -6,6 +8,18 @@ from .pdl_utils import stringify +def _start_background_loop(loop): + set_event_loop(loop) + loop.run_forever() + + +def create_event_loop_thread() -> AbstractEventLoop: + loop = new_event_loop() + loop_thread = Thread(target=_start_background_loop, args=(loop,), daemon=True) + loop_thread.start() + return loop + + def color_of(kind: BlockKind): color: Optional[str] match kind: