|
101 | 101 | SingletonContext, |
102 | 102 | add_done_callback, |
103 | 103 | deserialize, |
| 104 | + ensure_context, |
104 | 105 | ) |
105 | 106 | from .pdl_dumper import as_json, block_to_dict # noqa: E402 |
106 | 107 | from .pdl_lazy import PdlConst, PdlDict, PdlLazy, PdlList, lazy_apply # noqa: E402 |
@@ -1413,23 +1414,28 @@ def process_call_model( |
1413 | 1414 | scope, |
1414 | 1415 | loc, |
1415 | 1416 | ) |
1416 | | - model_input_result = model_input_future.result() |
1417 | | - if isinstance(model_input_result, str): |
1418 | | - model_input = [{"role": state.role, "content": model_input_result}] |
1419 | | - else: |
1420 | | - if isinstance(block, LitellmModelBlock): |
1421 | | - model_input = model_input_result.serialize(SerializeMode.LITELLM) |
1422 | | - else: |
1423 | | - model_input = model_input_result.serialize(SerializeMode.GRANITEIO) |
1424 | | - concrete_block = concrete_block.model_copy( |
1425 | | - update={ |
1426 | | - "pdl__model_input": model_input, |
1427 | | - } |
1428 | | - ) |
1429 | | - |
1430 | | - model_input = [{k: v for k, v in m.items() if k != "defsite"} for m in model_input] |
1431 | | - # Execute model call |
1432 | 1417 | try: |
| 1418 | + model_input_result = model_input_future.result() |
| 1419 | + if isinstance(model_input_result, str): |
| 1420 | + model_input_result = [{"role": state.role, "content": model_input_result}] |
| 1421 | + model_input_context = ensure_context(model_input_result) |
| 1422 | + match block: |
| 1423 | + case LitellmModelBlock(): |
| 1424 | + model_input = model_input_context.serialize(SerializeMode.LITELLM) |
| 1425 | + case GraniteioModelBlock(): |
| 1426 | + model_input = model_input_context.serialize(SerializeMode.GRANITEIO) |
| 1427 | + case _: |
| 1428 | + assert False |
| 1429 | + concrete_block = concrete_block.model_copy( |
| 1430 | + update={ |
| 1431 | + "pdl__model_input": model_input, |
| 1432 | + } |
| 1433 | + ) |
| 1434 | + model_input = [ |
| 1435 | + {k: v for k, v in m.items() if k != "defsite"} for m in model_input |
| 1436 | + ] |
| 1437 | + |
| 1438 | + # Execute model call |
1433 | 1439 | litellm_params = {} |
1434 | 1440 |
|
1435 | 1441 | def get_transformed_inputs(kwargs): |
|
0 commit comments