Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/pdl/pdl_granite_io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# pylint: disable=import-outside-toplevel
from asyncio import run_coroutine_threadsafe
from asyncio import AbstractEventLoop, run_coroutine_threadsafe
from typing import Any, Optional

from granite_io.types import ChatCompletionInputs
Expand All @@ -13,7 +13,6 @@
PDLRuntimeError,
)
from .pdl_lazy import PdlConst, PdlLazy, lazy_apply
from .pdl_llms import _LOOP
from .pdl_utils import value_of_expr


Expand Down Expand Up @@ -113,15 +112,14 @@ async def async_generate_text(

@staticmethod
def generate_text(
block: GraniteioModelBlock,
messages: ModelInput,
block: GraniteioModelBlock, messages: ModelInput, event_loop: AbstractEventLoop
) -> tuple[LazyMessage, PdlLazy[Any]]:
future = run_coroutine_threadsafe(
GraniteioModel.async_generate_text(
block,
messages,
),
_LOOP,
event_loop,
)
pdl_future: PdlLazy[tuple[dict[str, Any], Any]] = PdlConst(future)
message = lazy_apply((lambda x: x[0]), pdl_future)
Expand Down
14 changes: 12 additions & 2 deletions src/pdl/pdl_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

# TODO: temporarily disabling warnings to mute a pydantic warning from liteLLM
import warnings
from asyncio import AbstractEventLoop
from functools import partial
from os import getenv

Expand All @@ -31,7 +32,7 @@
)
from jinja2.nodes import TemplateData # noqa: E402
from jinja2.runtime import Undefined # noqa: E402
from pydantic import BaseModel # noqa: E402
from pydantic import BaseModel, ConfigDict, Field # noqa: E402

from .pdl_ast import ( # noqa: E402
AdvancedBlockType,
Expand Down Expand Up @@ -110,7 +111,11 @@
from .pdl_location_utils import append, get_loc_string # noqa: E402
from .pdl_parser import PDLParseError, parse_file, parse_str # noqa: E402
from .pdl_python_repl import PythonREPL # noqa: E402
from .pdl_scheduler import yield_background, yield_result # noqa: E402
from .pdl_scheduler import ( # noqa: E402
create_event_loop_thread,
yield_background,
yield_result,
)
from .pdl_schema_utils import get_json_schema # noqa: E402
from .pdl_schema_validator import type_check_args, type_check_spec # noqa: E402
from .pdl_utils import ( # noqa: E402
Expand All @@ -127,6 +132,8 @@


class InterpreterState(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

yield_result: bool = False
yield_background: bool = False
batch: int = 1
Expand All @@ -136,6 +143,7 @@ class InterpreterState(BaseModel):
cwd: Path = Path.cwd()
# background_tasks = {}
id_stack: list[str] = []
event_loop: AbstractEventLoop = Field(default_factory=create_event_loop_thread)

def with_yield_result(self: "InterpreterState", b: bool) -> "InterpreterState":
return self.model_copy(update={"yield_result": b})
Expand Down Expand Up @@ -1638,13 +1646,15 @@ def generate_client_response_single(
model_id=value_of_expr(block.model),
messages=model_input,
parameters=litellm_parameters_to_dict(parameters),
event_loop=state.event_loop,
)
case GraniteioModelBlock():
from .pdl_granite_io import GraniteioModel

message, response = GraniteioModel.generate_text(
block=block,
messages=model_input,
event_loop=state.event_loop,
)
case _:
assert False
Expand Down
43 changes: 13 additions & 30 deletions src/pdl/pdl_llms.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# pylint: disable=import-outside-toplevel
import asyncio
import threading
from concurrent.futures import Future
from asyncio import AbstractEventLoop, run_coroutine_threadsafe
from os import environ
from sys import stderr
from typing import Any, Callable, Generator, Optional, TypeVar
from typing import Any, Generator, Optional, TypeVar

import httpx
from dotenv import load_dotenv
Expand All @@ -25,19 +23,6 @@
load_dotenv()


def _start_background_loop(loop):
asyncio.set_event_loop(loop)
loop.run_forever()


_LOOP = asyncio.new_event_loop()
_LOOP_THREAD = threading.Thread(
target=_start_background_loop, args=(_LOOP,), daemon=True
)
_LOOP_THREAD.start()
# _BACKGROUND_TASKS = set()


class LitellmModel:
@staticmethod
async def async_generate_text(
Expand Down Expand Up @@ -88,21 +73,19 @@ def generate_text(
model_id: str,
messages: ModelInput,
parameters: dict[str, Any],
event_loop: AbstractEventLoop,
) -> tuple[LazyMessage, PdlLazy[Any]]:
if "PDL_VERBOSE_ASYNC" in environ:
print(f"Asynchronous model call started to {model_id}", file=stderr)
# global _BACKGROUND_TASKS
future = asyncio.run_coroutine_threadsafe(
future = run_coroutine_threadsafe(
LitellmModel.async_generate_text(
block,
model_id,
messages,
parameters,
),
_LOOP,
event_loop,
)
# _BACKGROUND_TASKS.add(future)
# future.add_done_callback(_BACKGROUND_TASKS.discard)
pdl_future: PdlLazy[tuple[dict[str, Any], Any]] = PdlConst(future)
message = lazy_apply((lambda x: x[0]), pdl_future)
response = lazy_apply((lambda x: x[1]), pdl_future)
Expand Down Expand Up @@ -213,13 +196,13 @@ def set_structured_decoding_parameters(
MapOutputT = TypeVar("MapOutputT")


def map_future(
f: Callable[[MapInputT], MapOutputT], x: Future[MapInputT]
) -> Future[MapOutputT]:
future = asyncio.run_coroutine_threadsafe(_async_call(f, x), _LOOP)
return future
# def map_future(
# f: Callable[[MapInputT], MapOutputT], x: Future[MapInputT]
# ) -> Future[MapOutputT]:
# future = asyncio.run_coroutine_threadsafe(_async_call(f, x), _LOOP)
# return future


async def _async_call(f, x):
v = x.result()
return f(v)
# async def _async_call(f, x):
# v = x.result()
# return f(v)
14 changes: 14 additions & 0 deletions src/pdl/pdl_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from asyncio import AbstractEventLoop, new_event_loop, set_event_loop
from threading import Thread
from typing import Any, Optional

from termcolor import colored
Expand All @@ -6,6 +8,18 @@
from .pdl_utils import stringify


def _start_background_loop(loop):
set_event_loop(loop)
loop.run_forever()


def create_event_loop_thread() -> AbstractEventLoop:
loop = new_event_loop()
loop_thread = Thread(target=_start_background_loop, args=(loop,), daemon=True)
loop_thread.start()
return loop


def color_of(kind: BlockKind):
color: Optional[str]
match kind:
Expand Down