Skip to content

Commit 03f4951

Browse files
committed
Remove dependency to strictyaml
1 parent 6cde7d8 commit 03f4951

File tree

6 files changed

+70
-31
lines changed

6 files changed

+70
-31
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ dependencies = [
1616
"litellm~=1.49",
1717
"termcolor~=2.0",
1818
"ipython~=8.0",
19-
"strictyaml~=1.7.3"
2019
]
2120
authors = [
2221
{ name="Mandana Vaziri", email="[email protected]" },

src/pdl/pdl-schema.json

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1362,8 +1362,9 @@
13621362
{
13631363
"$ref": "#/$defs/BamTextGenerationParameters"
13641364
},
1365+
{},
13651366
{
1366-
"type": "object"
1367+
"$ref": "#/$defs/LocalizedExpression"
13671368
},
13681369
{
13691370
"type": "null"
@@ -9302,8 +9303,9 @@
93029303
{
93039304
"$ref": "#/$defs/LitellmParameters"
93049305
},
9306+
{},
93059307
{
9306-
"type": "object"
9308+
"$ref": "#/$defs/LocalizedExpression"
93079309
},
93089310
{
93099311
"type": "null"

src/pdl/pdl_ast.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from enum import StrEnum
55
from typing import Any, Literal, Optional, Sequence, TypeAlias, Union
66

7-
import strictyaml
87
from genai.schema import (
98
DecodingMethod,
109
ModerationParameters,
@@ -51,8 +50,6 @@ class LocationType(BaseModel):
5150
table: dict[str, int]
5251

5352

54-
YamlSource: TypeAlias = strictyaml.YAML
55-
5653
empty_block_location = LocationType(file="", path=[], table={})
5754

5855

@@ -64,7 +61,6 @@ class LocalizedExpression(BaseModel):
6461
)
6562
expr: Any
6663
location: Optional[LocationType] = None
67-
pdl_yaml_src: Optional[YamlSource] = None
6864

6965

7066
ExpressionType: TypeAlias = Any | LocalizedExpression
@@ -146,7 +142,6 @@ class Block(BaseModel):
146142
# Fields for internal use
147143
result: Optional[Any] = None
148144
location: Optional[LocationType] = None
149-
pdl_yaml_src: Optional[YamlSource] = None
150145

151146

152147
class FunctionBlock(Block):

src/pdl/pdl_interpreter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1048,7 +1048,7 @@ def step_call_model(
10481048
],
10491049
]:
10501050
# evaluate model name
1051-
model, concrete_block = process_expr_of(block, "model", scope, loc)
1051+
_, concrete_block = process_expr_of(block, "model", scope, loc)
10521052
# evaluate model params
10531053
match concrete_block:
10541054
case BamModelBlock():

src/pdl/pdl_parser.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import json
22
from pathlib import Path
3-
from typing import Any
43

5-
import strictyaml
64
import yaml
75
from pydantic import ValidationError
86

97
from .pdl_analysis import unused_program
10-
from .pdl_ast import LocationType, PDLException, Program, YamlSource
8+
from .pdl_ast import LocationType, PDLException, Program
119
from .pdl_location_utils import get_line_map
1210
from .pdl_schema_error_analyzer import analyze_errors
1311

@@ -28,7 +26,7 @@ def parse_str(pdl_str: str, file_name: str = "") -> tuple[Program, LocationType]
2826
loc = LocationType(path=[], file=file_name, table=line_table)
2927
try:
3028
prog = Program.model_validate(prog_yaml)
31-
set_program_location(prog, pdl_str)
29+
# set_program_location(prog, pdl_str)
3230
unused_program(prog)
3331
except ValidationError as exc:
3432
pdl_schema_file = Path(__file__).parent / "pdl-schema.json"
@@ -42,24 +40,24 @@ def parse_str(pdl_str: str, file_name: str = "") -> tuple[Program, LocationType]
4240
return prog, loc
4341

4442

45-
def set_program_location(prog: Program, pdl_str: str, file_name: str = ""):
46-
loc = strictyaml.dirty_load(pdl_str, allow_flow_style=True)
47-
set_location(prog.root, loc)
48-
49-
50-
def set_location(
51-
pdl: Any,
52-
loc: YamlSource,
53-
):
54-
if hasattr(pdl, "pdl_yaml_src"):
55-
pdl.pdl_yaml_src = loc
56-
if isinstance(loc.data, dict):
57-
for x, v in loc.items():
58-
if hasattr(pdl, x.data):
59-
set_location(getattr(pdl, x.data), v)
60-
elif isinstance(pdl, list) and isinstance(loc.data, list):
61-
for data_i, loc_i in zip(pdl, loc):
62-
set_location(data_i, loc_i)
43+
# def set_program_location(prog: Program, pdl_str: str, file_name: str = ""):
44+
# loc = strictyaml.dirty_load(pdl_str, allow_flow_style=True)
45+
# set_location(prog.root, loc)
46+
47+
48+
# def set_location(
49+
# pdl: Any,
50+
# loc: YamlSource,
51+
# ):
52+
# if hasattr(pdl, "pdl_yaml_src"):
53+
# pdl.pdl_yaml_src = loc
54+
# if isinstance(loc.data, dict):
55+
# for x, v in loc.items():
56+
# if hasattr(pdl, x.data):
57+
# set_location(getattr(pdl, x.data), v)
58+
# elif isinstance(pdl, list) and isinstance(loc.data, list):
59+
# for data_i, loc_i in zip(pdl, loc):
60+
# set_location(data_i, loc_i)
6361

6462

6563
# def set_program_location(prog: Program, pdl_str: str, file_name: str = ""):

tests/test_ast_utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import pathlib
2+
3+
from pdl.pdl_ast_utils import MappedFunctions, iter_block_children, map_block_children
4+
from pdl.pdl_parser import PDLParseError, parse_file
5+
6+
7+
class Counter:
8+
def __init__(self):
9+
self.cpt = 0
10+
11+
def incr(self, *args):
12+
self.cpt += 1
13+
14+
class IterCounter:
15+
def __init__(self):
16+
self.cpt = 0
17+
18+
def count(self, ast):
19+
self.cpt += 1
20+
iter_block_children(self.count, ast)
21+
22+
23+
class MapCounter:
24+
def __init__(self):
25+
self.cpt = 0
26+
27+
def count(map_self, ast):
28+
map_self.cpt += 1
29+
class C(MappedFunctions):
30+
def f_block(c_self, block):
31+
return map_self.count(block)
32+
_ = map_block_children(C(), ast)
33+
return ast
34+
35+
def test_ast_iterators() -> None:
36+
for yaml_file_name in pathlib.Path(".").glob("**/*.pdl"):
37+
try:
38+
ast, _ = parse_file(yaml_file_name)
39+
iter_cpt = IterCounter()
40+
iter_cpt.count(ast.root)
41+
map_cpt = MapCounter()
42+
map_cpt.count(ast.root)
43+
assert iter_cpt.cpt == map_cpt.cpt, yaml_file_name
44+
except PDLParseError:
45+
pass

0 commit comments

Comments
 (0)