Skip to content

Commit e9a6ea7

Browse files
authored
Merge pull request #16 from TuringLang/py/hover
Add hoverable model source code
2 parents d16f464 + 780a7d0 commit e9a6ea7

File tree

3 files changed

+109
-42
lines changed

3 files changed

+109
-42
lines changed

ad.py

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,20 +93,40 @@ def run_ad(args):
9393
append_to_github_output("results", results)
9494

9595

96+
def get_model_definition(model_key):
97+
"""Get the model definition from the Julia script."""
98+
lines = []
99+
submodels = []
100+
record = False
101+
with open("models.jl", "r") as file:
102+
for line in file:
103+
line = line.rstrip()
104+
if line.startswith(f"@model function {model_key}"):
105+
record = True
106+
if record:
107+
lines.append(line)
108+
109+
if "to_submodel" in line:
110+
submodel_name = line.split("to_submodel(")[1].split("(")[0]
111+
submodels.append(submodel_name)
112+
if line == "end":
113+
break
114+
for submodel in submodels:
115+
lines = [get_model_definition(submodel), *lines]
116+
return "<br>".join(lines)
117+
118+
96119
def html(_args):
97120
## Here you can register known errors that have been reported on GitHub /
98121
## have otherwise been documented. They will be turned into links in the table.
99122

100123
ENZYME_RVS_ONE_PARAM = "https://github.com/EnzymeAD/Enzyme.jl/issues/2337"
101124
ENZYME_FWD_BLAS = "https://github.com/EnzymeAD/Enzyme.jl/issues/1995"
125+
MOONCAKE_THREADED = "https://github.com/chalk-lab/Mooncake.jl/issues/570"
102126
KNOWN_ERRORS = {
103-
("assume_beta", "EnzymeReverse"): ENZYME_RVS_ONE_PARAM,
104-
("assume_dirichlet", "EnzymeReverse"): ENZYME_RVS_ONE_PARAM,
105-
("assume_lkjcholu", "EnzymeReverse"): ENZYME_RVS_ONE_PARAM,
106-
("assume_normal", "EnzymeReverse"): ENZYME_RVS_ONE_PARAM,
107-
("assume_wishart", "EnzymeReverse"): ENZYME_RVS_ONE_PARAM,
108127
("assume_mvnormal", "EnzymeForward"): ENZYME_FWD_BLAS,
109128
("assume_wishart", "EnzymeForward"): ENZYME_FWD_BLAS,
129+
("multithreaded", "Mooncake"): MOONCAKE_THREADED,
110130
}
111131

112132

@@ -198,6 +218,8 @@ def html(_args):
198218
</ul>
199219
200220
<h2>Results</h2>
221+
222+
<p>(New: You can also hover over the model names to see their definitions.)</p>
201223
""")
202224

203225
# Table header
@@ -211,7 +233,7 @@ def html(_args):
211233
for model_name in models:
212234
ad_results = results[model_name]
213235
f.write("\n<tr>")
214-
f.write(f"<td>{model_name}</td>")
236+
f.write(f'<td>{model_name}<div class="model-definition"><pre>{get_model_definition(model_name)}</pre></div></td>')
215237
for adtype in adtypes:
216238
ad_result = ad_results[adtype]
217239
try:
@@ -272,6 +294,7 @@ def html(_args):
272294
td, th {
273295
border: 1px solid black;
274296
padding: 0px 10px;
297+
white-space: nowrap;
275298
}
276299
277300
th {
@@ -291,6 +314,7 @@ def html(_args):
291314
font-family: "Fira Sans", sans-serif;
292315
font-weight: 700;
293316
background-color: #ececec;
317+
position: relative;
294318
}
295319
296320
tr > th:first-child {
@@ -321,6 +345,29 @@ def html(_args):
321345
a.issue:visited {
322346
color: #880000;
323347
}
348+
349+
div.model-definition {
350+
background-color: #f6f6f6;
351+
border: 1px solid black;
352+
border-radius: 5px;
353+
padding: 0 10px;
354+
z-index: 5;
355+
font-size: 0.9em;
356+
text-align: left;
357+
font-weight: normal;
358+
position: absolute;
359+
left: 100%;
360+
top: 0;
361+
display: none;
362+
}
363+
364+
td:hover {
365+
background-color: #f6f6f6;
366+
}
367+
368+
td:hover > div.model-definition {
369+
display: block;
370+
}
324371
""")
325372

326373

