Skip to content

Commit 57a53e1

Browse files
committed
Merge branch 'main' into breaking
2 parents f20e86c + acac44d commit 57a53e1

File tree

8 files changed

+45
-33
lines changed

8 files changed

+45
-33
lines changed

HISTORY.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ And a couple of more internal changes:
5959
- The model evaluation function, `model.f` for some `model::Model`, no longer takes a context as an argument
6060
- The internal representation and API dealing with submodels (i.e., `ReturnedModelWrapper`, `Sampleable`, `should_auto_prefix`, `is_rhs_model`) has been simplified. If you need to check whether something is a submodel, just use `x isa DynamicPPL.Submodel`. Note that the public API i.e. `to_submodel` remains completely untouched.
6161

62+
## 0.36.14
63+
64+
Added compatibility with [email protected].
65+
66+
## 0.36.13
67+
68+
Added documentation for the `returned(::Model, ::MCMCChains.Chains)` method.
69+
6270
## 0.36.12
6371

6472
Removed several unexported functions.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ DynamicPPLMooncakeExt = ["Mooncake"]
4747
[compat]
4848
ADTypes = "1"
4949
AbstractMCMC = "5"
50-
AbstractPPL = "0.11"
50+
AbstractPPL = "0.11, 0.12"
5151
Accessors = "0.1"
5252
BangBang = "0.4.1"
5353
Bijectors = "0.13.18, 0.14, 0.15"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
1414
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1515

1616
[compat]
17-
AbstractPPL = "0.11"
17+
AbstractPPL = "0.11, 0.12"
1818
Accessors = "0.1"
1919
DataStructures = "0.18"
2020
Distributions = "0.25"

