Skip to content

Commit c1d40a5

Browse files
authored
Add support of PDL code blocks (#185)
Close issue #88
1 parent 0a3b1d1 commit c1d40a5

File tree

9 files changed

+87
-6
lines changed

9 files changed

+87
-6
lines changed

examples/hello/hello-code-jinja.pdl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
description: Hello world showing call out to shell command
1+
description: Hello world showing call out to Jinja
22
defs:
33
world: "World"
44
lang: jinja

examples/hello/hello-code-pdl.pdl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
description: Hello world showing call out to PDL
2+
lang: pdl
3+
code: |
4+
description: Hello world
5+
text:
6+
- "Hello\n"
7+
- model: replicate/ibm-granite/granite-3.0-8b-instruct
8+

examples/hello/hello-code.pdl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,3 @@ text:
77
import string
88
result = random.choice(string.ascii_lowercase)
99
- '!'
10-

pdl-live/src/pdl_ast.d.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2322,7 +2322,7 @@ export type Kind15 = "code";
23222322
* Programming language of the code.
23232323
*
23242324
*/
2325-
export type Lang = "python" | "command" | "jinja";
2325+
export type Lang = "python" | "command" | "jinja" | "pdl";
23262326
/**
23272327
* Code to execute.
23282328
*

src/pdl/pdl-schema.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2692,7 +2692,8 @@
26922692
"enum": [
26932693
"python",
26942694
"command",
2695-
"jinja"
2695+
"jinja",
2696+
"pdl"
26962697
],
26972698
"title": "Lang",
26982699
"type": "string"

src/pdl/pdl_ast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ class CodeBlock(Block):
281281
"""Execute a piece of code."""
282282

283283
kind: Literal[BlockKind.CODE] = BlockKind.CODE
284-
lang: Literal["python", "command", "jinja"]
284+
lang: Literal["python", "command", "jinja", "pdl"]
285285
"""Programming language of the code.
286286
"""
287287
code: "BlocksType"

src/pdl/pdl_interpreter.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
from .pdl_dumper import blocks_to_dict
6969
from .pdl_llms import BamModel, LitellmModel
7070
from .pdl_location_utils import append, get_loc_string
71-
from .pdl_parser import PDLParseError, parse_file
71+
from .pdl_parser import PDLParseError, parse_file, parse_str
7272
from .pdl_scheduler import (
7373
CodeYieldResultMessage,
7474
GeneratorWrapper,
@@ -1271,6 +1271,16 @@ def step_call_code(
12711271
loc=loc,
12721272
trace=block.model_copy(update={"code": code_s}),
12731273
) from exc
1274+
case "pdl":
1275+
try:
1276+
result = call_pdl(code_s, scope)
1277+
background = [{"role": state.role, "content": result}]
1278+
except Exception as exc:
1279+
raise PDLRuntimeError(
1280+
f"Code error: {repr(exc)}",
1281+
loc=loc,
1282+
trace=block.model_copy(update={"code": code_s}),
1283+
) from exc
12741284
case _:
12751285
message = f"Unsupported language: {block.lang}"
12761286
raise PDLRuntimeError(
@@ -1318,6 +1328,13 @@ def call_jinja(code: str, scope: dict) -> Any:
13181328
return result
13191329

13201330

1331+
def call_pdl(code: str, scope: dict) -> Any:
1332+
program, loc = parse_str(code)
1333+
state = InterpreterState()
1334+
result, _, _, _ = process_prog(state, scope, program, loc)
1335+
return result
1336+
1337+
13211338
def step_call(
13221339
state: InterpreterState, scope: ScopeType, block: CallBlock, loc: LocationType
13231340
) -> Generator[YieldMessage, Any, tuple[Any, Messages, ScopeType, CallBlock]]:
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Hello
2+
Hello! How can I assist you today?

tests/test_code.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,57 @@ def test_jinja3():
123123
bad
124124
good"""
125125
)
126+
127+
128+
def test_pdl1():
129+
prog_str = """
130+
lang: pdl
131+
code: |
132+
description: Hello world
133+
text:
134+
- "Hello World!"
135+
"""
136+
result = exec_str(prog_str)
137+
assert result == "Hello World!"
138+
139+
140+
def test_pdl2():
141+
prog_str = """
142+
defs:
143+
w: World
144+
lang: pdl
145+
code: |
146+
description: Hello world
147+
text:
148+
- "Hello ${w}!"
149+
"""
150+
result = exec_str(prog_str)
151+
assert result == "Hello World!"
152+
153+
154+
def test_pdl3():
155+
prog_str = """
156+
defs:
157+
x:
158+
code: "result = print"
159+
lang: python
160+
lang: pdl
161+
code: |
162+
data: ${x}
163+
"""
164+
result = exec_str(prog_str)
165+
assert result == "<built-in function print>"
166+
167+
168+
def test_pdl4():
169+
prog_str = """
170+
defs:
171+
x:
172+
code: "result = print"
173+
lang: python
174+
lang: pdl
175+
code: |
176+
data: ${ "${" }x ${ "}" }
177+
"""
178+
result = exec_str(prog_str)
179+
assert result == print # pylint: disable=comparison-with-callable

0 commit comments

Comments
 (0)