From 622cea4de6005e6e4550ddd4e81e8d9442ef63c7 Mon Sep 17 00:00:00 2001 From: Louis Mandel Date: Wed, 4 Jun 2025 15:17:49 -0400 Subject: [PATCH 1/4] fix: ensure that model inputs are always contexts Signed-off-by: Louis Mandel --- src/pdl/pdl_context.py | 13 ++++++++++++- src/pdl/pdl_interpreter.py | 34 ++++++++++++++++++---------------- 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/src/pdl/pdl_context.py b/src/pdl/pdl_context.py index 074decd3c..db83ca2c6 100644 --- a/src/pdl/pdl_context.py +++ b/src/pdl/pdl_context.py @@ -2,7 +2,7 @@ from enum import StrEnum from typing import Any, Callable -from .pdl_lazy import PdlApply, PdlDict, PdlLazy, PdlList +from .pdl_lazy import PdlApply, PdlConst, PdlDict, PdlLazy, PdlList # def _default(self, obj): # return getattr(obj.__class__, "to_json", _default.default)(obj) # pyright: ignore @@ -135,6 +135,17 @@ def __repr__(self): # pyright: ignore ret += ",".join([i.__repr__() for i in self.context.result()]) return ret + "]" +def ensure_context(context: dict | list | PDLContext) -> PDLContext: + match context: + case dict(): + ctx = SingletonContext(PdlConst(context)) + case list(): + ctx = DependentContext([ensure_context(c) for c in context]) + case PDLContext(): + ctx = context + case _: + raise TypeError(f"'{type(context)}' object is not a valid context") + return ctx def deserialize( context: list[dict[str, Any]], diff --git a/src/pdl/pdl_interpreter.py b/src/pdl/pdl_interpreter.py index 7dd55cb4c..e724f41b6 100644 --- a/src/pdl/pdl_interpreter.py +++ b/src/pdl/pdl_interpreter.py @@ -101,6 +101,7 @@ SingletonContext, add_done_callback, deserialize, + ensure_context, ) from .pdl_dumper import as_json, block_to_dict # noqa: E402 from .pdl_lazy import PdlConst, PdlDict, PdlLazy, PdlList, lazy_apply # noqa: E402 @@ -1413,23 +1414,24 @@ def process_call_model( scope, loc, ) - model_input_result = model_input_future.result() - if isinstance(model_input_result, str): - model_input = [{"role": state.role, "content": model_input_result}] - else: - if isinstance(block, LitellmModelBlock): - model_input = model_input_result.serialize(SerializeMode.LITELLM) - else: - model_input = model_input_result.serialize(SerializeMode.GRANITEIO) - concrete_block = concrete_block.model_copy( - update={ - "pdl__model_input": model_input, - } - ) - - model_input = [{k: v for k, v in m.items() if k != "defsite"} for m in model_input] - # Execute model call try: + model_input_result = model_input_future.result() + model_input_context = ensure_context(model_input_result) + match block: + case LitellmModelBlock(): + model_input = model_input_context.serialize(SerializeMode.LITELLM) + case GraniteioModelBlock(): + model_input = model_input_context.serialize(SerializeMode.GRANITEIO) + case _: + assert False + concrete_block = concrete_block.model_copy( + update={ + "pdl__model_input": model_input, + } + ) + model_input = [{k: v for k, v in m.items() if k != "defsite"} for m in model_input] + + # Execute model call litellm_params = {} def get_transformed_inputs(kwargs): From e371836a94b2622014d8a1f6bd74c7803736e905 Mon Sep 17 00:00:00 2001 From: Louis Mandel Date: Wed, 4 Jun 2025 15:25:12 -0400 Subject: [PATCH 2/4] Formatting and typing Signed-off-by: Louis Mandel --- src/pdl/pdl_context.py | 2 ++ src/pdl/pdl_interpreter.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/pdl/pdl_context.py b/src/pdl/pdl_context.py index db83ca2c6..2dda7d003 100644 --- a/src/pdl/pdl_context.py +++ b/src/pdl/pdl_context.py @@ -135,6 +135,7 @@ def __repr__(self): # pyright: ignore ret += ",".join([i.__repr__() for i in self.context.result()]) return ret + "]" + def ensure_context(context: dict | list | PDLContext) -> PDLContext: match context: case dict(): @@ -147,6 +148,7 @@ def ensure_context(context: dict | list | PDLContext) -> PDLContext: raise TypeError(f"'{type(context)}' object is not a valid context") return ctx + def deserialize( context: list[dict[str, Any]], ) -> DependentContext: # Only support dependent for now diff --git a/src/pdl/pdl_interpreter.py b/src/pdl/pdl_interpreter.py index e724f41b6..d0f9b28aa 100644 --- a/src/pdl/pdl_interpreter.py +++ b/src/pdl/pdl_interpreter.py @@ -1429,7 +1429,9 @@ def process_call_model( "pdl__model_input": model_input, } ) - model_input = [{k: v for k, v in m.items() if k != "defsite"} for m in model_input] + model_input = [ + {k: v for k, v in m.items() if k != "defsite"} for m in model_input + ] # Execute model call litellm_params = {} From e3506dc17790ec1f227950197226deab4dbf0812 Mon Sep 17 00:00:00 2001 From: Louis Mandel Date: Wed, 4 Jun 2025 15:34:32 -0400 Subject: [PATCH 3/4] Formatting and typing Signed-off-by: Louis Mandel --- src/pdl/pdl_context.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pdl/pdl_context.py b/src/pdl/pdl_context.py index 2dda7d003..e59b0e080 100644 --- a/src/pdl/pdl_context.py +++ b/src/pdl/pdl_context.py @@ -137,6 +137,7 @@ def __repr__(self): # pyright: ignore def ensure_context(context: dict | list | PDLContext) -> PDLContext: + ctx: PDLContext match context: case dict(): ctx = SingletonContext(PdlConst(context)) From aad3ccf2e3ab77604587295b66a3430d98f124f4 Mon Sep 17 00:00:00 2001 From: Louis Mandel Date: Wed, 4 Jun 2025 15:45:50 -0400 Subject: [PATCH 4/4] fix Signed-off-by: Louis Mandel --- src/pdl/pdl_interpreter.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/pdl/pdl_interpreter.py b/src/pdl/pdl_interpreter.py index d0f9b28aa..9c9718202 100644 --- a/src/pdl/pdl_interpreter.py +++ b/src/pdl/pdl_interpreter.py @@ -1416,6 +1416,8 @@ def process_call_model( ) try: model_input_result = model_input_future.result() + if isinstance(model_input_result, str): + model_input_result = [{"role": state.role, "content": model_input_result}] model_input_context = ensure_context(model_input_result) match block: case LitellmModelBlock():