Skip to content

Commit 457dc99

Browse files
committed
Merge branch 'main' into pdl-199
2 parents dc6192f + 4500f2f commit 457dc99

File tree

9 files changed

+255
-57
lines changed

9 files changed

+255
-57
lines changed

pdl-live/src/pdl_ast.d.ts

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2094,7 +2094,7 @@ export interface CallBlock {
20942094
result?: unknown;
20952095
location?: LocationType | null;
20962096
kind?: Kind18;
2097-
call: Call;
2097+
call: unknown;
20982098
args?: Args;
20992099
trace?: Trace6;
21002100
}
@@ -2383,7 +2383,7 @@ export interface DataBlock {
23832383
result?: unknown;
23842384
location?: LocationType | null;
23852385
kind?: Kind13;
2386-
data: Data;
2386+
data: unknown;
23872387
raw?: Raw;
23882388
}
23892389
/**
@@ -2439,7 +2439,7 @@ export interface IfBlock {
24392439
result?: unknown;
24402440
location?: LocationType | null;
24412441
kind?: Kind12;
2442-
if: If;
2442+
if: unknown;
24432443
then: Then;
24442444
else?: Else;
24452445
if_result?: IfResult;
@@ -2556,7 +2556,7 @@ export interface RepeatUntilBlock {
25562556
location?: LocationType | null;
25572557
kind?: Kind10;
25582558
repeat: Repeat1;
2559-
until: Until;
2559+
until: unknown;
25602560
join?: Join1;
25612561
trace?: Trace2;
25622562
}
@@ -3187,26 +3187,6 @@ export interface JoinArray {
31873187
export interface JoinLastOf {
31883188
as: As2;
31893189
}
3190-
/**
3191-
* Condition of the loop.
3192-
*
3193-
*/
3194-
export interface Until {
3195-
[k: string]: unknown;
3196-
}
3197-
/**
3198-
* Condition.
3199-
*
3200-
*/
3201-
export interface If {
3202-
[k: string]: unknown;
3203-
}
3204-
/**
3205-
* Value defined.
3206-
*/
3207-
export interface Data {
3208-
[k: string]: unknown;
3209-
}
32103190
export interface BamTextGenerationParameters {
32113191
beam_width?: BeamWidth;
32123192
decoding_method?: DecodingMethod | null;
@@ -3319,13 +3299,6 @@ export interface LitellmParameters {
33193299
max_retries?: MaxRetries;
33203300
[k: string]: unknown;
33213301
}
3322-
/**
3323-
* Function to call.
3324-
*
3325-
*/
3326-
export interface Call {
3327-
[k: string]: unknown;
3328-
}
33293302
/**
33303303
* Arguments of the function with their values.
33313304
*

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ dependencies = [
1515
"jsonschema~=4.0",
1616
"litellm~=1.49",
1717
"termcolor~=2.0",
18-
"ipython~=8.0"
18+
"ipython~=8.0",
1919
]
2020
authors = [
2121
{ name="Mandana Vaziri", email="[email protected]" },

src/pdl/pdl-schema.json

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,10 @@
670670
{
671671
"type": "string"
672672
},
673-
{}
673+
{},
674+
{
675+
"$ref": "#/$defs/LocalizedExpression"
676+
}
674677
],
675678
"title": "Model"
676679
},
@@ -870,8 +873,9 @@
870873
{
871874
"$ref": "#/$defs/BamTextGenerationParameters"
872875
},
876+
{},
873877
{
874-
"type": "object"
878+
"$ref": "#/$defs/LocalizedExpression"
875879
},
876880
{
877881
"type": "null"
@@ -1428,6 +1432,12 @@
14281432
"type": "string"
14291433
},
14301434
"call": {
1435+
"anyOf": [
1436+
{},
1437+
{
1438+
"$ref": "#/$defs/LocalizedExpression"
1439+
}
1440+
],
14311441
"description": "Function to call.\n ",
14321442
"title": "Call"
14331443
},
@@ -2225,6 +2235,12 @@
22252235
"type": "string"
22262236
},
22272237
"data": {
2238+
"anyOf": [
2239+
{},
2240+
{
2241+
"$ref": "#/$defs/LocalizedExpression"
2242+
}
2243+
],
22282244
"description": "Value defined.",
22292245
"title": "Data"
22302246
},
@@ -3209,6 +3225,14 @@
32093225
"type": "string"
32103226
},
32113227
"for": {
3228+
"additionalProperties": {
3229+
"anyOf": [
3230+
{},
3231+
{
3232+
"$ref": "#/$defs/LocalizedExpression"
3233+
}
3234+
]
3235+
},
32123236
"description": "Arrays to iterate over.\n ",
32133237
"title": "For",
32143238
"type": "object"
@@ -4402,6 +4426,12 @@
44024426
"type": "string"
44034427
},
44044428
"if": {
4429+
"anyOf": [
4430+
{},
4431+
{
4432+
"$ref": "#/$defs/LocalizedExpression"
4433+
}
4434+
],
44054435
"description": "Condition.\n ",
44064436
"title": "If"
44074437
},
@@ -5725,7 +5755,10 @@
57255755
{
57265756
"type": "string"
57275757
},
5728-
{}
5758+
{},
5759+
{
5760+
"$ref": "#/$defs/LocalizedExpression"
5761+
}
57295762
],
57305763
"title": "Model"
57315764
},
@@ -5914,8 +5947,9 @@
59145947
{
59155948
"$ref": "#/$defs/LitellmParameters"
59165949
},
5950+
{},
59175951
{
5918-
"type": "object"
5952+
"$ref": "#/$defs/LocalizedExpression"
59195953
},
59205954
{
59215955
"type": "null"
@@ -6329,6 +6363,31 @@
63296363
"title": "LitellmParameters",
63306364
"type": "object"
63316365
},
6366+
"LocalizedExpression": {
6367+
"additionalProperties": false,
6368+
"description": "Expression with location information",
6369+
"properties": {
6370+
"expr": {
6371+
"title": "Expr"
6372+
},
6373+
"location": {
6374+
"anyOf": [
6375+
{
6376+
"$ref": "#/$defs/LocationType"
6377+
},
6378+
{
6379+
"type": "null"
6380+
}
6381+
],
6382+
"default": null
6383+
}
6384+
},
6385+
"required": [
6386+
"expr"
6387+
],
6388+
"title": "LocalizedExpression",
6389+
"type": "object"
6390+
},
63326391
"LocationType": {
63336392
"additionalProperties": false,
63346393
"properties": {
@@ -8060,6 +8119,9 @@
80608119
"read": {
80618120
"anyOf": [
80628121
{},
8122+
{
8123+
"$ref": "#/$defs/LocalizedExpression"
8124+
},
80638125
{
80648126
"type": "null"
80658127
}
@@ -9005,6 +9067,12 @@
90059067
"title": "Repeat"
90069068
},
90079069
"until": {
9070+
"anyOf": [
9071+
{},
9072+
{
9073+
"$ref": "#/$defs/LocalizedExpression"
9074+
}
9075+
],
90089076
"description": "Condition of the loop.\n ",
90099077
"title": "Until"
90109078
},

src/pdl/pdl_ast.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,6 @@
1616

1717
ScopeType: TypeAlias = dict[str, Any]
1818

19-
ExpressionType: TypeAlias = Any
20-
# (
21-
# str
22-
# | int
23-
# | float
24-
# | bool
25-
# | None
26-
# | list["ExpressionType"]
27-
# | dict[str, "ExpressionType"]
28-
# )
29-
3019

3120
Message: TypeAlias = dict[str, Any]
3221
Messages: TypeAlias = list[Message]
@@ -64,6 +53,28 @@ class LocationType(BaseModel):
6453
empty_block_location = LocationType(file="", path=[], table={})
6554

6655

56+
class LocalizedExpression(BaseModel):
57+
"""Expression with location information"""
58+
59+
model_config = ConfigDict(
60+
extra="forbid", use_attribute_docstrings=True, arbitrary_types_allowed=True
61+
)
62+
expr: Any
63+
location: Optional[LocationType] = None
64+
65+
66+
ExpressionType: TypeAlias = Any | LocalizedExpression
67+
# (
68+
# str
69+
# | int
70+
# | float
71+
# | bool
72+
# | None
73+
# | list["ExpressionType"]
74+
# | dict[str, "ExpressionType"]
75+
# )
76+
77+
6778
class Parser(BaseModel):
6879
model_config = ConfigDict(extra="forbid")
6980
description: Optional[str] = None
@@ -96,7 +107,11 @@ class ContributeValue(BaseModel):
96107
class Block(BaseModel):
97108
"""Common fields for all PDL blocks."""
98109

99-
model_config = ConfigDict(extra="forbid", use_attribute_docstrings=True)
110+
model_config = ConfigDict(
111+
extra="forbid",
112+
use_attribute_docstrings=True,
113+
arbitrary_types_allowed=True,
114+
)
100115

101116
description: Optional[str] = None
102117
"""Documentation associated to the block.
@@ -265,7 +280,7 @@ class ModelBlock(Block):
265280
class BamModelBlock(ModelBlock):
266281
platform: Literal[ModelPlatform.BAM]
267282
prompt_id: Optional[str] = None
268-
parameters: Optional[BamTextGenerationParameters | dict] = None
283+
parameters: Optional[BamTextGenerationParameters | ExpressionType] = None
269284
moderations: Optional[ModerationParameters] = None
270285
data: Optional[PromptTemplateData] = None
271286
constraints: Any = None # TODO
@@ -275,7 +290,7 @@ class LitellmModelBlock(ModelBlock):
275290
"""Call a LLM through the LiteLLM API: https://docs.litellm.ai/."""
276291

277292
platform: Literal[ModelPlatform.LITELLM] = ModelPlatform.LITELLM
278-
parameters: Optional[LitellmParameters | dict] = None
293+
parameters: Optional[LitellmParameters | ExpressionType] = None
279294

280295

281296
class CodeBlock(Block):

src/pdl/pdl_ast_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ def iter_block_children(f: Callable[[BlockType], None], block: BlockType) -> Non
3838
f(blocks)
3939
match block:
4040
case FunctionBlock():
41-
if block.returns is not None:
42-
f(block.returns)
41+
f(block.returns)
4342
case CallBlock():
4443
if block.trace is not None:
4544
f(block.trace)

src/pdl/pdl_compilers/to_regex.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
IncludeBlock,
1919
LitellmModelBlock,
2020
LitellmParameters,
21+
LocalizedExpression,
2122
ModelBlock,
2223
ReadBlock,
2324
RepeatBlock,
@@ -273,10 +274,14 @@ def compile_block(
273274
"include_stop_sequence", False
274275
)
275276
else:
276-
stop_sequences = block.parameters.stop_sequences or []
277+
if isinstance(block.parameters, LocalizedExpression):
278+
parameters = block.parameters.expr
279+
else:
280+
parameters = block.parameters
281+
stop_sequences = parameters.stop_sequences or []
277282
include_stop_sequence = (
278-
block.parameters.include_stop_sequence is None
279-
or block.parameters.include_stop_sequence
283+
parameters.include_stop_sequence is None
284+
or parameters.include_stop_sequence
280285
)
281286
case LitellmModelBlock():
282287
if block.parameters is None:
@@ -285,6 +290,8 @@ def compile_block(
285290
else:
286291
if isinstance(block.parameters, LitellmParameters):
287292
parameters = block.parameters.model_dump()
293+
elif isinstance(block.parameters, LocalizedExpression):
294+
parameters = block.parameters.expr
288295
else:
289296
parameters = block.parameters
290297
stop_sequences = parameters.get("stop", [])

0 commit comments

Comments
 (0)