Skip to content

Commit acac44d

Browse files
authored
* [email protected] * Replace removed VarName constructor * Hotfix for Julia 1.10 claiming that istrans is type unstable * Use getsym() * Revert "Hotfix for Julia 1.10 claiming that istrans is type unstable" This reverts commit 2982129. * Reapply "Hotfix for Julia 1.10 claiming that istrans is type unstable" This reverts commit 67bb055.
1 parent 92f6eea commit acac44d

File tree

7 files changed

+40
-33
lines changed

7 files changed

+40
-33
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# DynamicPPL Changelog
22

3+
## 0.36.14
4+
5+
Added compatibility with [email protected].
6+
37
## 0.36.13
48

59
Added documentation for the `returned(::Model, ::MCMCChains.Chains)` method.

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.36.13"
3+
version = "0.36.14"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -46,7 +46,7 @@ DynamicPPLMooncakeExt = ["Mooncake"]
4646
[compat]
4747
ADTypes = "1"
4848
AbstractMCMC = "5"
49-
AbstractPPL = "0.11"
49+
AbstractPPL = "0.11, 0.12"
5050
Accessors = "0.1"
5151
BangBang = "0.4.1"
5252
Bijectors = "0.13.18, 0.14, 0.15"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
1414
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1515

1616
[compat]
17-
AbstractPPL = "0.11"
17+
AbstractPPL = "0.11, 0.12"
1818
Accessors = "0.1"
1919
DataStructures = "0.18"
2020
Distributions = "0.25"

src/abstract_varinfo.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,15 @@ If `vns` is provided, then only check if this/these varname(s) are transformed.
481481
"""
482482
istrans(vi::AbstractVarInfo) = istrans(vi, collect(keys(vi)))
483483
function istrans(vi::AbstractVarInfo, vns::AbstractVector)
484-
return !isempty(vns) && all(Base.Fix1(istrans, vi), vns)
484+
# This used to be: `!isempty(vns) && all(Base.Fix1(istrans, vi), vns)`.
485+
# In theory that should work perfectly fine. For unbeknownst reasons,
486+
# Julia 1.10 fails to infer its return type correctly. Thus we use this
487+
# slightly longer definition.
488+
isempty(vns) && return false
489+
for vn in vns
490+
istrans(vi, vn) || return false
491+
end
492+
return true
485493
end
486494

487495
"""

src/simple_varinfo.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName
362362
# Attempt to split into `parent` and `child` optic.
363363
parent, child, issuccess = splitoptic(getoptic(vn)) do optic
364364
o = optic === nothing ? identity : optic
365-
haskey(dict, VarName(vn, o))
365+
haskey(dict, VarName{getsym(vn)}(o))
366366
end
367367
# When combined with `VarInfo`, `nothing` is equivalent to `identity`.
368368
keyoptic = parent === nothing ? identity : parent
@@ -372,7 +372,7 @@ function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName
372372
BangBang.setindex!!(dict, val, vn)
373373
else
374374
# Split exists ⟹ trying to set an existing key.
375-
vn_key = VarName(vn, keyoptic)
375+
vn_key = VarName{getsym(vn)}(keyoptic)
376376
BangBang.setindex!!(dict, set!!(dict[vn_key], child, val), vn_key)
377377
end
378378
return Accessors.@set vi.values = dict_new

src/utils.jl

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,7 @@ x
597597
"""
598598
function parent(vn::VarName)
599599
p = parent(getoptic(vn))
600-
return p === nothing ? VarName(vn, identity) : VarName(vn, p)
600+
return p === nothing ? VarName{getsym(vn)}(identity) : VarName{getsym(vn)}(p)
601601
end
602602

603603
"""
@@ -712,7 +712,7 @@ ERROR: Could not find x.a[2] in x.a[1]
712712
function remove_parent_optic(vn_parent::VarName{sym}, vn_child::VarName{sym}) where {sym}
713713
_, child, issuccess = splitoptic(getoptic(vn_child)) do optic
714714
o = optic === nothing ? identity : optic
715-
VarName(vn_child, o) == vn_parent
715+
o == getoptic(vn_parent)
716716
end
717717

718718
issuccess || error("Could not find $vn_parent in $vn_child")
@@ -907,7 +907,7 @@ function hasvalue(vals::AbstractDict, vn::VarName)
907907
# If `issuccess` is `true`, we found such a split, and hence `vn` is present.
908908
parent, child, issuccess = splitoptic(getoptic(vn)) do optic
909909
o = optic === nothing ? identity : optic
910-
haskey(vals, VarName(vn, o))
910+
haskey(vals, VarName{getsym(vn)}(o))
911911
end
912912
# When combined with `VarInfo`, `nothing` is equivalent to `identity`.
913913
keyoptic = parent === nothing ? identity : parent
@@ -916,7 +916,7 @@ function hasvalue(vals::AbstractDict, vn::VarName)
916916
issuccess || return false
917917

918918
# At this point we just need to check that we `canview` the value.
919-
value = vals[VarName(vn, keyoptic)]
919+
value = vals[VarName{getsym(vn)}(keyoptic)]
920920

