Skip to content

Commit d70f6df

Browse files
committed
Refactor to use Svelte
1 parent a28f271 commit d70f6df

20 files changed

+1409
-254
lines changed

.github/workflows/generate_website.yml

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,22 @@ jobs:
100100
RESULTS_JSON: ${{ needs.run-models.outputs.json }}
101101
MANIFEST: ${{ needs.setup-keys.outputs.manifest }}
102102

103+
- name: Set up pnpm
104+
uses: pnpm/action-setup@v4
105+
with:
106+
version: 10
107+
108+
- name: Install dependencies
109+
run: pnpm install
110+
working-directory: web
111+
112+
- name: Build website
113+
run: pnpm build --base /ADTests
114+
working-directory: web
115+
103116
- name: Upload results
104117
uses: peaceiris/actions-gh-pages@v4
105118
with:
106119
github_token: ${{ secrets.GITHUB_TOKEN }}
107-
publish_dir: ./html
120+
publish_dir: ./web/dist
108121
destination_dir: ${{ github.event_name == 'pull_request' && 'pr' || '' }}

ad.py

Lines changed: 50 additions & 246 deletions
Original file line numberDiff line numberDiff line change
@@ -113,53 +113,48 @@ def get_model_definition(model_key):
113113
break
114114
for submodel in submodels:
115115
lines = [get_model_definition(submodel), *lines]
116-
return "<br>".join(lines)
117-
118-
119-
def html(_args):
120-
## Here you can register known errors that have been reported on GitHub /
121-
## have otherwise been documented. They will be turned into links in the table.
122-
123-
ENZYME_RVS_ONE_PARAM = "https://github.com/EnzymeAD/Enzyme.jl/issues/2337"
124-
ENZYME_FWD_BLAS = "https://github.com/EnzymeAD/Enzyme.jl/issues/1995"
125-
MOONCAKE_THREADED = "https://github.com/chalk-lab/Mooncake.jl/issues/570"
126-
ENZYME_DEMO_INCORRECT = "https://github.com/EnzymeAD/Enzyme.jl/issues/2387"
127-
KNOWN_ERRORS = {
128-
("assume_mvnormal", "EnzymeForward"): ENZYME_FWD_BLAS,
129-
("assume_wishart", "EnzymeForward"): ENZYME_FWD_BLAS,
130-
("multithreaded", "Mooncake"): MOONCAKE_THREADED,
131-
("dot_assume_observe_index", "EnzymeForward"): ENZYME_DEMO_INCORRECT,
132-
("dot_assume_observe_index", "EnzymeReverse"): ENZYME_DEMO_INCORRECT,
133-
}
116+
return "\n".join(lines)
134117

135118

119+
def try_float(value):
136120
try:
137-
results = os.environ["RESULTS_JSON"]
138-
print("-------- $RESULTS_JSON --------")
139-
print(results)
140-
print("------------- END -------------")
141-
# results is a list of dicts that looks something like this.
142-
# [
143-
# {"model_name": "model1",
144-
# "results": {
145-
# "AD1": "result1",
146-
# "AD2": "result2"
147-
# }
148-
# },
149-
# {"model_name": "model2",
150-
# "results": {
151-
# "AD1": "result3",
152-
# "AD2": "result4"
153-
# }
154-
# }
155-
# ]
156-
# We do some processing to turn it into a dict of dicts
157-
results = json.loads(results)
158-
results = {entry["model_name"]: entry["results"] for entry in results}
159-
except KeyError as e:
160-
print("RESULTS_JSON environment variable not set")
161-
exit(1)
121+
return float(value)
122+
except ValueError:
123+
return value
124+
162125

