Skip to content

Commit 8bc69de

Browse files
authored
fix: ensure that model inputs are always contexts (#958)
Signed-off-by: Louis Mandel <[email protected]>
1 parent 81dfb1a commit 8bc69de

File tree

2 files changed

+37
-17
lines changed

2 files changed

+37
-17
lines changed

src/pdl/pdl_context.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from enum import StrEnum
33
from typing import Any, Callable
44

5-
from .pdl_lazy import PdlApply, PdlDict, PdlLazy, PdlList
5+
from .pdl_lazy import PdlApply, PdlConst, PdlDict, PdlLazy, PdlList
66

77
# def _default(self, obj):
88
# return getattr(obj.__class__, "to_json", _default.default)(obj) # pyright: ignore
@@ -136,6 +136,20 @@ def __repr__(self): # pyright: ignore
136136
return ret + "]"
137137

138138

139+
def ensure_context(context: dict | list | PDLContext) -> PDLContext:
140+
ctx: PDLContext
141+
match context:
142+
case dict():
143+
ctx = SingletonContext(PdlConst(context))
144+
case list():
145+
ctx = DependentContext([ensure_context(c) for c in context])
146+
case PDLContext():
147+
ctx = context
148+
case _:
149+
raise TypeError(f"'{type(context)}' object is not a valid context")
150+
return ctx
151+
152+
139153
def deserialize(
140154
context: list[dict[str, Any]],
141155
) -> DependentContext: # Only support dependent for now

src/pdl/pdl_interpreter.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
SingletonContext,
102102
add_done_callback,
103103
deserialize,
104+
ensure_context,
104105
)
105106
from .pdl_dumper import as_json, block_to_dict # noqa: E402
106107
from .pdl_lazy import PdlConst, PdlDict, PdlLazy, PdlList, lazy_apply # noqa: E402
@@ -1413,23 +1414,28 @@ def process_call_model(
14131414
scope,
14141415
loc,
14151416
)
1416-
model_input_result = model_input_future.result()
1417-
if isinstance(model_input_result, str):
1418-
model_input = [{"role": state.role, "content": model_input_result}]
1419-
else:
1420-
if isinstance(block, LitellmModelBlock):
1421-
model_input = model_input_result.serialize(SerializeMode.LITELLM)
1422-
else:
1423-
model_input = model_input_result.serialize(SerializeMode.GRANITEIO)
1424-
concrete_block = concrete_block.model_copy(
1425-
update={
1426-
"pdl__model_input": model_input,
1427-
}
1428-
)
1429-
1430-
model_input = [{k: v for k, v in m.items() if k != "defsite"} for m in model_input]
1431-
# Execute model call
14321417
try:
1418+
model_input_result = model_input_future.result()
1419+
if isinstance(model_input_result, str):
1420+
model_input_result = [{"role": state.role, "content": model_input_result}]
1421+
model_input_context = ensure_context(model_input_result)
1422+
match block:
1423+
case LitellmModelBlock():
1424+
model_input = model_input_context.serialize(SerializeMode.LITELLM)
1425+
case GraniteioModelBlock():
1426+
model_input = model_input_context.serialize(SerializeMode.GRANITEIO)
1427+
case _:
1428+
assert False
1429+
concrete_block = concrete_block.model_copy(
1430+
update={
1431+
"pdl__model_input": model_input,
1432+
}
1433+
)
1434+
model_input = [
1435+
{k: v for k, v in m.items() if k != "defsite"} for m in model_input
1436+
]
1437+
1438+
# Execute model call
14331439
litellm_params = {}
14341440

14351441
def get_transformed_inputs(kwargs):

0 commit comments

Comments
 (0)