Skip to content

Commit 5aae897

Browse files
authored
Improve schema of expressions (#453)
1 parent b25e51b commit 5aae897

File tree

6 files changed

+143
-76
lines changed

6 files changed

+143
-76
lines changed

.github/workflows/build.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,12 @@ jobs:
4343
pipdeptree -fl
4444
- name: pre-commit checks
4545
run: pre-commit run -a
46+
- name: run tests
47+
run: py.test -v --capture=tee-sys --ignore=tests/test_examples_run.py --ignore=tests/test_schema.py tests
48+
if: matrix.python-version == '3.11'
4649
- name: run tests
4750
run: py.test -v --capture=tee-sys --ignore=tests/test_examples_run.py tests
51+
if: matrix.python-version != '3.11'
4852

4953
viewer:
5054
name: Build PDL live viewer

docs/viewer.md

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,6 @@ hide:
5252
.pdl_repeat {
5353
background-color: rgb(251, 201, 86);
5454
}
55-
.pdl_repeat_until {
56-
background-color: rgb(243, 209, 77);
57-
}
58-
.pdl_for {
59-
background-color: rgb(245, 241, 133);
60-
}
6155
.pdl_read {
6256
background-color: rgb(243, 77, 113);
6357
}

src/pdl/pdl-schema.json

Lines changed: 72 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,10 @@
860860
"anyOf": [
861861
{},
862862
{
863-
"$ref": "#/$defs/LocalizedExpression"
863+
"type": "string"
864+
},
865+
{
866+
"$ref": "#/$defs/LocalizedExpression_TypeVar_"
864867
}
865868
],
866869
"description": "Function to call.\n ",
@@ -870,7 +873,10 @@
870873
"anyOf": [
871874
{},
872875
{
873-
"$ref": "#/$defs/LocalizedExpression"
876+
"type": "string"
877+
},
878+
{
879+
"$ref": "#/$defs/LocalizedExpression_TypeVar_"
874880
}
875881
],
876882
"default": {},
@@ -1768,7 +1774,10 @@
17681774
"anyOf": [
17691775
{},
17701776
{
1771-
"$ref": "#/$defs/LocalizedExpression"
1777+
"type": "string"
1778+
},
1779+
{
1780+
"$ref": "#/$defs/LocalizedExpression_TypeVar_"
17721781
}
17731782
],
17741783
"description": "Value defined.",
@@ -3705,9 +3714,14 @@
37053714
},
37063715
"if": {
37073716
"anyOf": [
3708-
{},
37093717
{
3710-
"$ref": "#/$defs/LocalizedExpression"
3718+
"type": "boolean"
3719+
},
3720+
{
3721+
"type": "string"
3722+
},
3723+
{
3724+
"$ref": "#/$defs/LocalizedExpression_TypeVar_"
37113725
}
37123726
],
37133727
"description": "Condition.\n ",
@@ -5144,9 +5158,8 @@
51445158
{
51455159
"type": "string"
51465160
},
5147-
{},
51485161
{
5149-
"$ref": "#/$defs/LocalizedExpression"
5162+
"$ref": "#/$defs/LocalizedExpression_TypeVar_"
51505163
}
51515164
],
51525165
"title": "Model"
@@ -5324,9 +5337,14 @@
53245337
{
53255338
"$ref": "#/$defs/LitellmParameters"
53265339
},
5327-
{},
53285340
{
5329-
"$ref": "#/$defs/LocalizedExpression"
5341+
"type": "object"
5342+
},
5343+
{
5344+
"type": "string"
5345+
},
5346+
{
5347+
"$ref": "#/$defs/LocalizedExpression_TypeVar_"
53305348
},
53315349
{
53325350
"type": "null"
@@ -5740,9 +5758,8 @@
57405758
"title": "LitellmParameters",
57415759
"type": "object"
57425760
},
5743-
"LocalizedExpression": {
5761+
"LocalizedExpression_TypeVar_": {
57445762
"additionalProperties": false,
5745-
"description": "Expression with location information",
57465763
"properties": {
57475764
"expr": {
57485765
"title": "Expr"
@@ -6143,7 +6160,10 @@
61436160
"anyOf": [
61446161
{},
61456162
{
6146-
"$ref": "#/$defs/LocalizedExpression"
6163+
"type": "string"
6164+
},
6165+
{
6166+
"$ref": "#/$defs/LocalizedExpression_TypeVar_"
61476167
}
61486168
],
61496169
"description": "Matched expression.\n ",
@@ -6203,9 +6223,14 @@
62036223
},
62046224
"if": {
62056225
"anyOf": [
6206-
{},
62076226
{
6208-
"$ref": "#/$defs/LocalizedExpression"
6227+
"type": "boolean"
6228+
},
6229+
{
6230+
"type": "string"
6231+
},
6232+
{
6233+
"$ref": "#/$defs/LocalizedExpression_TypeVar_"
62096234
},
62106235
{
62116236
"type": "null"
@@ -7946,9 +7971,11 @@
79467971
},
79477972
"read": {
79487973
"anyOf": [
7949-
{},
79507974
{
7951-
"$ref": "#/$defs/LocalizedExpression"
7975+
"type": "string"
7976+
},
7977+
{
7978+
"$ref": "#/$defs/LocalizedExpression_TypeVar_"
79527979
},
79537980
{
79547981
"type": "null"
@@ -8382,9 +8409,15 @@
83828409
{
83838410
"additionalProperties": {
83848411
"anyOf": [
8385-
{},
83868412
{
8387-
"$ref": "#/$defs/LocalizedExpression"
8413+
"items": {},
8414+
"type": "array"
8415+
},
8416+
{
8417+
"type": "string"
8418+
},
8419+
{
8420+
"$ref": "#/$defs/LocalizedExpression_TypeVar_"
83888421
}
83898422
]
83908423
},
@@ -8400,9 +8433,14 @@
84008433
},
84018434
"while": {
84028435
"anyOf": [
8403-
{},
84048436
{
8405-
"$ref": "#/$defs/LocalizedExpression"
8437+
"type": "boolean"
8438+
},
8439+
{
8440+
"type": "string"
8441+
},
8442+
{
8443+
"$ref": "#/$defs/LocalizedExpression_TypeVar_"
84068444
}
84078445
],
84088446
"default": true,
@@ -8486,9 +8524,14 @@
84868524
},
84878525
"until": {
84888526
"anyOf": [
8489-
{},
84908527
{
8491-
"$ref": "#/$defs/LocalizedExpression"
8528+
"type": "boolean"
8529+
},
8530+
{
8531+
"type": "string"
8532+
},
8533+
{
8534+
"$ref": "#/$defs/LocalizedExpression_TypeVar_"
84928535
}
84938536
],
84948537
"default": false,
@@ -8497,9 +8540,14 @@
84978540
},
84988541
"max_iterations": {
84998542
"anyOf": [
8500-
{},
85018543
{
8502-
"$ref": "#/$defs/LocalizedExpression"
8544+
"type": "integer"
8545+
},
8546+
{
8547+
"type": "string"
8548+
},
8549+
{
8550+
"$ref": "#/$defs/LocalizedExpression_TypeVar_"
85038551
},
85048552
{
85058553
"type": "null"

src/pdl/pdl_ast.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,17 @@
22
"""
33

44
from enum import StrEnum
5-
from typing import Any, Literal, Mapping, Optional, Sequence, TypeAlias, Union
5+
from typing import (
6+
Any,
7+
Generic,
8+
Literal,
9+
Mapping,
10+
Optional,
11+
Sequence,
12+
TypeAlias,
13+
TypeVar,
14+
Union,
15+
)
616

717
from pydantic import BaseModel, ConfigDict, Field, RootModel
818
from pydantic.json_schema import SkipJsonSchema
@@ -51,26 +61,24 @@ class LocationType(BaseModel):
5161
empty_block_location = LocationType(file="", path=[], table={})
5262

5363

54-
class LocalizedExpression(BaseModel):
64+
LocalizedExpressionT = TypeVar("LocalizedExpressionT")
65+
66+
67+
class LocalizedExpression(BaseModel, Generic[LocalizedExpressionT]):
5568
"""Expression with location information"""
5669

5770
model_config = ConfigDict(
58-
extra="forbid", use_attribute_docstrings=True, arbitrary_types_allowed=True
71+
extra="forbid",
72+
use_attribute_docstrings=True,
73+
arbitrary_types_allowed=True,
74+
model_title_generator=(lambda _: "LocalizedExpression"),
5975
)
60-
expr: Any
76+
expr: LocalizedExpressionT
6177
location: Optional[LocationType] = None
6278

6379

64-
ExpressionType: TypeAlias = Any | LocalizedExpression
65-
# (
66-
# str
67-
# | int
68-
# | float
69-
# | bool
70-
# | None
71-
# | list["ExpressionType"]
72-
# | dict[str, "ExpressionType"]
73-
# )
80+
ExpressionTypeT = TypeVar("ExpressionTypeT")
81+
ExpressionType: TypeAlias = ExpressionTypeT | str | LocalizedExpression[ExpressionTypeT]
7482

7583

7684
class Pattern(BaseModel):
@@ -316,7 +324,7 @@ class ModelPlatform(StrEnum):
316324

317325
class ModelBlock(Block):
318326
kind: Literal[BlockKind.MODEL] = BlockKind.MODEL
319-
model: str | ExpressionType
327+
model: ExpressionType[str]
320328
input: Optional["BlockType"] = None
321329
trace: Optional["BlockType"] = None
322330
modelResponse: Optional[str] = None
@@ -326,7 +334,7 @@ class LitellmModelBlock(ModelBlock):
326334
"""Call a LLM through the LiteLLM API: https://docs.litellm.ai/."""
327335

328336
platform: Literal[ModelPlatform.LITELLM] = ModelPlatform.LITELLM
329-
parameters: Optional[LitellmParameters | ExpressionType] = None
337+
parameters: Optional[LitellmParameters | ExpressionType[dict]] = None
330338

331339

332340
class CodeBlock(Block):
@@ -353,7 +361,7 @@ class DataBlock(Block):
353361
"""Arbitrary JSON value."""
354362

355363
kind: Literal[BlockKind.DATA] = BlockKind.DATA
356-
data: ExpressionType
364+
data: ExpressionType[Any]
357365
"""Value defined."""
358366
raw: bool = False
359367
"""Do not evaluate expressions inside strings."""
@@ -403,7 +411,7 @@ class IfBlock(Block):
403411
"""Conditional control structure."""
404412

405413
kind: Literal[BlockKind.IF] = BlockKind.IF
406-
condition: ExpressionType = Field(alias="if")
414+
condition: ExpressionType[bool] = Field(alias="if")
407415
"""Condition.
408416
"""
409417
then: "BlockType"
@@ -423,7 +431,7 @@ class MatchCase(BaseModel):
423431
case: Optional[PatternType] = None
424432
"""Value to match.
425433
"""
426-
if_: Optional[ExpressionType] = Field(default=None, alias="if")
434+
if_: Optional[ExpressionType[bool]] = Field(default=None, alias="if")
427435
"""Boolean condition to satisfy.
428436
"""
429437
then: "BlockType"
@@ -435,7 +443,7 @@ class MatchBlock(Block):
435443
"""Match control structure."""
436444

437445
kind: Literal[BlockKind.MATCH] = BlockKind.MATCH
438-
match_: ExpressionType = Field(alias="match")
446+
match_: ExpressionType[Any] = Field(alias="match")
439447
"""Matched expression.
440448
"""
441449
with_: list[MatchCase] = Field(alias="with")
@@ -484,19 +492,19 @@ class RepeatBlock(Block):
484492
"""Repeat the execution of a block."""
485493

486494
kind: Literal[BlockKind.REPEAT] = BlockKind.REPEAT
487-
fors: Optional[dict[str, ExpressionType]] = Field(default=None, alias="for")
495+
fors: Optional[dict[str, ExpressionType[list]]] = Field(default=None, alias="for")
488496
"""Arrays to iterate over.
489497
"""
490-
while_: ExpressionType = Field(default=True, alias="while")
498+
while_: ExpressionType[bool] = Field(default=True, alias="while")
491499
"""Condition to stay at the beginning of the loop.
492500
"""
493501
repeat: "BlockType"
494502
"""Body of the loop.
495503
"""
496-
until: ExpressionType = False
504+
until: ExpressionType[bool] = False
497505
"""Condition to exit at the end of the loop.
498506
"""
499-
max_iterations: Optional[ExpressionType] = None
507+
max_iterations: Optional[ExpressionType[int]] = None
500508
"""Maximal number of iterations to perform.
501509
"""
502510
join: JoinType = JoinText()
@@ -510,7 +518,7 @@ class ReadBlock(Block):
510518
"""Read from a file or standard input."""
511519

512520
kind: Literal[BlockKind.READ] = BlockKind.READ
513-
read: ExpressionType | None
521+
read: ExpressionType[str] | None
514522
"""Name of the file to read. If `None`, read the standard input.
515523
"""
516524
message: Optional[str] = None

src/pdl/pdl_compilers/to_regex.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,8 @@ def compile_block(
270270
parameters = block.parameters.model_dump()
271271
elif isinstance(block.parameters, LocalizedExpression):
272272
parameters = block.parameters.expr
273+
elif isinstance(block.parameters, str):
274+
parameters = {}
273275
else:
274276
parameters = block.parameters
275277
stop_sequences = parameters.get("stop", [])

0 commit comments

Comments
 (0)