126+
def html(_args):
127+
results = os.environ["RESULTS_JSON"]
128+
print("-------- $RESULTS_JSON --------")
129+
print(results)
130+
print("------------- END -------------")
131+
# results is a list of dicts that looks something like this.
132+
# [
133+
# {"model_name": "model1",
134+
# "results": {
135+
# "AD1": "result1",
136+
# "AD2": "result2"
137+
# }
138+
# },
139+
# {"model_name": "model2",
140+
# "results": {
141+
# "AD1": "result3",
142+
# "AD2": "result4"
143+
# }
144+
# }
145+
# ]
146+
# We do some processing to turn it into a dict of dicts, then dump it
147+
# to the website
148+
results = json.loads(results)
149+
new_data = {}
150+
for entry in results:
151+
model_name = entry["model_name"]
152+
results = {k: try_float(v) for k, v in entry["results"].items()}
153+
new_data[model_name] = results
154+
with open("web/src/data/adtests.json", "w") as f:
155+
json.dump(new_data, f, indent=2)
156+
157+
# Process Manifest
163158
try:
164159
manifest = os.environ["MANIFEST"]
165160
print("-------- $MANIFEST --------")
@@ -169,209 +164,18 @@ def html(_args):
169164
except KeyError as e:
170165
print("MANIFEST environment variable not set, reading from Manifest.toml")
171166
manifest = get_manifest_dict()
172-
173-
# You can also process this with pandas. I don't do that here because
174-
# (1) extra dependency
175-
# (2) df.to_html() doesn't have enough customisation for our purposes.
176-
#
177-
# import pandas as pd
178-
# results_flattened = [
179-
# {"model_name": entry["model_name"], **entry["results"]}
180-
# for entry in json.loads(results)
181-
# ]
182-
# df = pd.DataFrame.from_records(results_flattened)
183-
184-
adtypes = sorted(list(results.values())[0].keys())
185-
models = sorted(results.keys())
186-
187-
# Create the directory if it doesn't exist
188-
os.makedirs("html", exist_ok=True)
189-
with open("html/index.html", "w") as f:
190-
f.write(
191-
"""<!DOCTYPE html>
192-
<html>
193-
<head><title>Turing AD tests</title>
194-
<link rel="stylesheet" type="text/css" href="main.css">
195-
</head>
196-
<body><main>
197-
<h1>Turing AD tests</h1>
198-
199-
<p><a href="https://turinglang.org/docs">Turing.jl documentation</a> | <a href="https://github.com/TuringLang/Turing.jl">Turing.jl GitHub</a> | <a href="https://github.com/TuringLang/ADTests">Source code for these tests</a></p>
200-
201-
<p>This page is intended as a brief overview of how different AD backends
202-
perform on a variety of Turing.jl models.
203-
Note that the inclusion of any AD backend here does not imply an endorsement
204-
from the Turing team; this table is purely for information.
205-
</p>
206-
207-
<ul>
208-
<li>The definitions of the models and AD types below can be found on <a
209-
href="https://github.com/TuringLang/ADTests" target="_blank">GitHub</a>.</li>
210-
<li><b>Numbers</b> indicate the time taken to calculate the gradient of the log
211-
density of the model using the specified AD type, divided by the time taken to
212-
calculate the log density itself (in AD speak, the primal). Basically:
213-
<b>smaller means faster.</b></li>
214-
<li>'<span class="wrong">wrong</span>' means that AD ran but the result was not
215-
correct. If this happens you should be very wary! Note that this is done by
216-
comparing against the result obtained using ForwardDiff, i.e., ForwardDiff is
217-
by definition always 'correct'.</li>
218-
<li>'<span class="error">error</span>' means that AD didn't run.</li>
219-
<li>Some of the 'wrong' or 'error' entries have question marks next to them.
220-
These will link to a GitHub issue or other page that describes the problem.
221-
</ul>
222-
223-
<h2>Results</h2>
224-
225-
<p>(New: You can also hover over the model names to see their definitions.)</p>
226-
""")
227-
228-
# Table header
229-
f.write('<table id="results"><thead>')
230-
f.write("<tr>")
231-
f.write('<th class="right">Model name \\ AD type</th>')
232-
for adtype in adtypes:
233-
f.write(f'<th class="right">{adtype}</th>')
234-
f.write("</tr></thead><tbody>")
235-
# Table body
236-
for model_name in models:
237-
ad_results = results[model_name]
238-
f.write("\n<tr>")
239-
f.write(f'<td>{model_name}<div class="model-definition"><pre>{get_model_definition(model_name)}</pre></div></td>')
240-
for adtype in adtypes:
241-
ad_result = ad_results[adtype]
242-
try:
243-
float(ad_result)
244-
f.write(f'<td>{ad_result}</td>')
245-
except ValueError:
246-
# Not a float, embed the class into the html
247-
error_url = KNOWN_ERRORS.get((model_name, adtype), None)
248-
span = f'<span class="{ad_result}">{ad_result}'
249-
if error_url is not None:
250-
span = f'<a class="issue" href="{error_url}" target="_blank">(?)</a> {span}'
251-
f.write(f'<td>{span}</td>')
252-
f.write("</tr>")
253-
f.write("\n</tbody></table>")
254-
f.write("<h2>Manifest</h2><p>The tests above were run with the following package versions:</p>")
255-
f.write("<table id='manifest'><thead><tr><th>Package</th><th>Version</th>")
256-
for package, version in manifest.items():
257-
version_string = "" if version is None else f"v{version}"
258-
f.write(f"<tr><td>{package}</td><td>{version_string}</td></tr>")
259-
f.write("</table>")
260-
f.write("</main></body></html>")
261-
262-
with open("html/main.css", "w") as f:
263-
f.write(
264-
"""
265-
@import url('https://fonts.googleapis.com/css2?family=Fira+Code:[email protected]&family=Fira+Sans:ital,wght@0,100;0,200;0,300;0,400;0,500;0,600;0,700;0,800;0,900;1,100;1,200;1,300;1,400;1,500;1,600;1,700;1,800;1,900&display=swap');
266-
html {
267-
font-family: "Fira Sans", sans-serif;
268-
box-sizing: border-box;
269-
font-size: 16px;
270-
line-height: 1.6;
271-
background-color: #f1f2e3;
272-
}
273-
*, *:before, *:after {
274-
box-sizing: inherit;
275-
}
276-
277-
body {
278-
display: flex;
279-
align-items: center;
280-
margin: 0px 0px 50px 0px;
281-
}
282-
283-
main {
284-
margin: auto;
285-
max-width: 1250px;
286-
}
287-
288-
table {
289-
border: 1px solid black;
290-
border-collapse: collapse;
291-
}
292-
293-
table#results {
294-
text-align: right;
295-
}
296-
297-
td, th {
298-
border: 1px solid black;
299-
padding: 0px 10px;
300-
white-space: nowrap;
301-
}
302-
303-
th {
304-
background-color: #ececec;
305-
text-align: left;
306-
}
307-
308-
th.right {
309-
text-align: right;
310-
}
311-
312-
td {
313-
font-family: "Fira Code", monospace;
314-
}
315-
316-
tr > td:first-child {
317-
font-family: "Fira Sans", sans-serif;
318-
font-weight: 700;
319-
background-color: #ececec;
320-
position: relative;
321-
}
322-
323-
tr > td:first-child:hover {
324-
background-color: #f6f6f6;
325-
}
326-
327-
tr > td:first-child:hover > div.model-definition {
328-
display: block;
329-
}
330-
331-
tr > th:first-child {
332-
font-family: "Fira Sans", sans-serif;
333-
font-weight: 700;
334-
background-color: #d1d1d1;
335-
}
336-
337-
span.err, span.error {
338-
color: #ff0000;
339-
}
340-
341-
span.incorrect, span.wrong {
342-
color: #ff0000;
343-
background-color: #ffcccc;
344-
}
345-
346-
a.issue {
347-
color: #880000;
348-
text-decoration: none;
349-
}
350-
351-
a.issue:hover {
352-
background-color: #ffcccc;
353-
transition: background-color 0.3s ease;
354-
}
355-
356-
a.issue:visited {
357-
color: #880000;
358-
}
359-
360-
div.model-definition {
361-
background-color: #f6f6f6;
362-
border: 1px solid black;
363-
border-radius: 5px;
364-
padding: 0 10px;
365-
z-index: 5;
366-
font-size: 0.9em;
367-
text-align: left;
368-
font-weight: normal;
369-
position: absolute;
370-
left: 100%;
371-
top: 0;
372-
display: none;
373-
}
374-
""")
167+
with open("web/src/data/manifest.json", "w") as f:
168+
json.dump(manifest, f, indent=2)
169+
170+
# Process model definitions
171+
model_keys = list(new_data.keys())
172+
# technically we can also get it this way
173+
# model_keys = run_and_capture([*JULIA_COMMAND, "--list-model-keys"]).splitlines()
174+
model_definitions = {}
175+
for model_key in model_keys:
176+
model_definitions[model_key] = get_model_definition(model_key)
177+
with open("web/src/data/model_definitions.json", "w") as f:
178+
json.dump(model_definitions, f, indent=2)
375179