docs/src/api.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,10 @@ It is possible to manually increase (or decrease) the accumulated log likelihood
160160
@addlogprob!
161161
```
162162

163-
Return values of the model function for a collection of samples can be obtained with [`returned(model, chain)`](@ref).
163+
Return values of the model function can be obtained with [`returned(model, sample)`](@ref), where `sample` is either a `MCMCChains.Chains` object (which represents a collection of samples) or a single sample represented as a `NamedTuple`.
164164

165165
```@docs
166+
returned(::DynamicPPL.Model, ::MCMCChains.Chains)
166167
returned(::DynamicPPL.Model, ::NamedTuple)
167168
```
168169

src/abstract_varinfo.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,15 @@ If `vns` is provided, then only check if this/these varname(s) are transformed.
725725
"""
726726
istrans(vi::AbstractVarInfo) = istrans(vi, collect(keys(vi)))
727727
function istrans(vi::AbstractVarInfo, vns::AbstractVector)
728-
return !isempty(vns) && all(Base.Fix1(istrans, vi), vns)
728+
# This used to be: `!isempty(vns) && all(Base.Fix1(istrans, vi), vns)`.
729+
# In theory that should work perfectly fine. For unbeknownst reasons,
730+
# Julia 1.10 fails to infer its return type correctly. Thus we use this
731+
# slightly longer definition.
732+
isempty(vns) && return false
733+
for vn in vns
734+
istrans(vi, vn) || return false
735+
end
736+
return true
729737
end
730738

731739
"""

src/simple_varinfo.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName
356356
# Attempt to split into `parent` and `child` optic.
357357
parent, child, issuccess = splitoptic(getoptic(vn)) do optic
358358
o = optic === nothing ? identity : optic
359-
haskey(dict, VarName(vn, o))
359+
haskey(dict, VarName{getsym(vn)}(o))
360360
end
361361
# When combined with `VarInfo`, `nothing` is equivalent to `identity`.
362362
keyoptic = parent === nothing ? identity : parent
@@ -366,7 +366,7 @@ function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName
366366
BangBang.setindex!!(dict, val, vn)
367367
else
368368
# Split exists ⟹ trying to set an existing key.
369-
vn_key = VarName(vn, keyoptic)
369+
vn_key = VarName{getsym(vn)}(keyoptic)
370370
BangBang.setindex!!(dict, set!!(dict[vn_key], child, val), vn_key)
371371
end
372372
return Accessors.@set vi.values = dict_new

src/utils.jl

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ x
588588
"""
589589
function parent(vn::VarName)
590590
p = parent(getoptic(vn))
591-
return p === nothing ? VarName(vn, identity) : VarName(vn, p)
591+
return p === nothing ? VarName{getsym(vn)}(identity) : VarName{getsym(vn)}(p)
592592
end
593593

594594
"""
@@ -703,7 +703,7 @@ ERROR: Could not find x.a[2] in x.a[1]
703703
function remove_parent_optic(vn_parent::VarName{sym}, vn_child::VarName{sym}) where {sym}
704704
_, child, issuccess = splitoptic(getoptic(vn_child)) do optic
705705
o = optic === nothing ? identity : optic
706-
VarName(vn_child, o) == vn_parent
706+
o == getoptic(vn_parent)
707707
end
708708

709709
issuccess || error("Could not find $vn_parent in $vn_child")
@@ -898,7 +898,7 @@ function hasvalue(vals::AbstractDict, vn::VarName)
898898
# If `issuccess` is `true`, we found such a split, and hence `vn` is present.
899899
parent, child, issuccess = splitoptic(getoptic(vn)) do optic
900900
o = optic === nothing ? identity : optic
901-
haskey(vals, VarName(vn, o))
901+
haskey(vals, VarName{getsym(vn)}(o))
902902
end
903903
# When combined with `VarInfo`, `nothing` is equivalent to `identity`.
904904
keyoptic = parent === nothing ? identity : parent
@@ -907,7 +907,7 @@ function hasvalue(vals::AbstractDict, vn::VarName)
907907
issuccess || return false
908908

909909
# At this point we just need to check that we `canview` the value.
910-
value = vals[VarName(vn, keyoptic)]
910+
value = vals[VarName{getsym(vn)}(keyoptic)]
911911

912912
return canview(child, value)
913913
end
@@ -927,7 +927,7 @@ function nested_getindex(values::AbstractDict, vn::VarName)
927927
# Split the optic into the key / `parent` and the extraction optic / `child`.
928928
parent, child, issuccess = splitoptic(getoptic(vn)) do optic
929929
o = optic === nothing ? identity : optic
930-
haskey(values, VarName(vn, o))
930+
haskey(values, VarName{getsym(vn)}(o))
931931
end
932932
# When combined with `VarInfo`, `nothing` is equivalent to `identity`.
933933
keyoptic = parent === nothing ? identity : parent
@@ -940,7 +940,7 @@ function nested_getindex(values::AbstractDict, vn::VarName)
940940

941941
# TODO: Should we also check that we `canview` the extracted `value`
942942
# rather than just let it fail upon `get` call?
943-
value = values[VarName(vn, keyoptic)]
943+
value = values[VarName{getsym(vn)}(keyoptic)]
944944
return child(value)
945945
end
946946

@@ -1058,20 +1058,21 @@ x.z[2][1]
10581058
varname_leaves(vn::VarName, ::Real) = [vn]
10591059
function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}})
10601060
return (
1061-
VarName(vn, Accessors.IndexLens(Tuple(I)) getoptic(vn)) for
1061+
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) getoptic(vn)) for
10621062
I in CartesianIndices(val)
10631063
)
10641064
end
10651065
function varname_leaves(vn::VarName, val::AbstractArray)
10661066
return Iterators.flatten(
1067-
varname_leaves(VarName(vn, Accessors.IndexLens(Tuple(I)) getoptic(vn)), val[I])
1068-
for I in CartesianIndices(val)
1067+
varname_leaves(
1068+
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) getoptic(vn)), val[I]
1069+
) for I in CartesianIndices(val)
10691070
)
10701071
end
10711072
function varname_leaves(vn::VarName, val::NamedTuple)
1072-
iter = Iterators.map(keys(val)) do sym
1073-
optic = Accessors.PropertyLens{sym}()
1074-
varname_leaves(VarName(vn, optic getoptic(vn)), optic(val))
1073+
iter = Iterators.map(keys(val)) do k
1074+
optic = Accessors.PropertyLens{k}()
1075+
varname_leaves(VarName{getsym(vn)}(optic getoptic(vn)), optic(val))
10751076
end
10761077
return Iterators.flatten(iter)
10771078
end
@@ -1101,7 +1102,7 @@ julia> foreach(println, varname_and_value_leaves(@varname(x), x))
11011102
(x.z[2][1], 3.0)
11021103
```
11031104
1104-
There are also some special handling for certain types:
1105+
There is also some special handling for certain types:
11051106
11061107
```jldoctest varname-and-value-leaves
11071108
julia> using LinearAlgebra
@@ -1220,7 +1221,7 @@ function varname_and_value_leaves_inner(
12201221
)
12211222
return (
12221223
Leaf(
1223-
VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) DynamicPPL.getoptic(vn)),
1224+
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) AbstractPPL.getoptic(vn)),
12241225
val[I],
12251226
) for I in CartesianIndices(val)
12261227
)
@@ -1229,14 +1230,14 @@ end
12291230
function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray)
12301231
return Iterators.flatten(
12311232
varname_and_value_leaves_inner(
1232-
VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) DynamicPPL.getoptic(vn)),
1233+
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) AbstractPPL.getoptic(vn)),
12331234
val[I],
12341235
) for I in CartesianIndices(val)
12351236
)
12361237
end
1237-
function varname_and_value_leaves_inner(vn::DynamicPPL.VarName, val::NamedTuple)
1238-
iter = Iterators.map(keys(val)) do sym
1239-
optic = DynamicPPL.Accessors.PropertyLens{sym}()
1238+
function varname_and_value_leaves_inner(vn::VarName, val::NamedTuple)
1239+
iter = Iterators.map(keys(val)) do k
1240+
optic = Accessors.PropertyLens{k}()
12401241
varname_and_value_leaves_inner(
12411242
VarName{getsym(vn)}(optic getoptic(vn)), optic(val)
12421243
)
@@ -1255,20 +1256,14 @@ function varname_and_value_leaves_inner(vn::VarName, x::Cholesky)
12551256
end
12561257
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular)
12571258
return (
1258-
Leaf(
1259-
VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) DynamicPPL.getoptic(vn)),
1260-
x[I],
1261-
)
1259+
Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) getoptic(vn)), x[I])
12621260
# Iteration over the lower-triangular indices.
12631261
for I in CartesianIndices(x) if I[1] >= I[2]
12641262
)
12651263
end
12661264
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular)
12671265
return (
1268-
Leaf(
1269-
VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) DynamicPPL.getoptic(vn)),
1270-
x[I],
1271-
)
1266+
Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) getoptic(vn)), x[I])
12721267
# Iteration over the upper-triangular indices.
12731268
for I in CartesianIndices(x) if I[1] <= I[2]
12741269
)

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3030
[compat]
3131
ADTypes = "1"
3232
AbstractMCMC = "5"
33-
AbstractPPL = "0.11"
33+
AbstractPPL = "0.11, 0.12"
3434
Accessors = "0.1"
3535
Aqua = "0.8"
3636
Bijectors = "0.15.1"

0 commit comments

Comments
 (0)