|
| 1 | +""" |
| 2 | +ad.py |
| 3 | +----- |
| 4 | +
|
| 5 | +Top-level Python script which orchestrates the Julia AD tests. |
| 6 | +
|
| 7 | +Usage: |
| 8 | +
|
| 9 | + python ad.py setup |
| 10 | + python ad.py run --model <model_key> |
| 11 | + python ad.py html |
| 12 | +""" |
| 13 | + |
| 14 | +import json |
| 15 | +import os |
| 16 | +import subprocess as sp |
| 17 | +import tomllib |
| 18 | +import argparse |
| 19 | +from pathlib import Path |
| 20 | +from warnings import warn |
| 21 | + |
| 22 | +JULIA_COMMAND = ["julia", "--color=yes", "--project=.", "main.jl"] |
| 23 | + |
| 24 | +def run_and_capture(command): |
| 25 | + """Run a command and capture its output.""" |
| 26 | + result = sp.run(command, text=True, check=True, stdout=sp.PIPE) |
| 27 | + return result.stdout.strip() |
| 28 | + |
| 29 | +def append_to_github_output(key, value): |
| 30 | + """Append a key-value pair to the file specified by $GITHUB_OUTPUT.""" |
| 31 | + pair = f"{key}={json.dumps(value)}" |
| 32 | + try: |
| 33 | + fname = os.environ["GITHUB_OUTPUT"] |
| 34 | + with open(fname, "a") as f: |
| 35 | + print(pair, file=f) |
| 36 | + except KeyError: |
| 37 | + print(f"GITHUB_OUTPUT not set") |
| 38 | + print(pair) |
| 39 | + |
| 40 | +def setup(_args): |
| 41 | + models = run_and_capture([*JULIA_COMMAND, "--list-model-keys"]).splitlines() |
| 42 | + adtypes = run_and_capture([*JULIA_COMMAND, "--list-adtype-keys"]).splitlines() |
| 43 | + append_to_github_output("model_keys", models) |
| 44 | + append_to_github_output("adtype_keys", adtypes) |
| 45 | + # TODO: Save the Manifest.toml file or at least a mapping of packages -> |
| 46 | + # versions, see #9 |
| 47 | + |
| 48 | +def run_ad(args): |
| 49 | + model_key = args.model |
| 50 | + |
| 51 | + # Get adtypes |
| 52 | + try: |
| 53 | + adtypes = json.loads(os.environ["ADTYPE_KEYS"]) |
| 54 | + except KeyError: |
| 55 | + warn("ADTYPE_KEYS environment variable not set; running Julia to get adtypes") |
| 56 | + adtypes = run_and_capture([*JULIA_COMMAND, "--list-adtype-keys"]).splitlines() |
| 57 | + |
| 58 | + results = {} |
| 59 | + |
| 60 | + # Run tests |
| 61 | + for adtype in adtypes: |
| 62 | + print(f"Running {model_key} with {adtype}...") |
| 63 | + try: |
| 64 | + output = run_and_capture([*JULIA_COMMAND, "--run", model_key, adtype]) |
| 65 | + result = output.splitlines()[-1] |
| 66 | + except sp.CalledProcessError as e: |
| 67 | + result = "error" |
| 68 | + |
| 69 | + print(f" ... {model_key} with {adtype} ==> {result}") |
| 70 | + results[adtype] = result |
| 71 | + |
| 72 | + print(results) |
| 73 | + |
| 74 | + # Save results |
| 75 | + append_to_github_output("results", results) |
| 76 | + |
| 77 | + |
| 78 | +def html(_args): |
| 79 | + ## Here you can register known errors that have been reported on GitHub / |
| 80 | + ## have otherwise been documented. They will be turned into links in the table. |
| 81 | + |
| 82 | + ENZYME_RVS_ONE_PARAM = "https://github.com/EnzymeAD/Enzyme.jl/issues/2337" |
| 83 | + ENZYME_FWD_BLAS = "https://github.com/EnzymeAD/Enzyme.jl/issues/1995" |
| 84 | + KNOWN_ERRORS = { |
| 85 | + ("assume_beta", "EnzymeReverse"): ENZYME_RVS_ONE_PARAM, |
| 86 | + ("assume_dirichlet", "EnzymeReverse"): ENZYME_RVS_ONE_PARAM, |
| 87 | + ("assume_lkjcholu", "EnzymeReverse"): ENZYME_RVS_ONE_PARAM, |
| 88 | + ("assume_normal", "EnzymeReverse"): ENZYME_RVS_ONE_PARAM, |
| 89 | + ("assume_wishart", "EnzymeReverse"): ENZYME_RVS_ONE_PARAM, |
| 90 | + ("assume_mvnormal", "EnzymeForward"): ENZYME_FWD_BLAS, |
| 91 | + ("assume_wishart", "EnzymeForward"): ENZYME_FWD_BLAS, |
| 92 | + } |
| 93 | + |
| 94 | + results = os.environ.get("RESULTS_JSON", None) |
| 95 | + |
| 96 | + if results is None: |
| 97 | + print("RESULTS_JSON not set") |
| 98 | + exit(1) |
| 99 | + else: |
| 100 | + print("-------- $RESULTS_JSON --------") |
| 101 | + print(results) |
| 102 | + print("------------- END -------------") |
| 103 | + # results is a list of dicts that looks something like this. |
| 104 | + # [ |
| 105 | + # {"model_name": "model1", |
| 106 | + # "results": { |
| 107 | + # "AD1": "result1", |
| 108 | + # "AD2": "result2" |
| 109 | + # } |
| 110 | + # }, |
| 111 | + # {"model_name": "model2", |
| 112 | + # "results": { |
| 113 | + # "AD1": "result3", |
| 114 | + # "AD2": "result4" |
| 115 | + # } |
| 116 | + # } |
| 117 | + # ] |
| 118 | + # We do some processing to turn it into a dict of dicts |
| 119 | + results = json.loads(results) |
| 120 | + results = {entry["model_name"]: entry["results"] for entry in results} |
| 121 | + |
| 122 | + # You can also process this with pandas. I don't do that here because |
| 123 | + # (1) extra dependency |
| 124 | + # (2) df.to_html() doesn't have enough customisation for our purposes. |
| 125 | + # |
| 126 | + # import pandas as pd |
| 127 | + # results_flattened = [ |
| 128 | + # {"model_name": entry["model_name"], **entry["results"]} |
| 129 | + # for entry in json.loads(results) |
| 130 | + # ] |
| 131 | + # df = pd.DataFrame.from_records(results_flattened) |
| 132 | + |
| 133 | + adtypes = sorted(list(results.values())[0].keys()) |
| 134 | + models = sorted(results.keys()) |
| 135 | + |
| 136 | + # Create the directory if it doesn't exist |
| 137 | + os.makedirs("html", exist_ok=True) |
| 138 | + with open("html/index.html", "w") as f: |
| 139 | + f.write( |
| 140 | +"""<!DOCTYPE html> |
| 141 | +<html> |
| 142 | +<head><title>Turing AD tests</title> |
| 143 | +<link rel="stylesheet" type="text/css" href="main.css"> |
| 144 | +</head> |
| 145 | +<body><main> |
| 146 | +<h1>Turing AD tests</h1> |
| 147 | +
|
| 148 | +<p><a href="https://turinglang.org/docs">Turing.jl documentation</a> | <a href="https://github.com/TuringLang/Turing.jl">Turing.jl GitHub</a> | <a href="https://github.com/TuringLang/ADTests">Source code for these tests</a></p> |
| 149 | +
|
| 150 | +<p>This page is intended as a brief overview of how different AD backends |
| 151 | +perform on a variety of Turing.jl models. |
| 152 | +Note that the inclusion of any AD backend here does not imply an endorsement |
| 153 | +from the Turing team; this table is purely for information. |
| 154 | +</p> |
| 155 | +
|
| 156 | +<ul> |
| 157 | +<li>The definitions of the models and AD types below can be found on <a |
| 158 | +href="https://github.com/TuringLang/ADTests" target="_blank">GitHub</a>.</li> |
| 159 | +<li><b>Numbers</b> indicate the time taken to calculate the gradient of the log |
| 160 | +density of the model using the specified AD type, divided by the time taken to |
| 161 | +calculate the log density itself (in AD speak, the primal). Basically: |
| 162 | +<b>smaller means faster.</b></li> |
| 163 | +<li>'<span class="wrong">wrong</span>' means that AD ran but the result was not |
| 164 | +correct. If this happens you should be very wary! Note that this is done by |
| 165 | +comparing against the result obtained using ForwardDiff, i.e., ForwardDiff is |
| 166 | +by definition always 'correct'.</li> |
| 167 | +<li>'<span class="error">error</span>' means that AD didn't run.</li> |
| 168 | +<li>Some of the 'wrong' or 'error' entries have question marks next to them. |
| 169 | +These will link to a GitHub issue or other page that describes the problem. |
| 170 | +</ul> |
| 171 | +
|
| 172 | +<h2>Results</h2> |
| 173 | +""") |
| 174 | + |
| 175 | + # Table header |
| 176 | + f.write('<table id="results"><thead>') |
| 177 | + f.write("<tr>") |
| 178 | + f.write("<th>Model name \\ AD type</th>") |
| 179 | + for adtype in adtypes: |
| 180 | + f.write(f"<th>{adtype}</th>") |
| 181 | + f.write("</tr></thead><tbody>") |
| 182 | + # Table body |
| 183 | + for model_name in models: |
| 184 | + ad_results = results[model_name] |
| 185 | + f.write("\n<tr>") |
| 186 | + f.write(f"<td>{model_name}</td>") |
| 187 | + for adtype in adtypes: |
| 188 | + ad_result = ad_results[adtype] |
| 189 | + try: |
| 190 | + float(ad_result) |
| 191 | + f.write(f'<td>{ad_result}</td>') |
| 192 | + except ValueError: |
| 193 | + # Not a float, embed the class into the html |
| 194 | + error_url = KNOWN_ERRORS.get((model_name, adtype), None) |
| 195 | + span = f'<span class="{ad_result}">{ad_result}' |
| 196 | + if error_url is not None: |
| 197 | + span = f'<a class="issue" href="{error_url}" target="_blank">(?)</a> {span}' |
| 198 | + f.write(f'<td>{span}</td>') |
| 199 | + f.write("</tr>") |
| 200 | + f.write("\n</tbody></table></main></body></html>") |
| 201 | + |
| 202 | + with open("html/main.css", "w") as f: |
| 203 | + f.write( |
| 204 | +""" |
| 205 | +@import url('https://fonts.googleapis.com/css2?family=Fira+Code:[email protected]&family=Fira+Sans:ital,wght@0,100;0,200;0,300;0,400;0,500;0,600;0,700;0,800;0,900;1,100;1,200;1,300;1,400;1,500;1,600;1,700;1,800;1,900&display=swap'); |
| 206 | +html { |
| 207 | + font-family: "Fira Sans", sans-serif; |
| 208 | + box-sizing: border-box; |
| 209 | + font-size: 16px; |
| 210 | + line-height: 1.6; |
| 211 | + background-color: #f1f2e3; |
| 212 | +} |
| 213 | +*, *:before, *:after { |
| 214 | + box-sizing: inherit; |
| 215 | +} |
| 216 | +
|
| 217 | +body { |
| 218 | + display: flex; |
| 219 | + align-items: center; |
| 220 | + margin: 0px 0px 50px 0px; |
| 221 | +} |
| 222 | +
|
| 223 | +main { |
| 224 | + margin: auto; |
| 225 | + max-width: 1250px; |
| 226 | +} |
| 227 | +
|
| 228 | +table#results { |
| 229 | + text-align: right; |
| 230 | + border: 1px solid black; |
| 231 | + border-collapse: collapse; |
| 232 | +} |
| 233 | +
|
| 234 | +td, th { |
| 235 | + border: 1px solid black; |
| 236 | + padding: 0px 10px; |
| 237 | +} |
| 238 | +
|
| 239 | +th { |
| 240 | + background-color: #ececec; |
| 241 | + text-align: right; |
| 242 | +} |
| 243 | +
|
| 244 | +td { |
| 245 | + font-family: "Fira Code", monospace; |
| 246 | +} |
| 247 | +
|
| 248 | +tr > td:first-child { |
| 249 | + font-family: "Fira Sans", sans-serif; |
| 250 | + font-weight: 700; |
| 251 | + background-color: #ececec; |
| 252 | +} |
| 253 | +
|
| 254 | +tr > th:first-child { |
| 255 | + font-family: "Fira Sans", sans-serif; |
| 256 | + font-weight: 700; |
| 257 | + background-color: #d1d1d1; |
| 258 | +} |
| 259 | +
|
| 260 | +span.err, span.error { |
| 261 | + color: #ff0000; |
| 262 | +} |
| 263 | +
|
| 264 | +span.incorrect, span.wrong { |
| 265 | + color: #ff0000; |
| 266 | + background-color: #ffcccc; |
| 267 | +} |
| 268 | +
|
| 269 | +a.issue { |
| 270 | + color: #880000; |
| 271 | + text-decoration: none; |
| 272 | +} |
| 273 | +
|
| 274 | +a.issue:hover { |
| 275 | + background-color: #ffcccc; |
| 276 | + transition: background-color 0.3s ease; |
| 277 | +} |
| 278 | +
|
| 279 | +a.issue:visited { |
| 280 | + color: #880000; |
| 281 | +} |
| 282 | +""") |
| 283 | + |
| 284 | + |
| 285 | +def parse_arguments(): |
| 286 | + parser = argparse.ArgumentParser(description="Run AD tests") |
| 287 | + subparsers = parser.add_subparsers(required=True) |
| 288 | + |
| 289 | + # Setup |
| 290 | + parser_setup = subparsers.add_parser("setup", help="Setup by saving model keys, adtype keys, and Manifest") |
| 291 | + parser_setup.set_defaults(func=setup) |
| 292 | + |
| 293 | + # Run a given model with all adtypes |
| 294 | + parser_run = subparsers.add_parser("run", help="Run a given model with all adtypes") |
| 295 | + parser_run.add_argument( |
| 296 | + "--model", type=str, help="Key of the model to run" |
| 297 | + ) |
| 298 | + parser_run.set_defaults(func=run_ad) |
| 299 | + |
| 300 | + # Generate HTML page |
| 301 | + parser_html = subparsers.add_parser("html", help="Generate HTML page") |
| 302 | + parser_html.set_defaults(func=html) |
| 303 | + |
| 304 | + return parser.parse_args() |
| 305 | + |
| 306 | + |
| 307 | +if __name__ == "__main__": |
| 308 | + args = parse_arguments() |
| 309 | + args.func(args) |
0 commit comments