|
68 | 68 | from .pdl_dumper import blocks_to_dict |
69 | 69 | from .pdl_llms import BamModel, LitellmModel |
70 | 70 | from .pdl_location_utils import append, get_loc_string |
71 | | -from .pdl_parser import PDLParseError, parse_file |
| 71 | +from .pdl_parser import PDLParseError, parse_file, parse_str |
72 | 72 | from .pdl_scheduler import ( |
73 | 73 | CodeYieldResultMessage, |
74 | 74 | GeneratorWrapper, |
@@ -1271,6 +1271,16 @@ def step_call_code( |
1271 | 1271 | loc=loc, |
1272 | 1272 | trace=block.model_copy(update={"code": code_s}), |
1273 | 1273 | ) from exc |
| 1274 | + case "pdl": |
| 1275 | + try: |
| 1276 | + result = call_pdl(code_s, scope) |
| 1277 | + background = [{"role": state.role, "content": result}] |
| 1278 | + except Exception as exc: |
| 1279 | + raise PDLRuntimeError( |
| 1280 | + f"Code error: {repr(exc)}", |
| 1281 | + loc=loc, |
| 1282 | + trace=block.model_copy(update={"code": code_s}), |
| 1283 | + ) from exc |
1274 | 1284 | case _: |
1275 | 1285 | message = f"Unsupported language: {block.lang}" |
1276 | 1286 | raise PDLRuntimeError( |
@@ -1318,6 +1328,13 @@ def call_jinja(code: str, scope: dict) -> Any: |
1318 | 1328 | return result |
1319 | 1329 |
|
1320 | 1330 |
|
| 1331 | +def call_pdl(code: str, scope: dict) -> Any: |
| 1332 | + program, loc = parse_str(code) |
| 1333 | + state = InterpreterState() |
| 1334 | + result, _, _, _ = process_prog(state, scope, program, loc) |
| 1335 | + return result |
| 1336 | + |
| 1337 | + |
1321 | 1338 | def step_call( |
1322 | 1339 | state: InterpreterState, scope: ScopeType, block: CallBlock, loc: LocationType |
1323 | 1340 | ) -> Generator[YieldMessage, Any, tuple[Any, Messages, ScopeType, CallBlock]]: |
|
0 commit comments