921921
return canview(child, value)
922922
end
@@ -936,7 +936,7 @@ function nested_getindex(values::AbstractDict, vn::VarName)
936936
# Split the optic into the key / `parent` and the extraction optic / `child`.
937937
parent, child, issuccess = splitoptic(getoptic(vn)) do optic
938938
o = optic === nothing ? identity : optic
939-
haskey(values, VarName(vn, o))
939+
haskey(values, VarName{getsym(vn)}(o))
940940
end
941941
# When combined with `VarInfo`, `nothing` is equivalent to `identity`.
942942
keyoptic = parent === nothing ? identity : parent
@@ -949,7 +949,7 @@ function nested_getindex(values::AbstractDict, vn::VarName)
949949

950950
# TODO: Should we also check that we `canview` the extracted `value`
951951
# rather than just let it fail upon `get` call?
952-
value = values[VarName(vn, keyoptic)]
952+
value = values[VarName{getsym(vn)}(keyoptic)]
953953
return child(value)
954954
end
955955

@@ -1067,20 +1067,21 @@ x.z[2][1]
10671067
varname_leaves(vn::VarName, ::Real) = [vn]
10681068
function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}})
10691069
return (
1070-
VarName(vn, Accessors.IndexLens(Tuple(I)) getoptic(vn)) for
1070+
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) getoptic(vn)) for
10711071
I in CartesianIndices(val)
10721072
)
10731073
end
10741074
function varname_leaves(vn::VarName, val::AbstractArray)
10751075
return Iterators.flatten(
1076-
varname_leaves(VarName(vn, Accessors.IndexLens(Tuple(I)) getoptic(vn)), val[I])
1077-
for I in CartesianIndices(val)
1076+
varname_leaves(
1077+
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) getoptic(vn)), val[I]
1078+
) for I in CartesianIndices(val)
10781079
)
10791080
end
10801081
function varname_leaves(vn::VarName, val::NamedTuple)
1081-
iter = Iterators.map(keys(val)) do sym
1082-
optic = Accessors.PropertyLens{sym}()
1083-
varname_leaves(VarName(vn, optic getoptic(vn)), optic(val))
1082+
iter = Iterators.map(keys(val)) do k
1083+
optic = Accessors.PropertyLens{k}()
1084+
varname_leaves(VarName{getsym(vn)}(optic getoptic(vn)), optic(val))
10841085
end
10851086
return Iterators.flatten(iter)
10861087
end
@@ -1110,7 +1111,7 @@ julia> foreach(println, varname_and_value_leaves(@varname(x), x))
11101111
(x.z[2][1], 3.0)
11111112
```
11121113
1113-
There are also some special handling for certain types:
1114+
There is also some special handling for certain types:
11141115
11151116
```jldoctest varname-and-value-leaves
11161117
julia> using LinearAlgebra
@@ -1229,7 +1230,7 @@ function varname_and_value_leaves_inner(
12291230
)
12301231
return (
12311232
Leaf(
1232-
VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) DynamicPPL.getoptic(vn)),
1233+
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) AbstractPPL.getoptic(vn)),
12331234
val[I],
12341235
) for I in CartesianIndices(val)
12351236
)
@@ -1238,14 +1239,14 @@ end
12381239
function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray)
12391240
return Iterators.flatten(
12401241
varname_and_value_leaves_inner(
1241-
VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) DynamicPPL.getoptic(vn)),
1242+
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) AbstractPPL.getoptic(vn)),
12421243
val[I],
12431244
) for I in CartesianIndices(val)
12441245
)
12451246
end
1246-
function varname_and_value_leaves_inner(vn::DynamicPPL.VarName, val::NamedTuple)
1247-
iter = Iterators.map(keys(val)) do sym
1248-
optic = DynamicPPL.Accessors.PropertyLens{sym}()
1247+
function varname_and_value_leaves_inner(vn::VarName, val::NamedTuple)
1248+
iter = Iterators.map(keys(val)) do k
1249+
optic = Accessors.PropertyLens{k}()
12491250
varname_and_value_leaves_inner(
12501251
VarName{getsym(vn)}(optic getoptic(vn)), optic(val)
12511252
)
@@ -1264,20 +1265,14 @@ function varname_and_value_leaves_inner(vn::VarName, x::Cholesky)
12641265
end
12651266
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular)
12661267
return (
1267-
Leaf(
1268-
VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) DynamicPPL.getoptic(vn)),
1269-
x[I],
1270-
)
1268+
Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) getoptic(vn)), x[I])
12711269
# Iteration over the lower-triangular indices.
12721270
for I in CartesianIndices(x) if I[1] >= I[2]
12731271
)
12741272
end
12751273
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular)
12761274
return (
1277-
Leaf(
1278-
VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) DynamicPPL.getoptic(vn)),
1279-
x[I],
1280-
)
1275+
Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) getoptic(vn)), x[I])
12811276
# Iteration over the upper-triangular indices.
12821277
for I in CartesianIndices(x) if I[1] <= I[2]
12831278
)

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3030
[compat]
3131
ADTypes = "1"
3232
AbstractMCMC = "5"
33-
AbstractPPL = "0.11"
33+
AbstractPPL = "0.11, 0.12"
3434
Accessors = "0.1"
3535
Aqua = "0.8"
3636
Bijectors = "0.15.1"

0 commit comments

Comments
 (0)