diff --git a/src/pdl/pdl_context.py b/src/pdl/pdl_context.py index 074decd3c..e59b0e080 100644 --- a/src/pdl/pdl_context.py +++ b/src/pdl/pdl_context.py @@ -2,7 +2,7 @@ from enum import StrEnum from typing import Any, Callable -from .pdl_lazy import PdlApply, PdlDict, PdlLazy, PdlList +from .pdl_lazy import PdlApply, PdlConst, PdlDict, PdlLazy, PdlList # def _default(self, obj): # return getattr(obj.__class__, "to_json", _default.default)(obj) # pyright: ignore @@ -136,6 +136,20 @@ def __repr__(self): # pyright: ignore return ret + "]" +def ensure_context(context: dict | list | PDLContext) -> PDLContext: + ctx: PDLContext + match context: + case dict(): + ctx = SingletonContext(PdlConst(context)) + case list(): + ctx = DependentContext([ensure_context(c) for c in context]) + case PDLContext(): + ctx = context + case _: + raise TypeError(f"'{type(context)}' object is not a valid context") + return ctx + + def deserialize( context: list[dict[str, Any]], ) -> DependentContext: # Only support dependent for now diff --git a/src/pdl/pdl_interpreter.py b/src/pdl/pdl_interpreter.py index 7dd55cb4c..9c9718202 100644 --- a/src/pdl/pdl_interpreter.py +++ b/src/pdl/pdl_interpreter.py @@ -101,6 +101,7 @@ SingletonContext, add_done_callback, deserialize, + ensure_context, ) from .pdl_dumper import as_json, block_to_dict # noqa: E402 from .pdl_lazy import PdlConst, PdlDict, PdlLazy, PdlList, lazy_apply # noqa: E402 @@ -1413,23 +1414,28 @@ def process_call_model( scope, loc, ) - model_input_result = model_input_future.result() - if isinstance(model_input_result, str): - model_input = [{"role": state.role, "content": model_input_result}] - else: - if isinstance(block, LitellmModelBlock): - model_input = model_input_result.serialize(SerializeMode.LITELLM) - else: - model_input = model_input_result.serialize(SerializeMode.GRANITEIO) - concrete_block = concrete_block.model_copy( - update={ - "pdl__model_input": model_input, - } - ) - - model_input = [{k: v for k, v in m.items() if k != "defsite"} for m in model_input] - # Execute model call try: + model_input_result = model_input_future.result() + if isinstance(model_input_result, str): + model_input_result = [{"role": state.role, "content": model_input_result}] + model_input_context = ensure_context(model_input_result) + match block: + case LitellmModelBlock(): + model_input = model_input_context.serialize(SerializeMode.LITELLM) + case GraniteioModelBlock(): + model_input = model_input_context.serialize(SerializeMode.GRANITEIO) + case _: + assert False + concrete_block = concrete_block.model_copy( + update={ + "pdl__model_input": model_input, + } + ) + model_input = [ + {k: v for k, v in m.items() if k != "defsite"} for m in model_input + ] + + # Execute model call litellm_params = {} def get_transformed_inputs(kwargs):