diff --git a/src/pdl/pdl.py b/src/pdl/pdl.py index 09d752d2f..09d617a26 100644 --- a/src/pdl/pdl.py +++ b/src/pdl/pdl.py @@ -20,7 +20,7 @@ ) from .pdl_interpreter import InterpreterState, process_prog from .pdl_lazy import PdlDict -from .pdl_parser import parse_file, parse_str +from .pdl_parser import parse_dict, parse_file, parse_str from .pdl_runner import exec_docker from .pdl_utils import validate_scope @@ -104,7 +104,7 @@ def exec_dict( Returns: Return the final result. """ - program = Program.model_validate(prog) + program = parse_dict(prog) result = exec_program(program, config, scope, loc, output) return result diff --git a/src/pdl/pdl_parser.py b/src/pdl/pdl_parser.py index 0225cba43..4bda95f9f 100644 --- a/src/pdl/pdl_parser.py +++ b/src/pdl/pdl_parser.py @@ -1,11 +1,11 @@ import json from pathlib import Path -from typing import Optional +from typing import Any, Optional import yaml from pydantic import ValidationError -from .pdl_ast import PDLException, PdlLocationType, Program +from .pdl_ast import PDLException, PdlLocationType, Program, empty_block_location from .pdl_location_utils import get_line_map from .pdl_schema_error_analyzer import analyze_errors @@ -25,22 +25,34 @@ def parse_str( ) -> tuple[Program, PdlLocationType]: if file_name is None: file_name = "" - prog_yaml = yaml.safe_load(pdl_str) + prog_dict = yaml.safe_load(pdl_str) line_table = get_line_map(pdl_str) loc = PdlLocationType(path=[], file=file_name, table=line_table) + prog = parse_dict(prog_dict, loc) + return prog, loc + + +def parse_dict( + pdl_dict: dict[str, Any], loc: Optional[PdlLocationType] = None +) -> Program: try: - prog = Program.model_validate(prog_yaml) + prog = Program.model_validate(pdl_dict) # set_program_location(prog, pdl_str) except ValidationError as exc: pdl_schema_file = Path(__file__).parent / "pdl-schema.json" with open(pdl_schema_file, "r", encoding="utf-8") as schema_fp: schema = json.load(schema_fp) defs = schema["$defs"] - errors = analyze_errors(defs, defs["Program"], prog_yaml, loc) + if loc is None: + loc = empty_block_location + errors = analyze_errors(defs, defs["Program"], pdl_dict, loc) if errors == []: - errors = [f"The file PDL {file_name} does not respect the schema."] + if loc.file == "": + errors = ["The PDL program does not respect the schema."] + else: + errors = [f"The file PDL {loc.file} does not respect the schema."] raise PDLParseError(errors) from exc - return prog, loc + return prog # def set_program_location(prog: Program, pdl_str: str, file_name: str = ""):