Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/pdl/pdl_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import StrEnum
from typing import Any, Callable

from .pdl_lazy import PdlApply, PdlDict, PdlLazy, PdlList
from .pdl_lazy import PdlApply, PdlConst, PdlDict, PdlLazy, PdlList

# def _default(self, obj):
# return getattr(obj.__class__, "to_json", _default.default)(obj) # pyright: ignore
Expand Down Expand Up @@ -136,6 +136,20 @@ def __repr__(self): # pyright: ignore
return ret + "]"


def ensure_context(context: dict | list | PDLContext) -> PDLContext:
ctx: PDLContext
match context:
case dict():
ctx = SingletonContext(PdlConst(context))
case list():
ctx = DependentContext([ensure_context(c) for c in context])
case PDLContext():
ctx = context
case _:
raise TypeError(f"'{type(context)}' object is not a valid context")
return ctx


def deserialize(
context: list[dict[str, Any]],
) -> DependentContext: # Only support dependent for now
Expand Down
38 changes: 22 additions & 16 deletions src/pdl/pdl_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
SingletonContext,
add_done_callback,
deserialize,
ensure_context,
)
from .pdl_dumper import as_json, block_to_dict # noqa: E402
from .pdl_lazy import PdlConst, PdlDict, PdlLazy, PdlList, lazy_apply # noqa: E402
Expand Down Expand Up @@ -1413,23 +1414,28 @@ def process_call_model(
scope,
loc,
)
model_input_result = model_input_future.result()
if isinstance(model_input_result, str):
model_input = [{"role": state.role, "content": model_input_result}]
else:
if isinstance(block, LitellmModelBlock):
model_input = model_input_result.serialize(SerializeMode.LITELLM)
else:
model_input = model_input_result.serialize(SerializeMode.GRANITEIO)
concrete_block = concrete_block.model_copy(
update={
"pdl__model_input": model_input,
}
)

model_input = [{k: v for k, v in m.items() if k != "defsite"} for m in model_input]
# Execute model call
try:
model_input_result = model_input_future.result()
if isinstance(model_input_result, str):
model_input_result = [{"role": state.role, "content": model_input_result}]
model_input_context = ensure_context(model_input_result)
match block:
case LitellmModelBlock():
model_input = model_input_context.serialize(SerializeMode.LITELLM)
case GraniteioModelBlock():
model_input = model_input_context.serialize(SerializeMode.GRANITEIO)
case _:
assert False
concrete_block = concrete_block.model_copy(
update={
"pdl__model_input": model_input,
}
)
model_input = [
{k: v for k, v in m.items() if k != "defsite"} for m in model_input
]

# Execute model call
litellm_params = {}

def get_transformed_inputs(kwargs):
Expand Down