main.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ elseif length(ARGS) == 3 && ARGS[1] == "--run"
4747
# If reached here - nothing went wrong
4848
@printf("%.3f", result.time_vs_primal)
4949
catch e
50+
@show e
5051
if e isa ADIncorrectException
5152
# First check for completely incorrect ones
5253
for (a, b) in zip(e.grad_expected, e.grad_actual)

models.jl

Lines changed: 55 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module Models
22

33
using DynamicPPL
44
using Distributions
5-
using LinearAlgebra: I
5+
using LinearAlgebra:I
66

77
export MODELS
88

@@ -94,8 +94,6 @@ end
9494
end
9595
add_model!(MODELS, observe_submodel())
9696

97-
# This one fails with Enzyme ...
98-
9997
@model function dot_assume_observe_index(x=[1.5, 2.0, 2.5], ::Type{TV}=Vector{Float64}) where {TV}
10098
a = TV(undef, length(x))
10199
a .~ Normal()
@@ -105,46 +103,67 @@ add_model!(MODELS, observe_submodel())
105103
end
106104
add_model!(MODELS, dot_assume_observe_index())
107105

108-
# Add models with different distributions
109-
110-
DISTRIBUTIONS = Dict(
111-
# Univariate
112-
:assume_normal => Normal(),
113-
:assume_beta => Beta(2, 2),
114-
# Multivariate
115-
:assume_mvnormal => MvNormal([0.0, 0.0], [1.0 0.5; 0.5 1.0]),
116-
:assume_dirichlet => Dirichlet([1.0, 5.0]),
117-
# Matrixvariate
118-
:assume_wishart => Wishart(7, [1.0 0.5; 0.5 1.0]),
119-
:assume_lkjcholu => LKJCholesky(5, 1.0, 'U'),
120-
)
121-
122-
for (name, distribution) in DISTRIBUTIONS
123-
@eval begin
124-
@model function $name()
125-
a ~ $distribution
126-
end
127-
add_model!(MODELS, $name())
106+
@model function assume_normal()
107+
a ~ Normal()
108+
end
109+
add_model!(MODELS, assume_normal())
110+
111+
@model function assume_beta()
112+
a ~ Beta(2, 2)
113+
end
114+
add_model!(MODELS, assume_beta())
115+
116+
@model function assume_mvnormal()
117+
a ~ MvNormal([0.0, 0.0], [1.0 0.5; 0.5 1.0])
118+
end
119+
add_model!(MODELS, assume_mvnormal())
120+
121+
@model function assume_dirichlet()
122+
a ~ Dirichlet([1.0, 5.0])
123+
end
124+
add_model!(MODELS, assume_dirichlet())
125+
126+
@model function assume_wishart()
127+
a ~ Wishart(7, [1.0 0.5; 0.5 1.0])
128+
end
129+
add_model!(MODELS, assume_wishart())
130+
131+
@model function assume_lkjcholu()
132+
a ~ LKJCholesky(5, 1.0, 'U')
133+
end
134+
add_model!(MODELS, assume_lkjcholu())
135+
136+
@model function n010(::Type{TV}=Vector{Float64}) where {TV}
137+
a = TV(undef, 10)
138+
for i in eachindex(a)
139+
a[i] ~ Normal()
128140
end
129141
end
142+
add_model!(MODELS, n010())
130143

131-
# Add models with different sizes
144+
@model function n050(::Type{TV}=Vector{Float64}) where {TV}
145+
a = TV(undef, 50)
146+
for i in eachindex(a)
147+
a[i] ~ Normal()
148+
end
149+
end
150+
add_model!(MODELS, n050())
132151

133-
NS = [10, 50, 100, 500]
152+
@model function n100(::Type{TV}=Vector{Float64}) where {TV}
153+
a = TV(undef, 100)
154+
for i in eachindex(a)
155+
a[i] ~ Normal()
156+
end
157+
end
158+
add_model!(MODELS, n100())
134159

135-
for n in NS
136-
# pad to make sure they sort correctly alphabetically
137-
name = Symbol("n$(lpad(n, 3, "0"))")
138-
@eval begin
139-
@model function $name(::Type{TV}=Vector{Float64}) where {TV}
140-
a = TV(undef, $n)
141-
for i in eachindex(a)
142-
a[i] ~ Normal()
143-
end
144-
end
145-
add_model!(MODELS, $name())
160+
@model function n500(::Type{TV}=Vector{Float64}) where {TV}
161+
a = TV(undef, 500)
162+
for i in eachindex(a)
163+
a[i] ~ Normal()
146164
end
147165
end
166+
add_model!(MODELS, n500())
148167

149168
@model function multithreaded(x)
150169
a ~ Normal()

0 commit comments

Comments
 (0)