376180

377181
def parse_arguments():

main.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ import Zygote
1313

1414
# AD backends to test.
1515
ADTYPES = Dict(
16-
"FiniteDifferences" => AutoFiniteDifferences(; fdm=central_fdm(5, 1)),
16+
# "FiniteDifferences" => AutoFiniteDifferences(; fdm=central_fdm(5, 1)),
1717
"ForwardDiff" => AutoForwardDiff(),
18-
"ReverseDiff" => AutoReverseDiff(; compile=false),
19-
"ReverseDiffCompiled" => AutoReverseDiff(; compile=true),
20-
"Mooncake" => AutoMooncake(; config=nothing),
21-
"EnzymeForward" => AutoEnzyme(; mode=set_runtime_activity(Forward, true)),
22-
"EnzymeReverse" => AutoEnzyme(; mode=set_runtime_activity(Reverse, true)),
23-
"Zygote" => AutoZygote(),
18+
# "ReverseDiff" => AutoReverseDiff(; compile=false),
19+
# "ReverseDiffCompiled" => AutoReverseDiff(; compile=true),
20+
# "Mooncake" => AutoMooncake(; config=nothing),
21+
# "EnzymeForward" => AutoEnzyme(; mode=set_runtime_activity(Forward, true)),
22+
# "EnzymeReverse" => AutoEnzyme(; mode=set_runtime_activity(Reverse, true)),
23+
# "Zygote" => AutoZygote(),
2424
)
2525

2626
# Models to test.

web/.gitignore

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Logs
2+
logs
3+
*.log
4+
npm-debug.log*
5+
yarn-debug.log*
6+
yarn-error.log*
7+
pnpm-debug.log*
8+
lerna-debug.log*
9+
10+
node_modules
11+
dist
12+
dist-ssr
13+
*.local
14+
15+
# Editor directories and files
16+
.vscode/*
17+
!.vscode/extensions.json
18+
.idea
19+
.DS_Store
20+
*.suo
21+
*.ntvs*
22+
*.njsproj
23+
*.sln
24+
*.sw?
25+
26+
src/data/*.json

web/.vscode/extensions.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"recommendations": ["svelte.svelte-vscode"]
3+
}

0 commit comments

Comments
 (0)