Skip to content

Commit 472dfe7

Browse files
torfjeldedevmotion
andcommitted
Remove sort from setval! and others (#251)
This PR addresses #250 by simply removing the `sort` as mentioned in the issue. This is also related to TuringLang/Turing.jl#1626, which if addressed, will fix a lot of issues (unless the user then decides to sort their chain themselves). Co-authored-by: David Widmann <[email protected]>
1 parent 5ffbbd2 commit 472dfe7

File tree

6 files changed

+68
-31
lines changed

6 files changed

+68
-31
lines changed

Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.10.20"
3+
version = "0.11.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -9,7 +9,6 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
99
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1010
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1111
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
12-
NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85"
1312
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1413
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1514

@@ -20,6 +19,5 @@ Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9"
2019
ChainRulesCore = "0.9.7"
2120
Distributions = "0.23.8, 0.24, 0.25"
2221
MacroTools = "0.5.6"
23-
NaturalSort = "1"
2422
ZygoteRules = "0.2"
2523
julia = "1.3"

src/DynamicPPL.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ using Bijectors
77

88
using AbstractMCMC: AbstractMCMC
99
using ChainRulesCore: ChainRulesCore
10-
using NaturalSort: NaturalSort
1110
using MacroTools: MacroTools
1211
using ZygoteRules: ZygoteRules
1312

src/varinfo.jl

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1325,6 +1325,7 @@ end
13251325

13261326
"""
13271327
setval!(vi::AbstractVarInfo, x)
1328+
setval!(vi::AbstractVarInfo, values, keys)
13281329
setval!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int)
13291330
13301331
Set the values in `vi` to the provided values and leave those which are not present in
@@ -1379,20 +1380,18 @@ julia> var_info[@varname(x[1])] # [✓] unchanged
13791380
-0.22312984965118443
13801381
```
13811382
"""
1382-
setval!(vi::AbstractVarInfo, x) = _apply!(_setval_kernel!, vi, values(x), keys(x))
1383+
setval!(vi::AbstractVarInfo, x) = setval!(vi, values(x), keys(x))
1384+
setval!(vi::AbstractVarInfo, values, keys) = _apply!(_setval_kernel!, vi, values, keys)
13831385
function setval!(
13841386
vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int
13851387
)
1386-
return _apply!(
1387-
_setval_kernel!, vi, chains.value[sample_idx, :, chain_idx], keys(chains)
1388-
)
1388+
return setval!(vi, chains.value[sample_idx, :, chain_idx], keys(chains))
13891389
end
13901390

13911391
function _setval_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys)
13921392
indices = findall(Base.Fix1(subsumes_string, string(vn)), keys)
13931393
if !isempty(indices)
1394-
sorted_indices = sort!(indices; by=i -> keys[i], lt=NaturalSort.natural)
1395-
val = reduce(vcat, values[sorted_indices])
1394+
val = reduce(vcat, values[indices])
13961395
setval!(vi, val, vn)
13971396
settrans!(vi, false, vn)
13981397
end
@@ -1402,6 +1401,7 @@ end
14021401

14031402
"""
14041403
setval_and_resample!(vi::AbstractVarInfo, x)
1404+
setval_and_resample!(vi::AbstractVarInfo, values, keys)
14051405
setval_and_resample!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx, chain_idx)
14061406
14071407
Set the values in `vi` to the provided values and those which are not present
@@ -1458,24 +1458,21 @@ julia> var_info[@varname(x[1])] # [✓] changed
14581458
- [`setval!`](@ref)
14591459
"""
14601460
function setval_and_resample!(vi::AbstractVarInfo, x)
1461-
return _apply!(_setval_and_resample_kernel!, vi, values(x), keys(x))
1461+
return setval_and_resample!(vi, values(x), keys(x))
1462+
end
1463+
function setval_and_resample!(vi::AbstractVarInfo, values, keys)
1464+
return _apply!(_setval_and_resample_kernel!, vi, values, keys)
14621465
end
14631466
function setval_and_resample!(
14641467
vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int
14651468
)
1466-
return _apply!(
1467-
_setval_and_resample_kernel!,
1468-
vi,
1469-
chains.value[sample_idx, :, chain_idx],
1470-
keys(chains),
1471-
)
1469+
return setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains))
14721470
end
14731471

14741472
function _setval_and_resample_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys)
14751473
indices = findall(Base.Fix1(subsumes_string, string(vn)), keys)
14761474
if !isempty(indices)
1477-
sorted_indices = sort!(indices; by=i -> keys[i], lt=NaturalSort.natural)
1478-
val = reduce(vcat, values[sorted_indices])
1475+
val = reduce(vcat, values[indices])
14791476
setval!(vi, val, vn)
14801477
settrans!(vi, false, vn)
14811478
else

