Skip to content

Commit 0d4cbba

Browse files
committed
Replace removed VarName constructor
1 parent 87d8e19 commit 0d4cbba

File tree

2 files changed

+41
-43
lines changed

2 files changed

+41
-43
lines changed

src/simple_varinfo.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -356,13 +356,15 @@ function BangBang.setindex!!(vi::SimpleVarInfo, vals, vns::AbstractVector{<:VarN
356356
return vi
357357
end
358358

359-
function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName)
359+
function BangBang.setindex!!(
360+
vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName{sym}
361+
) where {sym}
360362
# For dictlike objects, we treat the entire `vn` as a _key_ to set.
361363
dict = values_as(vi)
362364
# Attempt to split into `parent` and `child` optic.
363365
parent, child, issuccess = splitoptic(getoptic(vn)) do optic
364366
o = optic === nothing ? identity : optic
365-
haskey(dict, VarName(vn, o))
367+
haskey(dict, VarName{sym}(o))
366368
end
367369
# When combined with `VarInfo`, `nothing` is equivalent to `identity`.
368370
keyoptic = parent === nothing ? identity : parent
@@ -372,7 +374,7 @@ function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName
372374
BangBang.setindex!!(dict, val, vn)
373375
else
374376
# Split exists ⟹ trying to set an existing key.
375-
vn_key = VarName(vn, keyoptic)
377+
vn_key = VarName{sym}(keyoptic)
376378
BangBang.setindex!!(dict, set!!(dict[vn_key], child, val), vn_key)
377379
end
378380
return Accessors.@set vi.values = dict_new

src/utils.jl

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -595,9 +595,9 @@ julia> (parent ∘ parent ∘ parent)(@varname(x.a[1]))
595595
x
596596
```
597597
"""
598-
function parent(vn::VarName)
598+
function parent(vn::VarName{sym}) where {sym}
599599
p = parent(getoptic(vn))
600-
return p === nothing ? VarName(vn, identity) : VarName(vn, p)
600+
return p === nothing ? VarName{sym}(identity) : VarName{sym}(p)
601601
end
602602

603603
"""
@@ -712,7 +712,7 @@ ERROR: Could not find x.a[2] in x.a[1]
712712
function remove_parent_optic(vn_parent::VarName{sym}, vn_child::VarName{sym}) where {sym}
713713
_, child, issuccess = splitoptic(getoptic(vn_child)) do optic
714714
o = optic === nothing ? identity : optic
715-
VarName(vn_child, o) == vn_parent
715+
o == getoptic(vn_parent)
716716
end
717717

718718
issuccess || error("Could not find $vn_parent in $vn_child")
@@ -898,7 +898,7 @@ end
898898

899899
# For `dictlike` we need to check wether `vn` is "immediately" present, or
900900
# if some ancestor of `vn` is present in `dictlike`.
901-
function hasvalue(vals::AbstractDict, vn::VarName)
901+
function hasvalue(vals::AbstractDict, vn::VarName{sym}) where {sym}
902902
# First we check if `vn` is present as is.
903903
haskey(vals, vn) && return true
904904

@@ -907,7 +907,7 @@ function hasvalue(vals::AbstractDict, vn::VarName)
907907
# If `issuccess` is `true`, we found such a split, and hence `vn` is present.
908908
parent, child, issuccess = splitoptic(getoptic(vn)) do optic
909909
o = optic === nothing ? identity : optic
910-
haskey(vals, VarName(vn, o))
910+
haskey(vals, VarName{sym}(o))
911911
end
912912
# When combined with `VarInfo`, `nothing` is equivalent to `identity`.
913913
keyoptic = parent === nothing ? identity : parent
@@ -916,7 +916,7 @@ function hasvalue(vals::AbstractDict, vn::VarName)
916916
issuccess || return false
917917

918918
# At this point we just need to check that we `canview` the value.
919-
value = vals[VarName(vn, keyoptic)]
919+
value = vals[VarName{sym}(keyoptic)]
920920

921921
return canview(child, value)
922922
end
@@ -927,7 +927,7 @@ end
927927
Return value corresponding to `vn` in `values` by also looking
928928
in the the actual values of the dict.
929929
"""
930-
function nested_getindex(values::AbstractDict, vn::VarName)
930+
function nested_getindex(values::AbstractDict, vn::VarName{sym}) where {sym}
931931
maybeval = get(values, vn, nothing)
932932
if maybeval !== nothing
933933
return maybeval
@@ -936,7 +936,7 @@ function nested_getindex(values::AbstractDict, vn::VarName)
936936
# Split the optic into the key / `parent` and the extraction optic / `child`.
937937
parent, child, issuccess = splitoptic(getoptic(vn)) do optic
938938
o = optic === nothing ? identity : optic
939-
haskey(values, VarName(vn, o))
939+
haskey(values, VarName{sym}(o))
940940
end
941941
# When combined with `VarInfo`, `nothing` is equivalent to `identity`.
942942
keyoptic = parent === nothing ? identity : parent
@@ -949,7 +949,7 @@ function nested_getindex(values::AbstractDict, vn::VarName)
949949

950950
# TODO: Should we also check that we `canview` the extracted `value`
951951
# rather than just let it fail upon `get` call?
952-
value = values[VarName(vn, keyoptic)]
952+
value = values[VarName{sym}(keyoptic)]
953953
return child(value)
954954
end
955955

@@ -1065,22 +1065,24 @@ x.z[2][1]
10651065
```
10661066
"""
10671067
varname_leaves(vn::VarName, ::Real) = [vn]
1068-
function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}})
1068+
function varname_leaves(
1069+
vn::VarName{sym}, val::AbstractArray{<:Union{Real,Missing}}
1070+
) where {sym}
10691071
return (
1070-
VarName(vn, Accessors.IndexLens(Tuple(I)) getoptic(vn)) for
1072+
VarName{sym}(Accessors.IndexLens(Tuple(I)) getoptic(vn)) for
10711073
I in CartesianIndices(val)
10721074
)
10731075
end
1074-
function varname_leaves(vn::VarName, val::AbstractArray)
1076+
function varname_leaves(vn::VarName{sym}, val::AbstractArray) where {sym}
10751077
return Iterators.flatten(
1076-
varname_leaves(VarName(vn, Accessors.IndexLens(Tuple(I)) getoptic(vn)), val[I])
1078+
varname_leaves(VarName{sym}(Accessors.IndexLens(Tuple(I)) getoptic(vn)), val[I])
10771079
for I in CartesianIndices(val)
10781080
)
10791081
end
1080-
function varname_leaves(vn::VarName, val::NamedTuple)
1081-
iter = Iterators.map(keys(val)) do sym
1082-
optic = Accessors.PropertyLens{sym}()
1083-
varname_leaves(VarName(vn, optic getoptic(vn)), optic(val))
1082+
function varname_leaves(vn::VarName{sym}, val::NamedTuple) where {sym}
1083+
iter = Iterators.map(keys(val)) do k
1084+
optic = Accessors.PropertyLens{k}()
1085+
varname_leaves(VarName{sym}(optic getoptic(vn)), optic(val))
10841086
end
10851087
return Iterators.flatten(iter)
10861088
end
@@ -1225,30 +1227,26 @@ value(leaf::Leaf) = leaf.value
12251227
# Leaf-types.
12261228
varname_and_value_leaves_inner(vn::VarName, x::Real) = [Leaf(vn, x)]
12271229
function varname_and_value_leaves_inner(
1228-
vn::VarName, val::AbstractArray{<:Union{Real,Missing}}
1229-
)
1230+
vn::VarName{sym}, val::AbstractArray{<:Union{Real,Missing}}
1231+
) where {sym}
12301232
return (
12311233
Leaf(
1232-
VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) DynamicPPL.getoptic(vn)),
1233-
val[I],
1234+
VarName{sym}(Accessors.IndexLens(Tuple(I)) AbstractPPL.getoptic(vn)), val[I]
12341235
) for I in CartesianIndices(val)
12351236
)
12361237
end
12371238
# Containers.
1238-
function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray)
1239+
function varname_and_value_leaves_inner(vn::VarName{sym}, val::AbstractArray) where {sym}
12391240
return Iterators.flatten(
12401241
varname_and_value_leaves_inner(
1241-
VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) DynamicPPL.getoptic(vn)),
1242-
val[I],
1242+
VarName{sym}(Accessors.IndexLens(Tuple(I)) AbstractPPL.getoptic(vn)), val[I]
12431243
) for I in CartesianIndices(val)
12441244
)
12451245
end
1246-
function varname_and_value_leaves_inner(vn::DynamicPPL.VarName, val::NamedTuple)
1247-
iter = Iterators.map(keys(val)) do sym
1248-
optic = DynamicPPL.Accessors.PropertyLens{sym}()
1249-
varname_and_value_leaves_inner(
1250-
VarName{getsym(vn)}(optic getoptic(vn)), optic(val)
1251-
)
1246+
function varname_and_value_leaves_inner(vn::VarName{sym}, val::NamedTuple) where {sym}
1247+
iter = Iterators.map(keys(val)) do k
1248+
optic = Accessors.PropertyLens{k}()
1249+
varname_and_value_leaves_inner(VarName{sym}(optic getoptic(vn)), optic(val))
12521250
end
12531251

12541252
return Iterators.flatten(iter)
@@ -1262,22 +1260,20 @@ function varname_and_value_leaves_inner(vn::VarName, x::Cholesky)
12621260
varname_and_value_leaves_inner(Accessors.PropertyLens{:U}() vn, x.U)
12631261
end
12641262
end
1265-
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular)
1263+
function varname_and_value_leaves_inner(
1264+
vn::VarName{sym}, x::LinearAlgebra.LowerTriangular
1265+
) where {sym}
12661266
return (
1267-
Leaf(
1268-
VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) DynamicPPL.getoptic(vn)),
1269-
x[I],
1270-
)
1267+
Leaf(VarName{sym}(Accessors.IndexLens(Tuple(I)) AbstractPPL.getoptic(vn)), x[I])
12711268
# Iteration over the lower-triangular indices.
12721269
for I in CartesianIndices(x) if I[1] >= I[2]
12731270
)
12741271
end
1275-
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular)
1272+
function varname_and_value_leaves_inner(
1273+
vn::VarName{sym}, x::LinearAlgebra.UpperTriangular
1274+
) where {sym}
12761275
return (
1277-
Leaf(
1278-
VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) DynamicPPL.getoptic(vn)),
1279-
x[I],
1280-
)
1276+
Leaf(VarName{sym}(Accessors.IndexLens(Tuple(I)) AbstractPPL.getoptic(vn)), x[I])
12811277
# Iteration over the upper-triangular indices.
12821278
for I in CartesianIndices(x) if I[1] <= I[2]
12831279
)

0 commit comments

Comments
 (0)