|
| 1 | +import io |
| 2 | +import os |
| 3 | +import pathlib |
| 4 | +import random |
| 5 | +from dataclasses import dataclass |
| 6 | +from typing import Optional |
| 7 | + |
| 8 | +from pytest import CaptureFixture, MonkeyPatch |
| 9 | + |
| 10 | +from pdl import pdl |
| 11 | +from pdl.pdl_ast import ScopeType |
| 12 | +from pdl.pdl_dumper import block_to_dict |
| 13 | +from pdl.pdl_lazy import PdlDict |
| 14 | +from pdl.pdl_parser import PDLParseError |
| 15 | + |
| 16 | +# test_examples_run.py runs the examples and compares the results |
| 17 | +# to the expected results in tests/results/examples |
| 18 | + |
| 19 | +UPDATE_RESULTS = True |
| 20 | +RESULTS_VERSION = 1 |
| 21 | +OLLAMA_GHACTIONS_RESULTS_ENV_VAR = os.getenv("OLLAMA_GHACTIONS_RESULTS", "") |
| 22 | +OLLAMA_GHACTIONS_RESULTS = False |
| 23 | +if OLLAMA_GHACTIONS_RESULTS_ENV_VAR.lower().strip() == "true": |
| 24 | + OLLAMA_GHACTIONS_RESULTS = True |
| 25 | + |
| 26 | +TO_SKIP = { |
| 27 | + str(name) |
| 28 | + for name in [ |
| 29 | + # Requires dataset dependency |
| 30 | + pathlib.Path("examples") / "cldk" / "cldk-assistant.pdl", |
| 31 | + pathlib.Path("examples") / "gsm8k" / "gsm8.pdl", |
| 32 | + pathlib.Path("examples") / "gsm8k" / "gsm8k-plan.pdl", |
| 33 | + # Requires installation dependencies |
| 34 | + pathlib.Path("examples") / "intrinsics" / "demo-hallucination.pdl", |
| 35 | + # Skip RAG examples |
| 36 | + pathlib.Path("examples") / "rag" / "pdf_index.pdl", |
| 37 | + pathlib.Path("examples") / "rag" / "pdf_query.pdl", |
| 38 | + pathlib.Path("examples") |
| 39 | + / "rag" |
| 40 | + / "rag_library1.pdl", # (This is glue to Python, it doesn't "run" alone) |
| 41 | + # Skip structure decoding example (Jing doesn't have WATSONX API KEY) |
| 42 | + pathlib.Path("examples") / "tutorial" / "structured_decoding.pdl", |
| 43 | + # OUtput result include trace (and thus timing) for some reason. Investigate why |
| 44 | + pathlib.Path("examples") / "react" / "react_call.pdl", # Very non-deterministic |
| 45 | + pathlib.Path("pdl-live-react") / "demos" / "error.pdl", |
| 46 | + pathlib.Path("pdl-live-react") / "demos" / "demo1.pdl", |
| 47 | + pathlib.Path("pdl-live-react") / "demos" / "demo2.pdl", |
| 48 | + # For now, skip the granite-io examples |
| 49 | + pathlib.Path("examples") / "granite-io" / "granite_io_hallucinations.pdl", |
| 50 | + pathlib.Path("examples") / "granite-io" / "granite_io_openai.pdl", |
| 51 | + pathlib.Path("examples") / "granite-io" / "granite_io_thinking.pdl", |
| 52 | + pathlib.Path("examples") / "granite-io" / "granite_io_transformers.pdl", |
| 53 | + ] |
| 54 | +} |
| 55 | + |
| 56 | + |
| 57 | +@dataclass |
| 58 | +class InputsType: |
| 59 | + stdin: Optional[str] = None |
| 60 | + scope: Optional[ScopeType] = None |
| 61 | + |
| 62 | + |
| 63 | +TESTS_WITH_INPUT: dict[str, InputsType] = { |
| 64 | + str(name): inputs |
| 65 | + for name, inputs in { |
| 66 | + pathlib.Path("examples") |
| 67 | + / "tutorial" |
| 68 | + / "programs" |
| 69 | + / "chatbot.pdl": InputsType(stdin="What is APR?\nyes\n"), |
| 70 | + pathlib.Path("examples") |
| 71 | + / "tutorial" |
| 72 | + / "input_stdin.pdl": InputsType(stdin="Hello\n"), |
| 73 | + pathlib.Path("examples") |
| 74 | + / "tutorial" |
| 75 | + / "input_stdin_multiline.pdl": InputsType(stdin="Hello\nBye\n"), |
| 76 | + pathlib.Path("examples") |
| 77 | + / "input" |
| 78 | + / "input_test1.pdl": InputsType(stdin="Hello\n"), |
| 79 | + pathlib.Path("examples") |
| 80 | + / "input" |
| 81 | + / "input_test2.pdl": InputsType(stdin="Hello\n"), |
| 82 | + pathlib.Path("examples") |
| 83 | + / "chatbot" |
| 84 | + / "chatbot.pdl": InputsType(stdin="What is APR?\nyes\n"), |
| 85 | + pathlib.Path("examples") |
| 86 | + / "demo" |
| 87 | + / "7-chatbot-roles.pdl": InputsType(stdin="What is APR?\nquit\n"), |
| 88 | + pathlib.Path("examples") |
| 89 | + / "tutorial" |
| 90 | + / "free_variables.pdl": InputsType(scope=PdlDict({"something": "ABC"})), |
| 91 | + }.items() |
| 92 | +} |
| 93 | + |
| 94 | + |
| 95 | +EXPECTED_PARSE_ERROR = [ |
| 96 | + pathlib.Path("tests") / "data" / "line" / "hello.pdl", |
| 97 | + pathlib.Path("tests") / "data" / "line" / "hello1.pdl", |
| 98 | + pathlib.Path("tests") / "data" / "line" / "hello4.pdl", |
| 99 | + pathlib.Path("tests") / "data" / "line" / "hello7.pdl", |
| 100 | + pathlib.Path("tests") / "data" / "line" / "hello8.pdl", |
| 101 | + pathlib.Path("tests") / "data" / "line" / "hello10.pdl", |
| 102 | + pathlib.Path("tests") / "data" / "line" / "hello11.pdl", |
| 103 | + pathlib.Path("tests") / "data" / "line" / "hello31.pdl", |
| 104 | +] |
| 105 | + |
| 106 | +EXPECTED_RUNTIME_ERROR = [ |
| 107 | + pathlib.Path("examples") / "callback" / "repair_prompt.pdl", |
| 108 | + pathlib.Path("examples") / "tutorial" / "type_list.pdl", |
| 109 | + pathlib.Path("examples") / "tutorial" / "type_checking.pdl", |
| 110 | + pathlib.Path("tests") / "data" / "line" / "hello12.pdl", |
| 111 | + pathlib.Path("tests") / "data" / "line" / "hello13.pdl", |
| 112 | + pathlib.Path("tests") / "data" / "line" / "hello14.pdl", |
| 113 | + pathlib.Path("tests") / "data" / "line" / "hello15.pdl", |
| 114 | + pathlib.Path("tests") / "data" / "line" / "hello16.pdl", |
| 115 | + pathlib.Path("tests") / "data" / "line" / "hello17.pdl", |
| 116 | + pathlib.Path("tests") / "data" / "line" / "hello18.pdl", |
| 117 | + pathlib.Path("tests") / "data" / "line" / "hello19.pdl", |
| 118 | + pathlib.Path("tests") / "data" / "line" / "hello20.pdl", |
| 119 | + pathlib.Path("tests") / "data" / "line" / "hello21.pdl", |
| 120 | + pathlib.Path("tests") / "data" / "line" / "hello22.pdl", |
| 121 | + pathlib.Path("tests") / "data" / "line" / "hello23.pdl", |
| 122 | + pathlib.Path("tests") / "data" / "line" / "hello24.pdl", |
| 123 | + pathlib.Path("tests") / "data" / "line" / "hello25.pdl", |
| 124 | + pathlib.Path("tests") / "data" / "line" / "hello26.pdl", |
| 125 | + pathlib.Path("tests") / "data" / "line" / "hello27.pdl", |
| 126 | + pathlib.Path("tests") / "data" / "line" / "hello28.pdl", |
| 127 | + pathlib.Path("tests") / "data" / "line" / "hello29.pdl", |
| 128 | + pathlib.Path("tests") / "data" / "line" / "hello3.pdl", |
| 129 | + pathlib.Path("tests") / "data" / "line" / "hello30.pdl", |
| 130 | + pathlib.Path("tests") / "data" / "line" / "hello9.pdl", |
| 131 | +] |
| 132 | + |
| 133 | + |
| 134 | +def __write_to_results_file( |
| 135 | + dir_name: pathlib.Path, filename: str, content: str |
| 136 | +) -> None: |
| 137 | + """ |
| 138 | + Write to results file |
| 139 | + """ |
| 140 | + |
| 141 | + dir_name.mkdir(parents=True, exist_ok=True) |
| 142 | + with open(dir_name / filename, "w", encoding="utf-8") as result_file: |
| 143 | + result_file.write(content) |
| 144 | + |
| 145 | + |
| 146 | +def __find_and_compare_results( |
| 147 | + test_file_name: pathlib.Path, actual_result: str |
| 148 | +) -> bool: |
| 149 | + """ |
| 150 | + Look through test_file_name's parent directory and see if any of *.result |
| 151 | + matches the actual output |
| 152 | + """ |
| 153 | + |
| 154 | + result_dir_name = pathlib.Path(".") / "tests" / "results" / test_file_name.parent |
| 155 | + expected_files = result_dir_name.glob(test_file_name.stem + ".*.result") |
| 156 | + |
| 157 | + for expected_file in expected_files: |
| 158 | + with open(expected_file, "r", encoding="utf-8") as truth_file: |
| 159 | + expected_result = str(truth_file.read()) |
| 160 | + if str(actual_result).strip() == expected_result.strip(): |
| 161 | + return True |
| 162 | + return False |
| 163 | + |
| 164 | + |
| 165 | +def test_valid_programs(capsys: CaptureFixture[str], monkeypatch: MonkeyPatch) -> None: |
| 166 | + actual_parse_error: set[str] = set() |
| 167 | + actual_runtime_error: set[str] = set() |
| 168 | + wrong_results = {} |
| 169 | + |
| 170 | + # files = pathlib.Path(".").glob("**/*.pdl") |
| 171 | + files = [ |
| 172 | + pathlib.Path("examples") / "demo" / "4-function.pdl", |
| 173 | + ] |
| 174 | + |
| 175 | + for pdl_file_name in files: |
| 176 | + |
| 177 | + scope: ScopeType = PdlDict({}) |
| 178 | + if str(pdl_file_name) in TO_SKIP: |
| 179 | + continue |
| 180 | + if str(pdl_file_name) in TESTS_WITH_INPUT: |
| 181 | + inputs = TESTS_WITH_INPUT[str(pdl_file_name)] |
| 182 | + if inputs.stdin is not None: |
| 183 | + monkeypatch.setattr( |
| 184 | + "sys.stdin", |
| 185 | + io.StringIO(inputs.stdin), |
| 186 | + ) |
| 187 | + if inputs.scope is not None: |
| 188 | + scope = inputs.scope |
| 189 | + try: |
| 190 | + random.seed(11) |
| 191 | + output = pdl.exec_file( |
| 192 | + pdl_file_name, |
| 193 | + scope=scope, |
| 194 | + output="all", |
| 195 | + config=pdl.InterpreterConfig(batch=0), |
| 196 | + ) |
| 197 | + result = output["result"] |
| 198 | + |
| 199 | + block_to_dict(output["trace"], json_compatible=True) |
| 200 | + result_dir_name = ( |
| 201 | + pathlib.Path(".") / "tests" / "results" / pdl_file_name.parent |
| 202 | + ) |
| 203 | + |
| 204 | + if not __find_and_compare_results(pdl_file_name, str(result)): |
| 205 | + |
| 206 | + if OLLAMA_GHACTIONS_RESULTS: |
| 207 | + print( |
| 208 | + "-------------------- Updating result from running Ollama on GitHub Actions -------------------- " |
| 209 | + ) |
| 210 | + result_file_name = f"{pdl_file_name.stem}.ollama_ghactions.result" |
| 211 | + __write_to_results_file( |
| 212 | + result_dir_name, result_file_name, str(result) |
| 213 | + ) |
| 214 | + |
| 215 | + # Evaluate the results again. If fails again, then consider this program as failing |
| 216 | + if not __find_and_compare_results(pdl_file_name, str(result)): |
| 217 | + wrong_results[str(pdl_file_name)] = { |
| 218 | + "actual": str(result), |
| 219 | + } |
| 220 | + # If evaluating results produces correct result, then this is considered passing |
| 221 | + else: |
| 222 | + continue |
| 223 | + |
| 224 | + if UPDATE_RESULTS: |
| 225 | + result_file_name = ( |
| 226 | + f"{pdl_file_name.stem}.{str(RESULTS_VERSION)}.result" |
| 227 | + ) |
| 228 | + __write_to_results_file( |
| 229 | + result_dir_name, result_file_name, str(result) |
| 230 | + ) |
| 231 | + |
| 232 | + wrong_results[str(pdl_file_name)] = { |
| 233 | + "actual": str(result), |
| 234 | + } |
| 235 | + except PDLParseError: |
| 236 | + actual_parse_error |= {str(pdl_file_name)} |
| 237 | + except Exception as exc: |
| 238 | + if str(pdl_file_name) not in set(str(p) for p in EXPECTED_RUNTIME_ERROR): |
| 239 | + print(f"{pdl_file_name}: {exc}") # unexpected error: breakpoint |
| 240 | + actual_runtime_error |= {str(pdl_file_name)} |
| 241 | + print(exc) |
| 242 | + |
| 243 | + # Parse errors |
| 244 | + expected_parse_error = set(str(p) for p in []) |
| 245 | + unexpected_parse_error = sorted(list(actual_parse_error - expected_parse_error)) |
| 246 | + assert ( |
| 247 | + len(unexpected_parse_error) == 0 |
| 248 | + ), f"Unexpected parse error: {unexpected_parse_error}" |
| 249 | + |
| 250 | + # Runtime errors |
| 251 | + expected_runtime_error = set(str(p) for p in []) |
| 252 | + unexpected_runtime_error = sorted( |
| 253 | + list(actual_runtime_error - expected_runtime_error) |
| 254 | + ) |
| 255 | + assert ( |
| 256 | + len(unexpected_runtime_error) == 0 |
| 257 | + ), f"Unexpected runtime error: {unexpected_runtime_error}" |
| 258 | + |
| 259 | + # Unexpected valid |
| 260 | + unexpected_valid = sorted( |
| 261 | + list( |
| 262 | + (expected_parse_error - actual_parse_error).union( |
| 263 | + expected_runtime_error - actual_runtime_error |
| 264 | + ) |
| 265 | + ) |
| 266 | + ) |
| 267 | + assert len(unexpected_valid) == 0, f"Unexpected valid: {unexpected_valid}" |
| 268 | + # Unexpected results |
| 269 | + assert len(wrong_results) == 0, f"Wrong results: {wrong_results}" |
0 commit comments