test/runtests.jl

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,26 @@ include("test_util.jl")
6060

6161
if GROUP == "All" || GROUP == "Downstream"
6262
@testset "turing" begin
63-
# activate separate test environment
64-
Pkg.activate(DIRECTORY_Turing_tests)
65-
Pkg.develop(PackageSpec(; path=DIRECTORY_DynamicPPL))
66-
Pkg.instantiate()
63+
try
64+
# activate separate test environment
65+
Pkg.activate(DIRECTORY_Turing_tests)
66+
Pkg.develop(PackageSpec(; path=DIRECTORY_DynamicPPL))
67+
Pkg.instantiate()
6768

68-
# make sure that the new environment is considered `using` and `import` statements
69-
# (not added automatically on Julia 1.3, see e.g. PR #209)
70-
if !(joinpath(DIRECTORY_Turing_tests, "Project.toml") in Base.load_path())
71-
pushfirst!(LOAD_PATH, DIRECTORY_Turing_tests)
72-
end
69+
# make sure that the new environment is considered `using` and `import` statements
70+
# (not added automatically on Julia 1.3, see e.g. PR #209)
71+
if !(joinpath(DIRECTORY_Turing_tests, "Project.toml") in Base.load_path())
72+
pushfirst!(LOAD_PATH, DIRECTORY_Turing_tests)
73+
end
7374

74-
include(joinpath("turing", "runtests.jl"))
75+
include(joinpath("turing", "runtests.jl"))
76+
catch err
77+
err isa Pkg.Resolve.ResolverError || rethrow()
78+
# If we can't resolve that means this is incompatible by SemVer and this is fine
79+
# It means we marked this as a breaking change, so we don't need to worry about
80+
# Mistakenly introducing a breaking change, as we have intentionally made one
81+
@info "Not compatible with this release. No problem." exception = err
82+
end
7583
end
7684
end
7785
end

test/turing/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
55
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
66

77
[compat]
8-
DynamicPPL = "0.10"
8+
DynamicPPL = "0.11"
99
Turing = "0.15"
1010
julia = "1.3"

test/varinfo.jl

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,20 @@
193193
end
194194
@test vicopy[s_vns] == vi[s_vns]
195195

196+
# Ordering is NOT preserved => fails for multivariate model.
196197
DynamicPPL.setval!(
197198
vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...)
198199
)
199-
@test vicopy[m_vns] == 1:5
200+
if model == model_uv
201+
@test vicopy[m_vns] == 1:5
202+
else
203+
@test vicopy[m_vns] == [1, 3, 5, 4, 2]
204+
end
200205
@test vicopy[s_vns] == vi[s_vns]
201206

207+
DynamicPPL.setval!(
208+
vicopy, (; (Symbol("m[$i]") => i for i in (1, 2, 3, 4, 5))...)
209+
)
202210
DynamicPPL.setval!(vicopy, (s=42,))
203211
@test vicopy[m_vns] == 1:5
204212
@test vicopy[s_vns] == 42
@@ -222,10 +230,23 @@
222230
end
223231
@test vicopy[s_vns] != vi[s_vns]
224232

233+
# Ordering is NOT preserved.
225234
DynamicPPL.setval_and_resample!(
226235
vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...)
227236
)
228237
model(vicopy)
238+
if model == model_uv
239+
@test vicopy[m_vns] == 1:5
240+
else
241+
@test vicopy[m_vns] == [1, 3, 5, 4, 2]
242+
end
243+
@test vicopy[s_vns] != vi[s_vns]
244+
245+
# Correct ordering.
246+
DynamicPPL.setval_and_resample!(
247+
vicopy, (; (Symbol("m[$i]") => i for i in (1, 2, 3, 4, 5))...)
248+
)
249+
model(vicopy)
229250
@test vicopy[m_vns] == 1:5
230251
@test vicopy[s_vns] != vi[s_vns]
231252

@@ -235,5 +256,19 @@
235256
@test vicopy[s_vns] == 42
236257
end
237258
end
259+
260+
# https://github.com/TuringLang/DynamicPPL.jl/issues/250
261+
@model function demo()
262+
return x ~ filldist(MvNormal([1, 100], 1), 2)
263+
end
264+
265+
vi = VarInfo(demo())
266+
vals_prev = vi.metadata.x.vals
267+
ks = [@varname(x[1, 1]), @varname(x[2, 1]), @varname(x[1, 2]), @varname(x[2, 2])]
268+
DynamicPPL.setval!(vi, vi.metadata.x.vals, ks)
269+
@test vals_prev == vi.metadata.x.vals
270+
271+
DynamicPPL.setval_and_resample!(vi, vi.metadata.x.vals, ks)
272+
@test vals_prev == vi.metadata.x.vals
238273
end
239274
end

0 commit comments

Comments
 (0)