Skip to content

Commit 34d1ab3

Browse files
committed
Location (WIP)
1 parent f3a737b commit 34d1ab3

File tree

6 files changed

+214
-20
lines changed

6 files changed

+214
-20
lines changed

src/pdl/pdl-schema.json

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3155,6 +3155,12 @@
31553155
"type": "string"
31563156
},
31573157
"data": {
3158+
"anyOf": [
3159+
{},
3160+
{
3161+
"$ref": "#/$defs/LocalizedExpression"
3162+
}
3163+
],
31583164
"description": "Value defined.",
31593165
"title": "Data"
31603166
},
@@ -6428,6 +6434,12 @@
64286434
"type": "string"
64296435
},
64306436
"if": {
6437+
"anyOf": [
6438+
{},
6439+
{
6440+
"$ref": "#/$defs/LocalizedExpression"
6441+
}
6442+
],
64316443
"description": "Condition.\n ",
64326444
"title": "If"
64336445
},
@@ -9035,6 +9047,31 @@
90359047
"title": "LitellmParameters",
90369048
"type": "object"
90379049
},
9050+
"LocalizedExpression": {
9051+
"additionalProperties": false,
9052+
"description": "Expression with location information\n ",
9053+
"properties": {
9054+
"expr": {
9055+
"title": "Expr"
9056+
},
9057+
"location": {
9058+
"anyOf": [
9059+
{
9060+
"$ref": "#/$defs/LocationType"
9061+
},
9062+
{
9063+
"type": "null"
9064+
}
9065+
],
9066+
"default": null
9067+
}
9068+
},
9069+
"required": [
9070+
"expr"
9071+
],
9072+
"title": "LocalizedExpression",
9073+
"type": "object"
9074+
},
90389075
"LocationType": {
90399076
"additionalProperties": false,
90409077
"properties": {
@@ -12945,6 +12982,12 @@
1294512982
"title": "Repeat"
1294612983
},
1294712984
"until": {
12985+
"anyOf": [
12986+
{},
12987+
{
12988+
"$ref": "#/$defs/LocalizedExpression"
12989+
}
12990+
],
1294812991
"description": "Condition of the loop.\n ",
1294912992
"title": "Until"
1295012993
},

src/pdl/pdl_ast.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from enum import StrEnum
55
from typing import Any, Literal, Optional, TypeAlias, TypedDict, Union
66

7+
import strictyaml
78
from genai.schema import (
89
DecodingMethod,
910
ModerationParameters,
@@ -14,17 +15,6 @@
1415

1516
ScopeType: TypeAlias = dict[str, Any]
1617

17-
ExpressionType: TypeAlias = Any
18-
# (
19-
# str
20-
# | int
21-
# | float
22-
# | bool
23-
# | None
24-
# | list["ExpressionType"]
25-
# | dict[str, "ExpressionType"]
26-
# )
27-
2818

2919
class Message(TypedDict):
3020
role: Optional[str]
@@ -63,9 +53,32 @@ class LocationType(BaseModel):
6353
table: dict[str, int]
6454

6555

56+
YamlSource: TypeAlias = strictyaml.YAML
57+
6658
empty_block_location = LocationType(file="", path=[], table={})
6759

6860

61+
class LocalizedExpression(BaseModel):
62+
"""Expression with location information"""
63+
64+
model_config = ConfigDict(extra="forbid", use_attribute_docstrings=True)
65+
expr: Any
66+
location: Optional[LocationType] = None
67+
_pdl_yaml_src: Optional[YamlSource] = None
68+
69+
70+
ExpressionType: TypeAlias = Any | LocalizedExpression
71+
# (
72+
# str
73+
# | int
74+
# | float
75+
# | bool
76+
# | None
77+
# | list["ExpressionType"]
78+
# | dict[str, "ExpressionType"]
79+
# )
80+
81+
6982
class Parser(BaseModel):
7083
model_config = ConfigDict(extra="forbid")
7184
description: Optional[str] = None
@@ -124,6 +137,7 @@ class Block(BaseModel):
124137
# Fields for internal use
125138
result: Optional[Any] = None
126139
location: Optional[LocationType] = None
140+
_pdl_yaml_src: Optional[YamlSource] = None
127141

128142

129143
class FunctionBlock(Block):

src/pdl/pdl_ast_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,18 @@
2828
)
2929

3030

31+
def is_block_list(blocks: BlocksType) -> bool:
32+
return not isinstance(blocks, str) and isinstance(blocks, Sequence)
33+
34+
3135
def iter_block_children(f: Callable[[BlockType], None], block: BlockType) -> None:
3236
if not isinstance(block, Block):
3337
return
3438
for blocks in block.defs.values():
3539
iter_blocks(f, blocks)
3640
match block:
3741
case FunctionBlock():
38-
if block.returns is not None:
39-
iter_blocks(f, block.returns)
42+
iter_blocks(f, block.returns)
4043
case CallBlock():
4144
if block.trace is not None:
4245
iter_blocks(f, block.trace)
@@ -105,7 +108,7 @@ def iter_block_children(f: Callable[[BlockType], None], block: BlockType) -> Non
105108

106109

107110
def iter_blocks(f: Callable[[BlockType], None], blocks: BlocksType) -> None:
108-
if not isinstance(blocks, str) and isinstance(blocks, Sequence):
111+
if is_block_list(blocks):
109112
for block in blocks:
110113
f(block)
111114
else:

src/pdl/pdl_dumper.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Any, Sequence, TypeAlias
2+
from typing import Any, TypeAlias
33

44
import yaml
55

@@ -35,6 +35,7 @@
3535
RepeatUntilBlock,
3636
TextBlock,
3737
)
38+
from .pdl_ast_utils import is_block_list
3839

