Skip to content

Commit 9cd202c

Browse files
authored
Merge pull request #12 from TuringLang/py/python-rewrite
Rewrite Bash in Python
2 parents 0a8cd55 + 5c9e60d commit 9cd202c

File tree

4 files changed

+323
-299
lines changed

4 files changed

+323
-299
lines changed

.github/workflows/generate_website.yml

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ on:
44
push:
55
branches:
66
- main
7+
pull_request:
78
workflow_dispatch:
89

910
permissions:
@@ -32,10 +33,13 @@ jobs:
3233

3334
- uses: julia-actions/julia-buildpkg@v1
3435

35-
# This step sends model_keys and adtype_keys to GITHUB_OUTPUT
36+
- uses: astral-sh/setup-uv@v5
37+
with:
38+
python-version: "3.13"
39+
3640
- name: Setup keys
3741
id: keys
38-
run: ./ad.sh setup
42+
run: uv run ad.py setup
3943

4044
run-models:
4145
runs-on: ubuntu-latest
@@ -61,9 +65,13 @@ jobs:
6165

6266
- uses: julia-actions/julia-buildpkg@v1
6367

64-
- name: Run AD
68+
- uses: astral-sh/setup-uv@v5
69+
with:
70+
python-version: "3.13"
71+
72+
- name: Run given model with all adtypes
6573
id: run
66-
run: ./ad.sh run-model ${{ matrix.model }}
74+
run: uv run ad.py run --model ${{ matrix.model }}
6775
env:
6876
ADTYPE_KEYS: ${{ needs.setup-keys.outputs.adtype_keys }}
6977

@@ -77,6 +85,7 @@ jobs:
7785
7886
collect-results:
7987
runs-on: ubuntu-latest
88+
if: github.event_name != 'pull_request'
8089
needs: run-models
8190

8291
steps:
@@ -86,9 +95,7 @@ jobs:
8695
with:
8796
python-version: "3.13"
8897

89-
- run: |
90-
uv python install
91-
uv run collate.py
98+
- run: uv run ad.py html
9299
env:
93100
RESULTS_JSON: ${{ needs.run-models.outputs.json }}
94101

