|
1 |
| -module TestExtUtils |
2 |
| - |
3 |
| -################################################### |
4 |
| -# These used to be in DPPL/src/test_utils.jl ###### |
5 |
| -################################################### |
| 1 | +module TestUtils |
6 | 2 |
|
7 | 3 | using AbstractMCMC
|
8 | 4 | using DynamicPPL
|
@@ -1101,123 +1097,4 @@ function DynamicPPL.dot_tilde_observe(
|
1101 | 1097 | return logp * context.mod, vi
|
1102 | 1098 | end
|
1103 | 1099 |
|
1104 |
| - |
1105 |
| - |
1106 |
| -################################################### |
1107 |
| -# These used to be in DPPL/test/test_util.jl ###### |
1108 |
| -################################################### |
1109 |
| - |
1110 |
| -# default model |
1111 |
| -@model function gdemo_d() |
1112 |
| - s ~ InverseGamma(2, 3) |
1113 |
| - m ~ Normal(0, sqrt(s)) |
1114 |
| - 1.5 ~ Normal(m, sqrt(s)) |
1115 |
| - 2.0 ~ Normal(m, sqrt(s)) |
1116 |
| - return s, m |
1117 |
| -end |
1118 |
| -const gdemo_default = gdemo_d() |
1119 |
| - |
1120 |
| -function test_model_ad(model, logp_manual) |
1121 |
| - vi = VarInfo(model) |
1122 |
| - x = DynamicPPL.getall(vi) |
1123 |
| - |
1124 |
| - # Log probabilities using the model. |
1125 |
| - ℓ = DynamicPPL.LogDensityFunction(model, vi) |
1126 |
| - logp_model = Base.Fix1(LogDensityProblems.logdensity, ℓ) |
1127 |
| - |
1128 |
| - # Check that both functions return the same values. |
1129 |
| - lp = logp_manual(x) |
1130 |
| - @test logp_model(x) ≈ lp |
1131 |
| - |
1132 |
| - # Gradients based on the manual implementation. |
1133 |
| - grad = ForwardDiff.gradient(logp_manual, x) |
1134 |
| - |
1135 |
| - y, back = Tracker.forward(logp_manual, x) |
1136 |
| - @test Tracker.data(y) ≈ lp |
1137 |
| - @test Tracker.data(back(1)[1]) ≈ grad |
1138 |
| - |
1139 |
| - y, back = Zygote.pullback(logp_manual, x) |
1140 |
| - @test y ≈ lp |
1141 |
| - @test back(1)[1] ≈ grad |
1142 |
| - |
1143 |
| - # Gradients based on the model. |
1144 |
| - @test ForwardDiff.gradient(logp_model, x) ≈ grad |
1145 |
| - |
1146 |
| - y, back = Tracker.forward(logp_model, x) |
1147 |
| - @test Tracker.data(y) ≈ lp |
1148 |
| - @test Tracker.data(back(1)[1]) ≈ grad |
1149 |
| - |
1150 |
| - y, back = Zygote.pullback(logp_model, x) |
1151 |
| - @test y ≈ lp |
1152 |
| - @test back(1)[1] ≈ grad |
1153 |
| -end |
1154 |
| - |
1155 |
| -""" |
1156 |
| - test_setval!(model, chain; sample_idx = 1, chain_idx = 1) |
1157 |
| -
|
1158 |
| -Test `setval!` on `model` and `chain`. |
1159 |
| -
|
1160 |
| -Worth noting that this only supports models containing symbols of the forms |
1161 |
| -`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. |
1162 |
| -""" |
1163 |
| -function test_setval!(model, chain; sample_idx=1, chain_idx=1) |
1164 |
| - var_info = VarInfo(model) |
1165 |
| - spl = SampleFromPrior() |
1166 |
| - θ_old = var_info[spl] |
1167 |
| - DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) |
1168 |
| - θ_new = var_info[spl] |
1169 |
| - @test θ_old != θ_new |
1170 |
| - vals = DynamicPPL.values_as(var_info, OrderedDict) |
1171 |
| - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) |
1172 |
| - for (n, v) in mapreduce(collect, vcat, iters) |
1173 |
| - n = string(n) |
1174 |
| - if Symbol(n) ∉ keys(chain) |
1175 |
| - # Assume it's a group |
1176 |
| - chain_val = vec( |
1177 |
| - MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx] |
1178 |
| - ) |
1179 |
| - v_true = vec(v) |
1180 |
| - else |
1181 |
| - chain_val = chain[sample_idx, n, chain_idx] |
1182 |
| - v_true = v |
1183 |
| - end |
1184 |
| - |
1185 |
| - @test v_true == chain_val |
1186 |
| - end |
1187 |
| -end |
1188 |
| - |
1189 |
| -""" |
1190 |
| - short_varinfo_name(vi::AbstractVarInfo) |
1191 |
| -
|
1192 |
| -Return string representing a short description of `vi`. |
1193 |
| -""" |
1194 |
| -short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) = |
1195 |
| - "threadsafe($(short_varinfo_name(vi.varinfo)))" |
1196 |
| -function short_varinfo_name(vi::TypedVarInfo) |
1197 |
| - DynamicPPL.has_varnamedvector(vi) && return "TypedVarInfo with VarNamedVector" |
1198 |
| - return "TypedVarInfo" |
1199 |
| -end |
1200 |
| -short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo" |
1201 |
| -short_varinfo_name(::DynamicPPL.VectorVarInfo) = "VectorVarInfo" |
1202 |
| -short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" |
1203 |
| -short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}" |
1204 |
| -function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector}) |
1205 |
| - return "SimpleVarInfo{<:VarNamedVector}" |
1206 |
| -end |
1207 |
| - |
1208 |
| -# convenient functions for testing model.jl |
1209 |
| -# function to modify the representation of values based on their length |
1210 |
| -function modify_value_representation(nt::NamedTuple) |
1211 |
| - modified_nt = NamedTuple() |
1212 |
| - for (key, value) in zip(keys(nt), values(nt)) |
1213 |
| - if length(value) == 1 # Scalar value |
1214 |
| - modified_value = value[1] |
1215 |
| - else # Non-scalar value |
1216 |
| - modified_value = value |
1217 |
| - end |
1218 |
| - modified_nt = merge(modified_nt, (key => modified_value,)) |
1219 |
| - end |
1220 |
| - return modified_nt |
1221 | 1100 | end
|
1222 |
| - |
1223 |
| -end # module TestExtUtils |
0 commit comments