Skip to content

Commit 66d3434

Browse files
committed
Add hoverable model source code
1 parent d16f464 commit 66d3434

File tree

3 files changed

+93
-37
lines changed

3 files changed

+93
-37
lines changed

ad.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,22 @@ def run_ad(args):
9393
append_to_github_output("results", results)
9494

9595

96+
def get_model_definition(model_key):
97+
"""Get the model definition from the Julia script."""
98+
lines = []
99+
record = False
100+
with open("models.jl", "r") as file:
101+
for line in file:
102+
line = line.rstrip()
103+
if line.startswith(f"@model function {model_key}"):
104+
record = True
105+
if record:
106+
lines.append(line)
107+
if record and line.strip() == "end":
108+
break
109+
return "<br>".join(lines)
110+
111+
96112
def html(_args):
97113
## Here you can register known errors that have been reported on GitHub /
98114
## have otherwise been documented. They will be turned into links in the table.
@@ -211,7 +227,7 @@ def html(_args):
211227
for model_name in models:
212228
ad_results = results[model_name]
213229
f.write("\n<tr>")
214-
f.write(f"<td>{model_name}</td>")
230+
f.write(f'<td>{model_name}<div class="model-definition"><pre>{get_model_definition(model_name)}</pre></div></td>')
215231
for adtype in adtypes:
216232
ad_result = ad_results[adtype]
217233
try:
@@ -291,6 +307,7 @@ def html(_args):
291307
font-family: "Fira Sans", sans-serif;
292308
font-weight: 700;
293309
background-color: #ececec;
310+
position: relative;
294311
}
295312
296313
tr > th:first-child {
@@ -321,6 +338,25 @@ def html(_args):
321338
a.issue:visited {
322339
color: #880000;
323340
}
341+
342+
div.model-definition {
343+
background-color: #f6f6f6;
344+
border: 1px solid black;
345+
border-radius: 5px;
346+
padding: 0 10px;
347+
z-index: 5;
348+
font-size: 0.9em;
349+
text-align: left;
350+
font-weight: normal;
351+
position: absolute;
352+
left: 100%;
353+
top: 0;
354+
display: none;
355+
}
356+
357+
td:hover > div.model-definition {
358+
display: block;
359+
}
324360
""")
325361

326362

main.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ elseif length(ARGS) == 3 && ARGS[1] == "--run"
4747
# If reached here - nothing went wrong
4848
@printf("%.3f", result.time_vs_primal)
4949
catch e
50+
@show e
5051
if e isa ADIncorrectException
5152
# First check for completely incorrect ones
5253
for (a, b) in zip(e.grad_expected, e.grad_actual)

models.jl

Lines changed: 55 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module Models
22

33
using DynamicPPL
44
using Distributions
5-
using LinearAlgebra: I
5+
using LinearAlgebra:I
66

77
export MODELS
88

@@ -94,8 +94,6 @@ end
9494
end
9595
add_model!(MODELS, observe_submodel())
9696

97-
# This one fails with Enzyme ...
98-
9997
@model function dot_assume_observe_index(x=[1.5, 2.0, 2.5], ::Type{TV}=Vector{Float64}) where {TV}
10098
a = TV(undef, length(x))
10199
a .~ Normal()
@@ -105,46 +103,67 @@ add_model!(MODELS, observe_submodel())
105103
end
106104
add_model!(MODELS, dot_assume_observe_index())
107105

108-
# Add models with different distributions
109-
110-
DISTRIBUTIONS = Dict(
111-
# Univariate
112-
:assume_normal => Normal(),
113-
:assume_beta => Beta(2, 2),
114-
# Multivariate
115-
:assume_mvnormal => MvNormal([0.0, 0.0], [1.0 0.5; 0.5 1.0]),
116-
:assume_dirichlet => Dirichlet([1.0, 5.0]),
117-
# Matrixvariate
118-
:assume_wishart => Wishart(7, [1.0 0.5; 0.5 1.0]),
119-
:assume_lkjcholu => LKJCholesky(5, 1.0, 'U'),
120-
)
121-
122-
for (name, distribution) in DISTRIBUTIONS
123-
@eval begin
124-
@model function $name()
125-
a ~ $distribution
126-
end
127-
add_model!(MODELS, $name())
106+
@model function assume_normal()
107+
a ~ Normal()
108+
end
109+
add_model!(MODELS, assume_normal())
110+
111+
@model function assume_beta()
112+
a ~ Beta(2, 2)
113+
end
114+
add_model!(MODELS, assume_beta())
115+
116+
@model function assume_mvnormal()
117+
a ~ MvNormal([0.0, 0.0], [1.0 0.5; 0.5 1.0])
118+
end
119+
add_model!(MODELS, assume_mvnormal())
120+
121+
@model function assume_dirichlet()
122+
a ~ Dirichlet([1.0, 5.0])
123+
end
124+
add_model!(MODELS, assume_dirichlet())
125+
126+
@model function assume_wishart()
127+
a ~ Wishart(7, [1.0 0.5; 0.5 1.0])
128+
end
129+
add_model!(MODELS, assume_wishart())
130+
131+
@model function assume_lkjcholu()
132+
a ~ LKJCholesky(5, 1.0, 'U')
133+
end
134+
add_model!(MODELS, assume_lkjcholu())
135+
136+
@model function n010(::Type{TV}=Vector{Float64}) where {TV}
137+
a = TV(undef, 10)
138+
for i in eachindex(a)
139+
a[i] ~ Normal()
128140
end
129141
end
142+
add_model!(MODELS, n010())
130143

131-
# Add models with different sizes
144+
@model function n050(::Type{TV}=Vector{Float64}) where {TV}
145+
a = TV(undef, 50)
146+
for i in eachindex(a)
147+
a[i] ~ Normal()
148+
end
149+
end
150+
add_model!(MODELS, n050())
132151

133-
NS = [10, 50, 100, 500]
152+
@model function n100(::Type{TV}=Vector{Float64}) where {TV}
153+
a = TV(undef, 100)
154+
for i in eachindex(a)
155+
a[i] ~ Normal()
156+
end
157+
end
158+
add_model!(MODELS, n100())
134159

135-
for n in NS
136-
# pad to make sure they sort correctly alphabetically
137-
name = Symbol("n$(lpad(n, 3, "0"))")
138-
@eval begin
139-
@model function $name(::Type{TV}=Vector{Float64}) where {TV}
140-
a = TV(undef, $n)
141-
for i in eachindex(a)
142-
a[i] ~ Normal()
143-
end
144-
end
145-
add_model!(MODELS, $name())
160+
@model function n500(::Type{TV}=Vector{Float64}) where {TV}
161+
a = TV(undef, 500)
162+
for i in eachindex(a)
163+
a[i] ~ Normal()
146164
end
147165
end
166+
add_model!(MODELS, n500())
148167

149168
@model function multithreaded(x)
150169
a ~ Normal()

0 commit comments

Comments
 (0)