ad.py

Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
"""
2+
ad.py
3+
-----
4+
5+
Top-level Python script which orchestrates the Julia AD tests.
6+
7+
Usage:
8+
9+
python ad.py setup
10+
python ad.py run --model <model_key>
11+
python ad.py html
12+
"""
13+
14+
import json
15+
import os
16+
import subprocess as sp
17+
import tomllib
18+
import argparse
19+
from pathlib import Path
20+
from warnings import warn
21+
22+
JULIA_COMMAND = ["julia", "--color=yes", "--project=.", "main.jl"]
23+
24+
def run_and_capture(command):
25+
"""Run a command and capture its output."""
26+
result = sp.run(command, text=True, check=True, stdout=sp.PIPE)
27+
return result.stdout.strip()
28+
29+
def append_to_github_output(key, value):
30+
"""Append a key-value pair to the file specified by $GITHUB_OUTPUT."""
31+
pair = f"{key}={json.dumps(value)}"
32+
try:
33+
fname = os.environ["GITHUB_OUTPUT"]
34+
with open(fname, "a") as f:
35+
print(pair, file=f)
36+
except KeyError:
37+
print(f"GITHUB_OUTPUT not set")
38+
print(pair)
39+
40+
def setup(_args):
41+
models = run_and_capture([*JULIA_COMMAND, "--list-model-keys"]).splitlines()
42+
adtypes = run_and_capture([*JULIA_COMMAND, "--list-adtype-keys"]).splitlines()
43+
append_to_github_output("model_keys", models)
44+
append_to_github_output("adtype_keys", adtypes)
45+
# TODO: Save the Manifest.toml file or at least a mapping of packages ->
46+
# versions, see #9
47+
48+
def run_ad(args):
49+
model_key = args.model
50+
51+
# Get adtypes
52+
try:
53+
adtypes = json.loads(os.environ["ADTYPE_KEYS"])
54+
except KeyError:
55+
warn("ADTYPE_KEYS environment variable not set; running Julia to get adtypes")
56+
adtypes = run_and_capture([*JULIA_COMMAND, "--list-adtype-keys"]).splitlines()
57+
58+
results = {}
59+
60+
# Run tests
61+
for adtype in adtypes:
62+
print(f"Running {model_key} with {adtype}...")
63+
try:
64+
output = run_and_capture([*JULIA_COMMAND, "--run", model_key, adtype])
65+
result = output.splitlines()[-1]
66+
except sp.CalledProcessError as e:
67+
result = "error"
68+
69+
print(f" ... {model_key} with {adtype} ==> {result}")
70+
results[adtype] = result
71+
72+
print(results)
73+
74+
# Save results
75+
append_to_github_output("results", results)
76+
77+
78+
def html(_args):
79+
## Here you can register known errors that have been reported on GitHub /
80+
## have otherwise been documented. They will be turned into links in the table.
81+
82+
ENZYME_RVS_ONE_PARAM = "https://github.com/EnzymeAD/Enzyme.jl/issues/2337"
83+
ENZYME_FWD_BLAS = "https://github.com/EnzymeAD/Enzyme.jl/issues/1995"
84+
KNOWN_ERRORS = {
85+
("assume_beta", "EnzymeReverse"): ENZYME_RVS_ONE_PARAM,
86+
("assume_dirichlet", "EnzymeReverse"): ENZYME_RVS_ONE_PARAM,
87+
("assume_lkjcholu", "EnzymeReverse"): ENZYME_RVS_ONE_PARAM,
88+
("assume_normal", "EnzymeReverse"): ENZYME_RVS_ONE_PARAM,
89+
("assume_wishart", "EnzymeReverse"): ENZYME_RVS_ONE_PARAM,
90+
("assume_mvnormal", "EnzymeForward"): ENZYME_FWD_BLAS,
91+
("assume_wishart", "EnzymeForward"): ENZYME_FWD_BLAS,
92+
}
93+
94+
results = os.environ.get("RESULTS_JSON", None)
95+
96+
if results is None:
97+
print("RESULTS_JSON not set")
98+
exit(1)
99+
else:
100+
print("-------- $RESULTS_JSON --------")
101+
print(results)
102+
print("------------- END -------------")
103+
# results is a list of dicts that looks something like this.
104+
# [
105+
# {"model_name": "model1",
106+
# "results": {
107+
# "AD1": "result1",
108+
# "AD2": "result2"
109+
# }
110+
# },
111+
# {"model_name": "model2",
112+
# "results": {
113+
# "AD1": "result3",
114+
# "AD2": "result4"
115+
# }
116+
# }
117+
# ]
118+
# We do some processing to turn it into a dict of dicts
119+
results = json.loads(results)
120+
results = {entry["model_name"]: entry["results"] for entry in results}
121+
122+
# You can also process this with pandas. I don't do that here because
123+
# (1) extra dependency
124+
# (2) df.to_html() doesn't have enough customisation for our purposes.
125+
#
126+
# import pandas as pd
127+
# results_flattened = [
128+
# {"model_name": entry["model_name"], **entry["results"]}
129+
# for entry in json.loads(results)
130+
# ]
131+
# df = pd.DataFrame.from_records(results_flattened)
132+
133+
adtypes = sorted(list(results.values())[0].keys())
134+
models = sorted(results.keys())
135+
136+
# Create the directory if it doesn't exist
137+
os.makedirs("html", exist_ok=True)
138+
with open("html/index.html", "w") as f:
139+
f.write(
140+
"""<!DOCTYPE html>
141+
<html>
142+
<head><title>Turing AD tests</title>
143+
<link rel="stylesheet" type="text/css" href="main.css">
144+
</head>
145+
<body><main>
146+
<h1>Turing AD tests</h1>
147+
148+
<p><a href="https://turinglang.org/docs">Turing.jl documentation</a> | <a href="https://github.com/TuringLang/Turing.jl">Turing.jl GitHub</a> | <a href="https://github.com/TuringLang/ADTests">Source code for these tests</a></p>
149+
150+
<p>This page is intended as a brief overview of how different AD backends
151+
perform on a variety of Turing.jl models.
152+
Note that the inclusion of any AD backend here does not imply an endorsement
153+
from the Turing team; this table is purely for information.
154+
</p>
155+
156+
<ul>
157+
<li>The definitions of the models and AD types below can be found on <a
158+
href="https://github.com/TuringLang/ADTests" target="_blank">GitHub</a>.</li>
159+
<li><b>Numbers</b> indicate the time taken to calculate the gradient of the log
160+
density of the model using the specified AD type, divided by the time taken to
161+
calculate the log density itself (in AD speak, the primal). Basically:
162+
<b>smaller means faster.</b></li>
163+
<li>'<span class="wrong">wrong</span>' means that AD ran but the result was not
164+
correct. If this happens you should be very wary! Note that this is done by
165+
comparing against the result obtained using ForwardDiff, i.e., ForwardDiff is
166+
by definition always 'correct'.</li>
167+
<li>'<span class="error">error</span>' means that AD didn't run.</li>
168+
<li>Some of the 'wrong' or 'error' entries have question marks next to them.
169+
These will link to a GitHub issue or other page that describes the problem.
170+
</ul>
171+
172+
<h2>Results</h2>
173+
""")
174+
175+
# Table header
176+
f.write('<table id="results"><thead>')
177+
f.write("<tr>")
178+
f.write("<th>Model name \\ AD type</th>")
179+
for adtype in adtypes:
180+
f.write(f"<th>{adtype}</th>")
181+
f.write("</tr></thead><tbody>")
182+
# Table body
183+
for model_name in models:
184+
ad_results = results[model_name]
185+
f.write("\n<tr>")
186+
f.write(f"<td>{model_name}</td>")
187+
for adtype in adtypes:
188+
ad_result = ad_results[adtype]
189+
try:
190+
float(ad_result)
191+
f.write(f'<td>{ad_result}</td>')
192+
except ValueError:
193+
# Not a float, embed the class into the html
194+
error_url = KNOWN_ERRORS.get((model_name, adtype), None)
195+
span = f'<span class="{ad_result}">{ad_result}'
196+
if error_url is not None:
197+
span = f'<a class="issue" href="{error_url}" target="_blank">(?)</a> {span}'
198+
f.write(f'<td>{span}</td>')
199+
f.write("</tr>")
200+
f.write("\n</tbody></table></main></body></html>")
201+
202+
with open("html/main.css", "w") as f:
203+
f.write(
204+
"""
205+
@import url('https://fonts.googleapis.com/css2?family=Fira+Code:[email protected]&family=Fira+Sans:ital,wght@0,100;0,200;0,300;0,400;0,500;0,600;0,700;0,800;0,900;1,100;1,200;1,300;1,400;1,500;1,600;1,700;1,800;1,900&display=swap');
206+
html {
207+
font-family: "Fira Sans", sans-serif;
208+
box-sizing: border-box;
209+
font-size: 16px;
210+
line-height: 1.6;
211+
background-color: #f1f2e3;
212+
}
213+
*, *:before, *:after {
214+
box-sizing: inherit;
215+
}
216+
217+
body {
218+
display: flex;
219+
align-items: center;
220+
margin: 0px 0px 50px 0px;
221+
}
222+
223+
main {
224+
margin: auto;
225+
max-width: 1250px;
226+
}
227+
228+
table#results {
229+
text-align: right;
230+
border: 1px solid black;
231+
border-collapse: collapse;
232+
}
233+
234+
td, th {
235+
border: 1px solid black;
236+
padding: 0px 10px;
237+
}
238+
239+
th {
240+
background-color: #ececec;
241+
text-align: right;
242+
}
243+
244+
td {
245+
font-family: "Fira Code", monospace;
246+
}
247+
248+
tr > td:first-child {
249+
font-family: "Fira Sans", sans-serif;
250+
font-weight: 700;
251+
background-color: #ececec;
252+
}
253+
254+
tr > th:first-child {
255+
font-family: "Fira Sans", sans-serif;
256+
font-weight: 700;
257+
background-color: #d1d1d1;
258+
}
259+
260+
span.err, span.error {
261+
color: #ff0000;
262+
}
263+
264+
span.incorrect, span.wrong {
265+
color: #ff0000;
266+
background-color: #ffcccc;
267+
}
268+
269+
a.issue {
270+
color: #880000;
271+
text-decoration: none;
272+
}
273+
274+
a.issue:hover {
275+
background-color: #ffcccc;
276+
transition: background-color 0.3s ease;
277+
}
278+
279+
a.issue:visited {
280+
color: #880000;
281+
}
282+
""")
283+
284+
285+
def parse_arguments():
286+
parser = argparse.ArgumentParser(description="Run AD tests")
287+
subparsers = parser.add_subparsers(required=True)
288+
289+
# Setup
290+
parser_setup = subparsers.add_parser("setup", help="Setup by saving model keys, adtype keys, and Manifest")
291+
parser_setup.set_defaults(func=setup)
292+
293+
# Run a given model with all adtypes
294+
parser_run = subparsers.add_parser("run", help="Run a given model with all adtypes")
295+
parser_run.add_argument(
296+
"--model", type=str, help="Key of the model to run"
297+
)
298+
parser_run.set_defaults(func=run_ad)
299+
300+
# Generate HTML page
301+
parser_html = subparsers.add_parser("html", help="Generate HTML page")
302+
parser_html.set_defaults(func=html)
303+
304+
return parser.parse_args()
305+
306+
307+
if __name__ == "__main__":
308+
args = parse_arguments()
309+
args.func(args)

0 commit comments

Comments
 (0)