|
1 |
| -module TestUtils |
| 1 | +module TestExtUtils |
| 2 | + |
| 3 | +################################################### |
| 4 | +# These used to be in DPPL/src/test_utils.jl ###### |
| 5 | +################################################### |
2 | 6 |
|
3 | 7 | using AbstractMCMC
|
4 | 8 | using DynamicPPL
|
@@ -1097,4 +1101,123 @@ function DynamicPPL.dot_tilde_observe(
|
1097 | 1101 | return logp * context.mod, vi
|
1098 | 1102 | end
|
1099 | 1103 |
|
| 1104 | + |
| 1105 | + |
| 1106 | +################################################### |
| 1107 | +# These used to be in DPPL/test/test_util.jl ###### |
| 1108 | +################################################### |
| 1109 | + |
| 1110 | +# default model |
| 1111 | +@model function gdemo_d() |
| 1112 | + s ~ InverseGamma(2, 3) |
| 1113 | + m ~ Normal(0, sqrt(s)) |
| 1114 | + 1.5 ~ Normal(m, sqrt(s)) |
| 1115 | + 2.0 ~ Normal(m, sqrt(s)) |
| 1116 | + return s, m |
| 1117 | +end |
| 1118 | +const gdemo_default = gdemo_d() |
| 1119 | + |
| 1120 | +function test_model_ad(model, logp_manual) |
| 1121 | + vi = VarInfo(model) |
| 1122 | + x = DynamicPPL.getall(vi) |
| 1123 | + |
| 1124 | + # Log probabilities using the model. |
| 1125 | + ℓ = DynamicPPL.LogDensityFunction(model, vi) |
| 1126 | + logp_model = Base.Fix1(LogDensityProblems.logdensity, ℓ) |
| 1127 | + |
| 1128 | + # Check that both functions return the same values. |
| 1129 | + lp = logp_manual(x) |
| 1130 | + @test logp_model(x) ≈ lp |
| 1131 | + |
| 1132 | + # Gradients based on the manual implementation. |
| 1133 | + grad = ForwardDiff.gradient(logp_manual, x) |
| 1134 | + |
| 1135 | + y, back = Tracker.forward(logp_manual, x) |
| 1136 | + @test Tracker.data(y) ≈ lp |
| 1137 | + @test Tracker.data(back(1)[1]) ≈ grad |
| 1138 | + |
| 1139 | + y, back = Zygote.pullback(logp_manual, x) |
| 1140 | + @test y ≈ lp |
| 1141 | + @test back(1)[1] ≈ grad |
| 1142 | + |
| 1143 | + # Gradients based on the model. |
| 1144 | + @test ForwardDiff.gradient(logp_model, x) ≈ grad |
| 1145 | + |
| 1146 | + y, back = Tracker.forward(logp_model, x) |
| 1147 | + @test Tracker.data(y) ≈ lp |
| 1148 | + @test Tracker.data(back(1)[1]) ≈ grad |
| 1149 | + |
| 1150 | + y, back = Zygote.pullback(logp_model, x) |
| 1151 | + @test y ≈ lp |
| 1152 | + @test back(1)[1] ≈ grad |
| 1153 | +end |
| 1154 | + |
| 1155 | +""" |
| 1156 | + test_setval!(model, chain; sample_idx = 1, chain_idx = 1) |
| 1157 | +
|
| 1158 | +Test `setval!` on `model` and `chain`. |
| 1159 | +
|
| 1160 | +Worth noting that this only supports models containing symbols of the forms |
| 1161 | +`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. |
| 1162 | +""" |
| 1163 | +function test_setval!(model, chain; sample_idx=1, chain_idx=1) |
| 1164 | + var_info = VarInfo(model) |
| 1165 | + spl = SampleFromPrior() |
| 1166 | + θ_old = var_info[spl] |
| 1167 | + DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) |
| 1168 | + θ_new = var_info[spl] |
| 1169 | + @test θ_old != θ_new |
| 1170 | + vals = DynamicPPL.values_as(var_info, OrderedDict) |
| 1171 | + iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) |
| 1172 | + for (n, v) in mapreduce(collect, vcat, iters) |
| 1173 | + n = string(n) |
| 1174 | + if Symbol(n) ∉ keys(chain) |
| 1175 | + # Assume it's a group |
| 1176 | + chain_val = vec( |
| 1177 | + MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx] |
| 1178 | + ) |
| 1179 | + v_true = vec(v) |
| 1180 | + else |
| 1181 | + chain_val = chain[sample_idx, n, chain_idx] |
| 1182 | + v_true = v |
| 1183 | + end |
| 1184 | + |
| 1185 | + @test v_true == chain_val |
| 1186 | + end |
| 1187 | +end |
| 1188 | + |
| 1189 | +""" |
| 1190 | + short_varinfo_name(vi::AbstractVarInfo) |
| 1191 | +
|
| 1192 | +Return string representing a short description of `vi`. |
| 1193 | +""" |
| 1194 | +short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) = |
| 1195 | + "threadsafe($(short_varinfo_name(vi.varinfo)))" |
| 1196 | +function short_varinfo_name(vi::TypedVarInfo) |
| 1197 | + DynamicPPL.has_varnamedvector(vi) && return "TypedVarInfo with VarNamedVector" |
| 1198 | + return "TypedVarInfo" |
| 1199 | +end |
| 1200 | +short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo" |
| 1201 | +short_varinfo_name(::DynamicPPL.VectorVarInfo) = "VectorVarInfo" |
| 1202 | +short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" |
| 1203 | +short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}" |
| 1204 | +function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector}) |
| 1205 | + return "SimpleVarInfo{<:VarNamedVector}" |
| 1206 | +end |
| 1207 | + |
| 1208 | +# convenient functions for testing model.jl |
| 1209 | +# function to modify the representation of values based on their length |
| 1210 | +function modify_value_representation(nt::NamedTuple) |
| 1211 | + modified_nt = NamedTuple() |
| 1212 | + for (key, value) in zip(keys(nt), values(nt)) |
| 1213 | + if length(value) == 1 # Scalar value |
| 1214 | + modified_value = value[1] |
| 1215 | + else # Non-scalar value |
| 1216 | + modified_value = value |
| 1217 | + end |
| 1218 | + modified_nt = merge(modified_nt, (key => modified_value,)) |
| 1219 | + end |
| 1220 | + return modified_nt |
1100 | 1221 | end
|
| 1222 | + |
| 1223 | +end # module TestExtUtils |
0 commit comments