3940
yaml.SafeDumper.org_represent_str = yaml.SafeDumper.represent_str # type: ignore
4041

@@ -235,7 +236,7 @@ def blocks_to_dict(
235236
result: (
236237
int | float | str | dict[str, Any] | list[int | float | str | dict[str, Any]]
237238
)
238-
if not isinstance(blocks, str) and isinstance(blocks, Sequence):
239+
if is_block_list(blocks):
239240
result = [block_to_dict(block, json_compatible) for block in blocks]
240241
else:
241242
result = block_to_dict(blocks, json_compatible)

src/pdl/pdl_interpreter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
# from itertools import batched
1010
from pathlib import Path
11-
from typing import Any, Generator, Optional, Sequence, TypeVar
11+
from typing import Any, Generator, Optional, TypeVar
1212

1313
import litellm
1414
import yaml
@@ -65,6 +65,7 @@
6565
TextBlock,
6666
empty_block_location,
6767
)
68+
from .pdl_ast_utils import is_block_list
6869
from .pdl_dumper import blocks_to_dict
6970
from .pdl_llms import BamModel, LitellmModel
7071
from .pdl_location_utils import append, get_loc_string
@@ -799,7 +800,7 @@ def step_blocks(
799800
background: Messages
800801
trace: BlocksType
801802
results = []
802-
if not isinstance(blocks, str) and isinstance(blocks, Sequence):
803+
if is_block_list(blocks):
803804
iteration_state = state.with_yield_result(
804805
state.yield_result and iteration_type != IterationType.ARRAY
805806
)

src/pdl/pdl_parser.py

Lines changed: 134 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,26 @@
11
import json
22
from pathlib import Path
3+
from typing import Any, Optional
34

5+
import strictyaml
46
import yaml
57
from pydantic import ValidationError
68

7-
from .pdl_ast import LocationType, PDLException, Program
8-
from .pdl_location_utils import get_line_map
9+
from .pdl_ast import (
10+
Block,
11+
BlocksType,
12+
CallBlock,
13+
DataBlock,
14+
FunctionBlock,
15+
LastOfBlock,
16+
LocationType,
17+
ModelBlock,
18+
PDLException,
19+
Program,
20+
YamlSource,
21+
)
22+
from .pdl_ast_utils import is_block_list
23+
from .pdl_location_utils import append, get_line_map
924
from .pdl_schema_error_analyzer import analyze_errors
1025

1126

@@ -25,6 +40,7 @@ def parse_str(pdl_str: str, file_name: str = "") -> tuple[Program, LocationType]
2540
loc = LocationType(path=[], file=file_name, table=line_table)
2641
try:
2742
prog = Program.model_validate(prog_yaml)
43+
set_program_location(prog, pdl_str)
2844
except ValidationError as exc:
2945
pdl_schema_file = Path(__file__).parent / "pdl-schema.json"
3046
with open(pdl_schema_file, "r", encoding="utf-8") as schema_fp:
@@ -35,3 +51,119 @@ def parse_str(pdl_str: str, file_name: str = "") -> tuple[Program, LocationType]
3551
errors = ["The file do not respect the schema."]
3652
raise PDLParseError(errors) from exc
3753
return prog, loc
54+
55+
56+
def set_program_location(prog: Program, pdl_str: str, file_name: str = ""):
57+
loc = strictyaml.dirty_load(pdl_str)
58+
set_location(prog.root, loc)
59+
60+
61+
def set_location(
62+
pdl: Any,
63+
loc: YamlSource,
64+
):
65+
if hasattr(pdl, "_pdl_yaml_src"):
66+
pdl._pdl_yaml_src = loc
67+
if isinstance(loc.data, dict):
68+
for x, v in loc.items():
69+
if hasattr(pdl, x.data):
70+
set_location(getattr(pdl, x.data), v)
71+
elif isinstance(pdl, list) and isinstance(loc.data, list):
72+
for data_i, loc_i in zip(pdl, loc):
73+
set_location(data_i, loc_i)
74+
75+
76+
# def set_program_location(prog: Program, pdl_str: str, file_name: str = ""):
77+
# line_table = get_line_map(pdl_str)
78+
# loc = LocationType(path=[], file=file_name, table=line_table)
79+
# return Program(set_blocks_location(prog.root, loc))
80+
81+
# def set_blocks_location(
82+
# blocks: BlocksType,
83+
# loc: YAML,
84+
# ):
85+
# if is_block_list(blocks):
86+
# return [set_block_location(block, append(loc, f"[{i}]")) for i, block in enumerate(blocks)]
87+
# return set_block_location(blocks, loc)
88+
89+
90+
# def set_block_location(
91+
# block: BlocksType,
92+
# loc: LocationType,
93+
# ):
94+
# if not isinstance(block, Block):
95+
# return DataBlock(data=block, location=loc)
96+
# block = block.model_copy(update={"location": loc})
97+
# defs_loc = append(loc, "defs")
98+
# block.defs = {x: set_block_location(b, append(defs_loc, x)) for x, b in block.defs }
99+
# if block.parser is not None:
100+
# block.parser = set_parser_location(block.parser)
101+
# if block.fallback is not None:
102+
# block.fallback = set_block_location(block.fallback, append(loc, "fallback"))
103+
# match block:
104+
# case FunctionBlock():
105+
# block.returns = set_blocks_location(block.returns, append(loc, "return"))
106+
# case CallBlock():
107+
# block.args = {x: set_expr_location(expr) for x, expr in block.args.items()}
108+
# case ModelBlock():
109+
# if block.input is not None:
110+
# iter_blocks(f, block.input)
111+
# case CodeBlock():
112+
# iter_blocks(f, block.code)
113+
# case GetBlock():
114+
# pass
115+
# case DataBlock():
116+
# pass
117+
# case TextBlock():
118+
# iter_blocks(f, block.text)
119+
# case LastOfBlock():
120+
# iter_blocks(f, block.lastOf)
121+
# case ArrayBlock():
122+
# iter_blocks(f, block.array)
123+
# case ObjectBlock():
124+
# if isinstance(block.object, dict):
125+
# body = list(block.object.values())
126+
# else:
127+
# body = block.object
128+
# iter_blocks(f, body)
129+
# case MessageBlock():
130+
# iter_blocks(f, block.content)
131+
# case IfBlock():
132+
# iter_blocks(f, block.then)
133+
# if block.elses is not None:
134+
# iter_blocks(f, block.elses)
135+
# case RepeatBlock():
136+
# iter_blocks(f, block.repeat)
137+
# if block.trace is not None:
138+
# for trace in block.trace:
139+
# iter_blocks(f, trace)
140+
# case RepeatUntilBlock():
141+
# iter_blocks(f, block.repeat)
142+
# if block.trace is not None:
143+
# for trace in block.trace:
144+
# iter_blocks(f, trace)
145+
# case ForBlock():
146+
# iter_blocks(f, block.repeat)
147+
# if block.trace is not None:
148+
# for trace in block.trace:
149+
# iter_blocks(f, trace)
150+
# case ErrorBlock():
151+
# iter_blocks(f, block.program)
152+
# case ReadBlock():
153+
# pass
154+
# case IncludeBlock():
155+
# if block.trace is not None:
156+
# iter_blocks(f, block.trace)
157+
# case EmptyBlock():
158+
# pass
159+
# case _:
160+
# assert (
161+
# False
162+
# ), f"Internal error (missing case iter_block_children({type(block)}))"
163+
# match (block.parser):
164+
# case "json" | "yaml" | RegexParser():
165+
# pass
166+
# case PdlParser():
167+
# iter_blocks(f, block.parser.pdl)
168+
# if block.fallback is not None:
169+
# iter_blocks(f, block.fallback)

0 commit comments

Comments
 (0)