|
1 |
| -module TestUtils |
| 1 | +module TestExtUtils |
| 2 | + |
| 3 | +################################################### |
| 4 | +# These used to be in DPPL/src/test_utils.jl ###### |
| 5 | +################################################### |
2 | 6 |
|
3 | 7 | using AbstractMCMC
|
4 | 8 | using DynamicPPL
|
@@ -1097,4 +1101,121 @@ function DynamicPPL.dot_tilde_observe(
|
1097 | 1101 | return logp * context.mod, vi
|
1098 | 1102 | end
|
1099 | 1103 |
|
| 1104 | +################################################### |
| 1105 | +# These used to be in DPPL/test/test_util.jl ###### |
| 1106 | +################################################### |
| 1107 | + |
| 1108 | +# default model |
| 1109 | +@model function gdemo_d() |
| 1110 | + s ~ InverseGamma(2, 3) |
| 1111 | + m ~ Normal(0, sqrt(s)) |
| 1112 | + 1.5 ~ Normal(m, sqrt(s)) |
| 1113 | + 2.0 ~ Normal(m, sqrt(s)) |
| 1114 | + return s, m |
| 1115 | +end |
| 1116 | +const gdemo_default = gdemo_d() |
| 1117 | + |
| 1118 | +function test_model_ad(model, logp_manual) |
| 1119 | + vi = VarInfo(model) |
| 1120 | + x = DynamicPPL.getall(vi) |
| 1121 | + |
| 1122 | + # Log probabilities using the model. |
| 1123 | + ℓ = DynamicPPL.LogDensityFunction(model, vi) |
| 1124 | + logp_model = Base.Fix1(LogDensityProblems.logdensity, ℓ) |
| 1125 | + |
| 1126 | + # Check that both functions return the same values. |
| 1127 | + lp = logp_manual(x) |
| 1128 | + @test logp_model(x) ≈ lp |
| 1129 | + |
| 1130 | + # Gradients based on the manual implementation. |
| 1131 | + grad = ForwardDiff.gradient(logp_manual, x) |
| 1132 | + |
| 1133 | + y, back = Tracker.forward(logp_manual, x) |
| 1134 | + @test Tracker.data(y) ≈ lp |
| 1135 | + @test Tracker.data(back(1)[1]) ≈ grad |
| 1136 | + |
| 1137 | + y, back = Zygote.pullback(logp_manual, x) |
| 1138 | + @test y ≈ lp |
| 1139 | + @test back(1)[1] ≈ grad |
| 1140 | + |
| 1141 | + # Gradients based on the model. |
| 1142 | + @test ForwardDiff.gradient(logp_model, x) ≈ grad |
| 1143 | + |
| 1144 | + y, back = Tracker.forward(logp_model, x) |
| 1145 | + @test Tracker.data(y) ≈ lp |
| 1146 | + @test Tracker.data(back(1)[1]) ≈ grad |
| 1147 | + |
| 1148 | + y, back = Zygote.pullback(logp_model, x) |
| 1149 | + @test y ≈ lp |
| 1150 | + @test back(1)[1] ≈ grad |
| 1151 | +end |
| 1152 | + |
| 1153 | +""" |
| 1154 | + test_setval!(model, chain; sample_idx = 1, chain_idx = 1) |
| 1155 | +
|
| 1156 | +Test `setval!` on `model` and `chain`. |
| 1157 | +
|
| 1158 | +Worth noting that this only supports models containing symbols of the forms |
| 1159 | +`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. |
| 1160 | +""" |
| 1161 | +function test_setval!(model, chain; sample_idx=1, chain_idx=1) |
| 1162 | + var_info = VarInfo(model) |
| 1163 | + spl = SampleFromPrior() |
| 1164 | + θ_old = var_info[spl] |
| 1165 | + DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) |
| 1166 | + θ_new = var_info[spl] |
| 1167 | + @test θ_old != θ_new |
| 1168 | + vals = DynamicPPL.values_as(var_info, OrderedDict) |
| 1169 | + iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) |
| 1170 | + for (n, v) in mapreduce(collect, vcat, iters) |
| 1171 | + n = string(n) |
| 1172 | + if Symbol(n) ∉ keys(chain) |
| 1173 | + # Assume it's a group |
| 1174 | + chain_val = vec( |
| 1175 | + MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx] |
| 1176 | + ) |
| 1177 | + v_true = vec(v) |
| 1178 | + else |
| 1179 | + chain_val = chain[sample_idx, n, chain_idx] |
| 1180 | + v_true = v |
| 1181 | + end |
| 1182 | + |
| 1183 | + @test v_true == chain_val |
| 1184 | + end |
1100 | 1185 | end
|
| 1186 | + |
| 1187 | +""" |
| 1188 | + short_varinfo_name(vi::AbstractVarInfo) |
| 1189 | +
|
| 1190 | +Return string representing a short description of `vi`. |
| 1191 | +""" |
| 1192 | +short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) = |
| 1193 | + "threadsafe($(short_varinfo_name(vi.varinfo)))" |
| 1194 | +function short_varinfo_name(vi::TypedVarInfo) |
| 1195 | + DynamicPPL.has_varnamedvector(vi) && return "TypedVarInfo with VarNamedVector" |
| 1196 | + return "TypedVarInfo" |
| 1197 | +end |
| 1198 | +short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo" |
| 1199 | +short_varinfo_name(::DynamicPPL.VectorVarInfo) = "VectorVarInfo" |
| 1200 | +short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" |
| 1201 | +short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}" |
| 1202 | +function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector}) |
| 1203 | + return "SimpleVarInfo{<:VarNamedVector}" |
| 1204 | +end |
| 1205 | + |
| 1206 | +# convenient functions for testing model.jl |
| 1207 | +# function to modify the representation of values based on their length |
| 1208 | +function modify_value_representation(nt::NamedTuple) |
| 1209 | + modified_nt = NamedTuple() |
| 1210 | + for (key, value) in zip(keys(nt), values(nt)) |
| 1211 | + if length(value) == 1 # Scalar value |
| 1212 | + modified_value = value[1] |
| 1213 | + else # Non-scalar value |
| 1214 | + modified_value = value |
| 1215 | + end |
| 1216 | + modified_nt = merge(modified_nt, (key => modified_value,)) |
| 1217 | + end |
| 1218 | + return modified_nt |
| 1219 | +end |
| 1220 | + |
| 1221 | +end # module TestExtUtils |
0 commit comments