Skip to content

Commit 4500f2f

Browse files
authored
Add localized expressions in AST (#130)
1 parent 53004df commit 4500f2f

File tree

10 files changed

+257
-57
lines changed

10 files changed

+257
-57
lines changed

pdl-live/src/pdl_ast.d.ts

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3235,7 +3235,7 @@ export interface CallBlock {
32353235
result?: unknown;
32363236
location?: LocationType | null;
32373237
kind?: Kind18;
3238-
call: Call;
3238+
call: unknown;
32393239
args?: Args;
32403240
trace?: Trace6;
32413241
}
@@ -3654,7 +3654,7 @@ export interface DataBlock {
36543654
result?: unknown;
36553655
location?: LocationType | null;
36563656
kind?: Kind13;
3657-
data: Data;
3657+
data: unknown;
36583658
raw?: Raw;
36593659
}
36603660
/**
@@ -3736,7 +3736,7 @@ export interface IfBlock {
37363736
result?: unknown;
37373737
location?: LocationType | null;
37383738
kind?: Kind12;
3739-
if: If;
3739+
if: unknown;
37403740
then: Then;
37413741
else?: Else;
37423742
if_result?: IfResult;
@@ -3905,7 +3905,7 @@ export interface RepeatUntilBlock {
39053905
location?: LocationType | null;
39063906
kind?: Kind10;
39073907
repeat: Repeat1;
3908-
until: Until;
3908+
until: unknown;
39093909
join?: Join1;
39103910
trace?: Trace2;
39113911
}
@@ -4822,26 +4822,6 @@ export interface JoinArray {
48224822
export interface JoinLastOf {
48234823
as: As2;
48244824
}
4825-
/**
4826-
* Condition of the loop.
4827-
*
4828-
*/
4829-
export interface Until {
4830-
[k: string]: unknown;
4831-
}
4832-
/**
4833-
* Condition.
4834-
*
4835-
*/
4836-
export interface If {
4837-
[k: string]: unknown;
4838-
}
4839-
/**
4840-
* Value defined.
4841-
*/
4842-
export interface Data {
4843-
[k: string]: unknown;
4844-
}
48454825
export interface BamTextGenerationParameters {
48464826
beam_width?: BeamWidth;
48474827
decoding_method?: DecodingMethod | null;
@@ -4954,13 +4934,6 @@ export interface LitellmParameters {
49544934
max_retries?: MaxRetries;
49554935
[k: string]: unknown;
49564936
}
4957-
/**
4958-
* Function to call.
4959-
*
4960-
*/
4961-
export interface Call {
4962-
[k: string]: unknown;
4963-
}
49644937
/**
49654938
* Arguments of the function with their values.
49664939
*

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ dependencies = [
1515
"jsonschema~=4.0",
1616
"litellm~=1.49",
1717
"termcolor~=2.0",
18-
"ipython~=8.0"
18+
"ipython~=8.0",
1919
]
2020
authors = [
2121
{ name="Mandana Vaziri", email="[email protected]" },

src/pdl/pdl-schema.json

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,7 +1077,10 @@
10771077
{
10781078
"type": "string"
10791079
},
1080-
{}
1080+
{},
1081+
{
1082+
"$ref": "#/$defs/LocalizedExpression"
1083+
}
10811084
],
10821085
"title": "Model"
10831086
},
@@ -1359,8 +1362,9 @@
13591362
{
13601363
"$ref": "#/$defs/BamTextGenerationParameters"
13611364
},
1365+
{},
13621366
{
1363-
"type": "object"
1367+
"$ref": "#/$defs/LocalizedExpression"
13641368
},
13651369
{
13661370
"type": "null"
@@ -2081,6 +2085,12 @@
20812085
"type": "string"
20822086
},
20832087
"call": {
2088+
"anyOf": [
2089+
{},
2090+
{
2091+
"$ref": "#/$defs/LocalizedExpression"
2092+
}
2093+
],
20842094
"description": "Function to call.\n ",
20852095
"title": "Call"
20862096
},
@@ -3370,6 +3380,12 @@
33703380
"type": "string"
33713381
},
33723382
"data": {
3383+
"anyOf": [
3384+
{},
3385+
{
3386+
"$ref": "#/$defs/LocalizedExpression"
3387+
}
3388+
],
33733389
"description": "Value defined.",
33743390
"title": "Data"
33753391
},
@@ -4928,6 +4944,14 @@
49284944
"type": "string"
49294945
},
49304946
"for": {
4947+
"additionalProperties": {
4948+
"anyOf": [
4949+
{},
4950+
{
4951+
"$ref": "#/$defs/LocalizedExpression"
4952+
}
4953+
]
4954+
},
49314955
"description": "Arrays to iterate over.\n ",
49324956
"title": "For",
49334957
"type": "object"
@@ -6859,6 +6883,12 @@
68596883
"type": "string"
68606884
},
68616885
"if": {
6886+
"anyOf": [
6887+
{},
6888+
{
6889+
"$ref": "#/$defs/LocalizedExpression"
6890+
}
6891+
],
68626892
"description": "Condition.\n ",
68636893
"title": "If"
68646894
},
@@ -8999,7 +9029,10 @@
89999029
{
90009030
"type": "string"
90019031
},
9002-
{}
9032+
{},
9033+
{
9034+
"$ref": "#/$defs/LocalizedExpression"
9035+
}
90039036
],
90049037
"title": "Model"
90059038
},
@@ -9270,8 +9303,9 @@
92709303
{
92719304
"$ref": "#/$defs/LitellmParameters"
92729305
},
9306+
{},
92739307
{
9274-
"type": "object"
9308+
"$ref": "#/$defs/LocalizedExpression"
92759309
},
92769310
{
92779311
"type": "null"
@@ -9685,6 +9719,31 @@
96859719
"title": "LitellmParameters",
96869720
"type": "object"
96879721
},
9722+
"LocalizedExpression": {
9723+
"additionalProperties": false,
9724+
"description": "Expression with location information",
9725+
"properties": {
9726+
"expr": {
9727+
"title": "Expr"
9728+
},
9729+
"location": {
9730+
"anyOf": [
9731+
{
9732+
"$ref": "#/$defs/LocationType"
9733+
},
9734+
{
9735+
"type": "null"
9736+
}
9737+
],
9738+
"default": null
9739+
}
9740+
},
9741+
"required": [
9742+
"expr"
9743+
],
9744+
"title": "LocalizedExpression",
9745+
"type": "object"
9746+
},
96889747
"LocationType": {
96899748
"additionalProperties": false,
96909749
"properties": {
@@ -12398,6 +12457,9 @@
1239812457
"read": {
1239912458
"anyOf": [
1240012459
{},
12460+
{
12461+
"$ref": "#/$defs/LocalizedExpression"
12462+
},
1240112463
{
1240212464
"type": "null"
1240312465
}
@@ -13917,6 +13979,12 @@
1391713979
"title": "Repeat"
1391813980
},
1391913981
"until": {
13982+
"anyOf": [
13983+
{},
13984+
{
13985+
"$ref": "#/$defs/LocalizedExpression"
13986+
}
13987+
],
1392013988
"description": "Condition of the loop.\n ",
1392113989
"title": "Until"
1392213990
},

src/pdl/pdl_ast.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,6 @@
1616

1717
ScopeType: TypeAlias = dict[str, Any]
1818

19-
ExpressionType: TypeAlias = Any
20-
# (
21-
# str
22-
# | int
23-
# | float
24-
# | bool
25-
# | None
26-
# | list["ExpressionType"]
27-
# | dict[str, "ExpressionType"]
28-
# )
29-
3019

3120
Message: TypeAlias = dict[str, Any]
3221
Messages: TypeAlias = list[Message]
@@ -64,6 +53,28 @@ class LocationType(BaseModel):
6453
empty_block_location = LocationType(file="", path=[], table={})
6554

6655

56+
class LocalizedExpression(BaseModel):
57+
"""Expression with location information"""
58+
59+
model_config = ConfigDict(
60+
extra="forbid", use_attribute_docstrings=True, arbitrary_types_allowed=True
61+
)
62+
expr: Any
63+
location: Optional[LocationType] = None
64+
65+
66+
ExpressionType: TypeAlias = Any | LocalizedExpression
67+
# (
68+
# str
69+
# | int
70+
# | float
71+
# | bool
72+
# | None
73+
# | list["ExpressionType"]
74+
# | dict[str, "ExpressionType"]
75+
# )
76+
77+
6778
class Parser(BaseModel):
6879
model_config = ConfigDict(extra="forbid")
6980
description: Optional[str] = None
@@ -96,7 +107,11 @@ class ContributeValue(BaseModel):
96107
class Block(BaseModel):
97108
"""Common fields for all PDL blocks."""
98109

99-
model_config = ConfigDict(extra="forbid", use_attribute_docstrings=True)
110+
model_config = ConfigDict(
111+
extra="forbid",
112+
use_attribute_docstrings=True,
113+
arbitrary_types_allowed=True,
114+
)
100115

101116
description: Optional[str] = None
102117
"""Documentation associated to the block.
@@ -265,7 +280,7 @@ class ModelBlock(Block):
265280
class BamModelBlock(ModelBlock):
266281
platform: Literal[ModelPlatform.BAM]
267282
prompt_id: Optional[str] = None
268-
parameters: Optional[BamTextGenerationParameters | dict] = None
283+
parameters: Optional[BamTextGenerationParameters | ExpressionType] = None
269284
moderations: Optional[ModerationParameters] = None
270285
data: Optional[PromptTemplateData] = None
271286
constraints: Any = None # TODO
@@ -275,7 +290,7 @@ class LitellmModelBlock(ModelBlock):
275290
"""Call a LLM through the LiteLLM API: https://docs.litellm.ai/."""
276291

277292
platform: Literal[ModelPlatform.LITELLM] = ModelPlatform.LITELLM
278-
parameters: Optional[LitellmParameters | dict] = None
293+
parameters: Optional[LitellmParameters | ExpressionType] = None
279294

280295

281296
class CodeBlock(Block):

src/pdl/pdl_ast_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ def iter_block_children(f: Callable[[BlocksType], None], block: BlockType) -> No
3939
f(blocks)
4040
match block:
4141
case FunctionBlock():
42-
if block.returns is not None:
43-
f(block.returns)
42+
f(block.returns)
4443
case CallBlock():
4544
if block.trace is not None:
4645
f(block.trace)
@@ -208,7 +207,7 @@ def map_block_children(f: MappedFunctions, block: BlockType) -> BlockType:
208207

209208
def map_blocks(f: MappedFunctions, blocks: BlocksType) -> BlocksType:
210209
if not isinstance(blocks, str) and isinstance(blocks, Sequence):
211-
# is a list of blocks
210+
# Is a list of blocks
212211
blocks = [f.f_block(block) for block in blocks]
213212
else:
214213
blocks = f.f_block(blocks)

src/pdl/pdl_compilers/to_regex.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
IncludeBlock,
2020
LitellmModelBlock,
2121
LitellmParameters,
22+
LocalizedExpression,
2223
ModelBlock,
2324
ReadBlock,
2425
RepeatBlock,
@@ -273,10 +274,14 @@ def compile_block(
273274
"include_stop_sequence", False
274275
)
275276
else:
276-
stop_sequences = block.parameters.stop_sequences or []
277+
if isinstance(block.parameters, LocalizedExpression):
278+
parameters = block.parameters.expr
279+
else:
280+
parameters = block.parameters
281+
stop_sequences = parameters.stop_sequences or []
277282
include_stop_sequence = (
278-
block.parameters.include_stop_sequence is None
279-
or block.parameters.include_stop_sequence
283+
parameters.include_stop_sequence is None
284+
or parameters.include_stop_sequence
280285
)
281286
case LitellmModelBlock():
282287
if block.parameters is None:
@@ -285,6 +290,8 @@ def compile_block(
285290
else:
286291
if isinstance(block.parameters, LitellmParameters):
287292
parameters = block.parameters.model_dump()
293+
elif isinstance(block.parameters, LocalizedExpression):
294+
parameters = block.parameters.expr
288295
else:
289296
parameters = block.parameters
290297
stop_sequences = parameters.get("stop", [])

src/pdl/pdl_dumper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ def blocks_to_dict(
253253
) -> DumpedBlockType | list[DumpedBlockType]:
254254
result: DumpedBlockType | list[DumpedBlockType]
255255
if not isinstance(blocks, str) and isinstance(blocks, Sequence):
256+
# Is a list of blocks
256257
result = [block_to_dict(block, json_compatible) for block in blocks]
257258
else:
258259
result = block_to_dict(blocks, json_compatible)

0 commit comments

Comments
 (0)