@@ -113,53 +113,48 @@ def get_model_definition(model_key):
113
113
break
114
114
for submodel in submodels :
115
115
lines = [get_model_definition (submodel ), * lines ]
116
- return "<br>" .join (lines )
117
-
118
-
119
- def html (_args ):
120
- ## Here you can register known errors that have been reported on GitHub /
121
- ## have otherwise been documented. They will be turned into links in the table.
122
-
123
- ENZYME_RVS_ONE_PARAM = "https://github.com/EnzymeAD/Enzyme.jl/issues/2337"
124
- ENZYME_FWD_BLAS = "https://github.com/EnzymeAD/Enzyme.jl/issues/1995"
125
- MOONCAKE_THREADED = "https://github.com/chalk-lab/Mooncake.jl/issues/570"
126
- ENZYME_DEMO_INCORRECT = "https://github.com/EnzymeAD/Enzyme.jl/issues/2387"
127
- KNOWN_ERRORS = {
128
- ("assume_mvnormal" , "EnzymeForward" ): ENZYME_FWD_BLAS ,
129
- ("assume_wishart" , "EnzymeForward" ): ENZYME_FWD_BLAS ,
130
- ("multithreaded" , "Mooncake" ): MOONCAKE_THREADED ,
131
- ("dot_assume_observe_index" , "EnzymeForward" ): ENZYME_DEMO_INCORRECT ,
132
- ("dot_assume_observe_index" , "EnzymeReverse" ): ENZYME_DEMO_INCORRECT ,
133
- }
116
+ return "\n " .join (lines )
134
117
135
118
119
+ def try_float (value ):
136
120
try :
137
- results = os .environ ["RESULTS_JSON" ]
138
- print ("-------- $RESULTS_JSON --------" )
139
- print (results )
140
- print ("------------- END -------------" )
141
- # results is a list of dicts that looks something like this.
142
- # [
143
- # {"model_name": "model1",
144
- # "results": {
145
- # "AD1": "result1",
146
- # "AD2": "result2"
147
- # }
148
- # },
149
- # {"model_name": "model2",
150
- # "results": {
151
- # "AD1": "result3",
152
- # "AD2": "result4"
153
- # }
154
- # }
155
- # ]
156
- # We do some processing to turn it into a dict of dicts
157
- results = json .loads (results )
158
- results = {entry ["model_name" ]: entry ["results" ] for entry in results }
159
- except KeyError as e :
160
- print ("RESULTS_JSON environment variable not set" )
161
- exit (1 )
121
+ return float (value )
122
+ except ValueError :
123
+ return value
124
+
162
125
126
+ def html (_args ):
127
+ results = os .environ ["RESULTS_JSON" ]
128
+ print ("-------- $RESULTS_JSON --------" )
129
+ print (results )
130
+ print ("------------- END -------------" )
131
+ # results is a list of dicts that looks something like this.
132
+ # [
133
+ # {"model_name": "model1",
134
+ # "results": {
135
+ # "AD1": "result1",
136
+ # "AD2": "result2"
137
+ # }
138
+ # },
139
+ # {"model_name": "model2",
140
+ # "results": {
141
+ # "AD1": "result3",
142
+ # "AD2": "result4"
143
+ # }
144
+ # }
145
+ # ]
146
+ # We do some processing to turn it into a dict of dicts, then dump it
147
+ # to the website
148
+ results = json .loads (results )
149
+ new_data = {}
150
+ for entry in results :
151
+ model_name = entry ["model_name" ]
152
+ results = {k : try_float (v ) for k , v in entry ["results" ].items ()}
153
+ new_data [model_name ] = results
154
+ with open ("web/src/data/adtests.json" , "w" ) as f :
155
+ json .dump (new_data , f , indent = 2 )
156
+
157
+ # Process Manifest
163
158
try :
164
159
manifest = os .environ ["MANIFEST" ]
165
160
print ("-------- $MANIFEST --------" )
@@ -169,209 +164,18 @@ def html(_args):
169
164
except KeyError as e :
170
165
print ("MANIFEST environment variable not set, reading from Manifest.toml" )
171
166
manifest = get_manifest_dict ()
172
-
173
- # You can also process this with pandas. I don't do that here because
174
- # (1) extra dependency
175
- # (2) df.to_html() doesn't have enough customisation for our purposes.
176
- #
177
- # import pandas as pd
178
- # results_flattened = [
179
- # {"model_name": entry["model_name"], **entry["results"]}
180
- # for entry in json.loads(results)
181
- # ]
182
- # df = pd.DataFrame.from_records(results_flattened)
183
-
184
- adtypes = sorted (list (results .values ())[0 ].keys ())
185
- models = sorted (results .keys ())
186
-
187
- # Create the directory if it doesn't exist
188
- os .makedirs ("html" , exist_ok = True )
189
- with open ("html/index.html" , "w" ) as f :
190
- f .write (
191
- """<!DOCTYPE html>
192
- <html>
193
- <head><title>Turing AD tests</title>
194
- <link rel="stylesheet" type="text/css" href="main.css">
195
- </head>
196
- <body><main>
197
- <h1>Turing AD tests</h1>
198
-
199
- <p><a href="https://turinglang.org/docs">Turing.jl documentation</a> | <a href="https://github.com/TuringLang/Turing.jl">Turing.jl GitHub</a> | <a href="https://github.com/TuringLang/ADTests">Source code for these tests</a></p>
200
-
201
- <p>This page is intended as a brief overview of how different AD backends
202
- perform on a variety of Turing.jl models.
203
- Note that the inclusion of any AD backend here does not imply an endorsement
204
- from the Turing team; this table is purely for information.
205
- </p>
206
-
207
- <ul>
208
- <li>The definitions of the models and AD types below can be found on <a
209
- href="https://github.com/TuringLang/ADTests" target="_blank">GitHub</a>.</li>
210
- <li><b>Numbers</b> indicate the time taken to calculate the gradient of the log
211
- density of the model using the specified AD type, divided by the time taken to
212
- calculate the log density itself (in AD speak, the primal). Basically:
213
- <b>smaller means faster.</b></li>
214
- <li>'<span class="wrong">wrong</span>' means that AD ran but the result was not
215
- correct. If this happens you should be very wary! Note that this is done by
216
- comparing against the result obtained using ForwardDiff, i.e., ForwardDiff is
217
- by definition always 'correct'.</li>
218
- <li>'<span class="error">error</span>' means that AD didn't run.</li>
219
- <li>Some of the 'wrong' or 'error' entries have question marks next to them.
220
- These will link to a GitHub issue or other page that describes the problem.
221
- </ul>
222
-
223
- <h2>Results</h2>
224
-
225
- <p>(New: You can also hover over the model names to see their definitions.)</p>
226
- """ )
227
-
228
- # Table header
229
- f .write ('<table id="results"><thead>' )
230
- f .write ("<tr>" )
231
- f .write ('<th class="right">Model name \\ AD type</th>' )
232
- for adtype in adtypes :
233
- f .write (f'<th class="right">{ adtype } </th>' )
234
- f .write ("</tr></thead><tbody>" )
235
- # Table body
236
- for model_name in models :
237
- ad_results = results [model_name ]
238
- f .write ("\n <tr>" )
239
- f .write (f'<td>{ model_name } <div class="model-definition"><pre>{ get_model_definition (model_name )} </pre></div></td>' )
240
- for adtype in adtypes :
241
- ad_result = ad_results [adtype ]
242
- try :
243
- float (ad_result )
244
- f .write (f'<td>{ ad_result } </td>' )
245
- except ValueError :
246
- # Not a float, embed the class into the html
247
- error_url = KNOWN_ERRORS .get ((model_name , adtype ), None )
248
- span = f'<span class="{ ad_result } ">{ ad_result } '
249
- if error_url is not None :
250
- span = f'<a class="issue" href="{ error_url } " target="_blank">(?)</a> { span } '
251
- f .write (f'<td>{ span } </td>' )
252
- f .write ("</tr>" )
253
- f .write ("\n </tbody></table>" )
254
- f .write ("<h2>Manifest</h2><p>The tests above were run with the following package versions:</p>" )
255
- f .write ("<table id='manifest'><thead><tr><th>Package</th><th>Version</th>" )
256
- for package , version in manifest .items ():
257
- version_string = "" if version is None else f"v{ version } "
258
- f .write (f"<tr><td>{ package } </td><td>{ version_string } </td></tr>" )
259
- f .write ("</table>" )
260
- f .write ("</main></body></html>" )
261
-
262
- with open ("html/main.css" , "w" ) as f :
263
- f .write (
264
- """
265
- @import url('https://fonts.googleapis.com/css2?family=Fira+Code:[email protected] &family=Fira+Sans:ital,wght@0,100;0,200;0,300;0,400;0,500;0,600;0,700;0,800;0,900;1,100;1,200;1,300;1,400;1,500;1,600;1,700;1,800;1,900&display=swap');
266
- html {
267
- font-family: "Fira Sans", sans-serif;
268
- box-sizing: border-box;
269
- font-size: 16px;
270
- line-height: 1.6;
271
- background-color: #f1f2e3;
272
- }
273
- *, *:before, *:after {
274
- box-sizing: inherit;
275
- }
276
-
277
- body {
278
- display: flex;
279
- align-items: center;
280
- margin: 0px 0px 50px 0px;
281
- }
282
-
283
- main {
284
- margin: auto;
285
- max-width: 1250px;
286
- }
287
-
288
- table {
289
- border: 1px solid black;
290
- border-collapse: collapse;
291
- }
292
-
293
- table#results {
294
- text-align: right;
295
- }
296
-
297
- td, th {
298
- border: 1px solid black;
299
- padding: 0px 10px;
300
- white-space: nowrap;
301
- }
302
-
303
- th {
304
- background-color: #ececec;
305
- text-align: left;
306
- }
307
-
308
- th.right {
309
- text-align: right;
310
- }
311
-
312
- td {
313
- font-family: "Fira Code", monospace;
314
- }
315
-
316
- tr > td:first-child {
317
- font-family: "Fira Sans", sans-serif;
318
- font-weight: 700;
319
- background-color: #ececec;
320
- position: relative;
321
- }
322
-
323
- tr > td:first-child:hover {
324
- background-color: #f6f6f6;
325
- }
326
-
327
- tr > td:first-child:hover > div.model-definition {
328
- display: block;
329
- }
330
-
331
- tr > th:first-child {
332
- font-family: "Fira Sans", sans-serif;
333
- font-weight: 700;
334
- background-color: #d1d1d1;
335
- }
336
-
337
- span.err, span.error {
338
- color: #ff0000;
339
- }
340
-
341
- span.incorrect, span.wrong {
342
- color: #ff0000;
343
- background-color: #ffcccc;
344
- }
345
-
346
- a.issue {
347
- color: #880000;
348
- text-decoration: none;
349
- }
350
-
351
- a.issue:hover {
352
- background-color: #ffcccc;
353
- transition: background-color 0.3s ease;
354
- }
355
-
356
- a.issue:visited {
357
- color: #880000;
358
- }
359
-
360
- div.model-definition {
361
- background-color: #f6f6f6;
362
- border: 1px solid black;
363
- border-radius: 5px;
364
- padding: 0 10px;
365
- z-index: 5;
366
- font-size: 0.9em;
367
- text-align: left;
368
- font-weight: normal;
369
- position: absolute;
370
- left: 100%;
371
- top: 0;
372
- display: none;
373
- }
374
- """ )
167
+ with open ("web/src/data/manifest.json" , "w" ) as f :
168
+ json .dump (manifest , f , indent = 2 )
169
+
170
+ # Process model definitions
171
+ model_keys = list (new_data .keys ())
172
+ # technically we can also get it this way
173
+ # model_keys = run_and_capture([*JULIA_COMMAND, "--list-model-keys"]).splitlines()
174
+ model_definitions = {}
175
+ for model_key in model_keys :
176
+ model_definitions [model_key ] = get_model_definition (model_key )
177
+ with open ("web/src/data/model_definitions.json" , "w" ) as f :
178
+ json .dump (model_definitions , f , indent = 2 )
375
179
376
180
377
181
def parse_arguments ():
0 commit comments