Skip to content

Commit 0f8c9b1

Browse files
committed
Rename trans to is_transformed
1 parent a011dd6 commit 0f8c9b1

18 files changed

+181
-173
lines changed

HISTORY.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ The separation of these functions was primarily implemented to avoid performing
5454

5555
Previously `VarInfo` (or more correctly, the `Metadata` object within a `VarInfo`), had a flag called `"del"` for all variables. If it was set to `true` the variable was to be overwritten with a new value at the next evaluation. The new `InitContext` and related changes above make this flag unnecessary, and it has been removed.
5656

57-
The only other flag, other than `"del"`, that `Metadata` ever used was `"trans"`. Thus the generic functions `set_flag!`, `unset_flag!` and `is_flagged!` have also been removed. One can simply use `istrans` and a newly exported function called `settrans!!` instead.
57+
The only flag other than `"del"` that `Metadata` ever used was `"trans"`. Thus the generic functions `set_flag!`, `unset_flag!` and `is_flagged!` have also been removed in favour of more specific ones. We've also used this opportunity to name the `"trans"` flag and the corresponding `istrans` function to be more explicit. The new, exported interface consists of the `is_transformed` and `set_transformed!!` functions.
5858

5959
**Other changes**
6060

docs/src/api.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,8 +329,8 @@ The [Transformations section below](#Transformations) describes the methods used
329329
In the specific case of `VarInfo`, it keeps track of whether samples have been transformed by setting flags on them, using the following functions.
330330

331331
```@docs
332-
istrans
333-
settrans!!
332+
is_transformed
333+
set_transformed!!
334334
```
335335

336336
```@docs

ext/DynamicPPLEnzymeCoreExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ else
88
using ..EnzymeCore
99
end
1010

11-
# Mark istrans as having 0 derivative. The `nothing` return value is not significant, Enzyme
11+
# Mark is_transformed as having 0 derivative. The `nothing` return value is not significant, Enzyme
1212
# only checks whether such a method exists, and never runs it.
13-
@inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.istrans), args...) = nothing
13+
@inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.is_transformed), args...) =
14+
nothing
1415

1516
end

ext/DynamicPPLMooncakeExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
module DynamicPPLMooncakeExt
22

3-
using DynamicPPL: DynamicPPL, istrans
3+
using DynamicPPL: DynamicPPL, is_transformed
44
using Mooncake: Mooncake
55

66
# This is purely an optimisation.
7-
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg}
7+
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(is_transformed),Vararg}
88

99
end # module

src/DynamicPPL.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ export AbstractVarInfo,
7070
acclogjac!!,
7171
acclogprior!!,
7272
accloglikelihood!!,
73-
istrans,
74-
settrans!!,
73+
is_transformed,
74+
set_transformed!!,
7575
link,
7676
link!!,
7777
invlink,

src/abstract_varinfo.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,7 @@ end
769769

770770
# Transformations
771771
"""
772-
istrans(vi::AbstractVarInfo[, vns::Union{VarName, AbstractVector{<:Varname}}])
772+
is_transformed(vi::AbstractVarInfo[, vns::Union{VarName, AbstractVector{<:Varname}}])
773773
774774
Return `true` if `vi` is working in unconstrained space, and `false`
775775
if `vi` is assuming realizations to be in support of the corresponding distributions.
@@ -780,27 +780,27 @@ If `vns` is provided, then only check if this/these varname(s) are transformed.
780780
Not all implementations of `AbstractVarInfo` support transforming only a subset of
781781
the variables.
782782
"""
783-
istrans(vi::AbstractVarInfo) = istrans(vi, collect(keys(vi)))
784-
function istrans(vi::AbstractVarInfo, vns::AbstractVector)
785-
# This used to be: `!isempty(vns) && all(Base.Fix1(istrans, vi), vns)`.
783+
is_transformed(vi::AbstractVarInfo) = is_transformed(vi, collect(keys(vi)))
784+
function is_transformed(vi::AbstractVarInfo, vns::AbstractVector)
785+
# This used to be: `!isempty(vns) && all(Base.Fix1(is_transformed, vi), vns)`.
786786
# In theory that should work perfectly fine. For unbeknownst reasons,
787787
# Julia 1.10 fails to infer its return type correctly. Thus we use this
788788
# slightly longer definition.
789789
isempty(vns) && return false
790790
for vn in vns
791-
istrans(vi, vn) || return false
791+
is_transformed(vi, vn) || return false
792792
end
793793
return true
794794
end
795795

796796
"""
797-
settrans!!(vi::AbstractVarInfo, trans::Bool[, vn::VarName])
797+
set_transformed!!(vi::AbstractVarInfo, trans::Bool[, vn::VarName])
798798
799-
Return `vi` with `istrans(vi, vn)` evaluating to `true`.
799+
Return `vi` with `is_transformed(vi, vn)` evaluating to `true`.
800800
801-
If `vn` is not specified, then `istrans(vi)` evaluates to `true` for all variables.
801+
If `vn` is not specified, then `is_transformed(vi)` evaluates to `true` for all variables.
802802
"""
803-
function settrans!! end
803+
function set_transformed!! end
804804

805805
# For link!!, invlink!!, link, and invlink, we deliberately do not provide a fallback
806806
# method for the case when no `vns` is provided, that would get all the keys from the
@@ -833,7 +833,7 @@ function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model)
833833
ctx = DynamicTransformationContext{false}()
834834
model = contextualize(model, setleafcontext(model.context, ctx))
835835
vi = last(evaluate!!(model, vi))
836-
return settrans!!(vi, t)
836+
return set_transformed!!(vi, t)
837837
end
838838
function link!!(
839839
t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model
@@ -846,7 +846,7 @@ function link!!(
846846
if hasacc(vi, Val(:LogJacobian))
847847
vi = acclogjac!!(vi, logjac)
848848
end
849-
return settrans!!(vi, t)
849+
return set_transformed!!(vi, t)
850850
end
851851

852852
"""
@@ -896,7 +896,7 @@ function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model)
896896
ctx = DynamicTransformationContext{true}()
897897
model = contextualize(model, setleafcontext(model.context, ctx))
898898
vi = last(evaluate!!(model, vi))
899-
return settrans!!(vi, NoTransformation())
899+
return set_transformed!!(vi, NoTransformation())
900900
end
901901
function invlink!!(
902902
t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model
@@ -912,7 +912,7 @@ function invlink!!(
912912
if hasacc(vi, Val(:LogJacobian))
913913
vi = acclogjac!!(vi, inv_logjac)
914914
end
915-
return settrans!!(vi, NoTransformation())
915+
return set_transformed!!(vi, NoTransformation())
916916
end
917917

918918
"""
@@ -1020,7 +1020,7 @@ function unflatten end
10201020
"""
10211021
to_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val)
10221022
1023-
Return reconstructed `val`, possibly linked if `istrans(vi, vn)` is `true`.
1023+
Return reconstructed `val`, possibly linked if `is_transformed(vi, vn)` is `true`.
10241024
"""
10251025
function to_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val)
10261026
f = to_maybe_linked_internal_transform(vi, vn, dist)
@@ -1030,7 +1030,7 @@ end
10301030
"""
10311031
from_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val)
10321032
1033-
Return reconstructed `val`, possibly invlinked if `istrans(vi, vn)` is `true`.
1033+
Return reconstructed `val`, possibly invlinked if `is_transformed(vi, vn)` is `true`.
10341034
"""
10351035
function from_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val)
10361036
f = from_maybe_linked_internal_transform(vi, vn, dist)
@@ -1087,14 +1087,14 @@ in `varinfo` to a representation compatible with `dist`.
10871087
If `dist` is not present, then it is assumed that `varinfo` knows the correct output for `vn`.
10881088
"""
10891089
function from_maybe_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName, dist)
1090-
return if istrans(varinfo, vn)
1090+
return if is_transformed(varinfo, vn)
10911091
from_linked_internal_transform(varinfo, vn, dist)
10921092
else
10931093
from_internal_transform(varinfo, vn, dist)
10941094
end
10951095
end
10961096
function from_maybe_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName)
1097-
return if istrans(varinfo, vn)
1097+
return if is_transformed(varinfo, vn)
10981098
from_linked_internal_transform(varinfo, vn)
10991099
else
11001100
from_internal_transform(varinfo, vn)

src/contexts/init.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,9 @@ function tilde_assume!!(
165165
# If the VarInfo alrady had a value for this variable, we will
166166
# keep the same linked status as in the original VarInfo. If not, we
167167
# check the rest of the VarInfo to see if other variables are linked.
168-
# istrans(vi) returns true if vi is nonempty and all variables in vi
168+
# is_transformed(vi) returns true if vi is nonempty and all variables in vi
169169
# are linked.
170-
insert_transformed_value = in_varinfo ? istrans(vi, vn) : istrans(vi)
170+
insert_transformed_value = in_varinfo ? is_transformed(vi, vn) : is_transformed(vi)
171171
f = if insert_transformed_value
172172
link_transform(dist)
173173
else
@@ -183,7 +183,7 @@ function tilde_assume!!(
183183
end
184184
# Neither of these set the `trans` flag so we have to do it manually if
185185
# necessary.
186-
insert_transformed_value && settrans!!(vi, true, vn)
186+
insert_transformed_value && set_transformed!!(vi, true, vn)
187187
# `accumulate_assume!!` wants untransformed values as the second argument.
188188
vi = accumulate_assume!!(vi, x, logjac, vn, dist)
189189
# We always return the untransformed value here, as that will determine

src/contexts/transformation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ function tilde_assume!!(
2121
# vi[vn, right] always provides the value in unlinked space.
2222
x = vi[vn, right]
2323

24-
if istrans(vi, vn)
24+
if is_transformed(vi, vn)
2525
isinverse || @warn "Trying to link an already transformed variable ($vn)"
2626
else
2727
isinverse && @warn "Trying to invlink a non-transformed variable ($vn)"

src/simple_varinfo.jl

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -96,23 +96,23 @@ julia> _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo());
9696
julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞
9797
1.8632965762164932
9898
99-
julia> _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true));
99+
julia> _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true));
100100
101101
julia> vi[@varname(x)] # (✓) -∞ < x < ∞
102102
-0.21080155351918753
103103
104-
julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
104+
julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
105105
106106
julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
107107
true
108108
109109
julia> # And with `OrderedDict` of course!
110-
_, vi = DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true));
110+
_, vi = DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true));
111111
112112
julia> vi[@varname(x)] # (✓) -∞ < x < ∞
113113
0.6225185067787314
114114
115-
julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
115+
julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
116116
117117
julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
118118
true
@@ -121,15 +121,15 @@ true
121121
Evaluation in transformed space of course also works:
122122
123123
```jldoctest simplevarinfo-general
124-
julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true)
124+
julia> vi = DynamicPPL.set_transformed!!(SimpleVarInfo((x = -1.0,)), true)
125125
Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0)))
126126
127127
julia> # (✓) Positive probability mass on negative numbers!
128128
getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi)))
129129
-1.3678794411714423
130130
131131
julia> # While if we forget to indicate that it's transformed:
132-
vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false)
132+
vi = DynamicPPL.set_transformed!!(SimpleVarInfo((x = -1.0,)), false)
133133
SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0)))
134134
135135
julia> # (✓) No probability mass on negative numbers!
@@ -466,32 +466,34 @@ function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo)
466466
return SimpleVarInfo(values, accs, transformation)
467467
end
468468

469-
function settrans!!(vi::SimpleVarInfo, trans)
470-
return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation())
469+
function set_transformed!!(vi::SimpleVarInfo, trans)
470+
return set_transformed!!(vi, trans ? DynamicTransformation() : NoTransformation())
471471
end
472-
function settrans!!(vi::SimpleVarInfo, transformation::AbstractTransformation)
472+
function set_transformed!!(vi::SimpleVarInfo, transformation::AbstractTransformation)
473473
return Accessors.@set vi.transformation = transformation
474474
end
475-
function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans)
476-
return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans)
475+
function set_transformed!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans)
476+
return Accessors.@set vi.varinfo = set_transformed!!(vi.varinfo, trans)
477477
end
478-
function settrans!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName)
478+
function set_transformed!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName)
479479
# We keep this method around just to obey the AbstractVarInfo interface.
480480
# However, note that this would only be a valid operation if it would be a
481481
# no-op, which we check here.
482-
if trans != istrans(vi)
482+
if trans != is_transformed(vi)
483483
error(
484-
"Individual variables in SimpleVarInfo cannot have different `settrans` statuses.",
484+
"Individual variables in SimpleVarInfo cannot have different `set_transformed` statuses.",
485485
)
486486
end
487487
end
488488

489-
istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation)
490-
istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi)
491-
istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn)
492-
istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}) = istrans(vi.varinfo)
489+
is_transformed(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation)
490+
is_transformed(vi::SimpleVarInfo, ::VarName) = is_transformed(vi)
491+
function is_transformed(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName)
492+
return is_transformed(vi.varinfo, vn)
493+
end
494+
is_transformed(vi::ThreadSafeVarInfo{<:SimpleVarInfo}) = is_transformed(vi.varinfo)
493495

494-
islinked(vi::SimpleVarInfo) = istrans(vi)
496+
islinked(vi::SimpleVarInfo) = is_transformed(vi)
495497

496498
values_as(vi::SimpleVarInfo) = vi.values
497499
values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values
@@ -618,7 +620,7 @@ function link!!(
618620
if hasacc(vi_new, Val(:LogJacobian))
619621
vi_new = acclogjac!!(vi_new, logjac)
620622
end
621-
return settrans!!(vi_new, t)
623+
return set_transformed!!(vi_new, t)
622624
end
623625

624626
function invlink!!(
@@ -636,7 +638,7 @@ function invlink!!(
636638
if hasacc(vi_new, Val(:LogJacobian))
637639
vi_new = acclogjac!!(vi_new, inv_logjac)
638640
end
639-
return settrans!!(vi_new, NoTransformation())
641+
return set_transformed!!(vi_new, NoTransformation())
640642
end
641643

642644
# With `SimpleVarInfo`, when we're not working with linked variables, there's no need to do anything.

src/threadsafe.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,14 @@ function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
106106
model = contextualize(
107107
model, setleafcontext(model.context, DynamicTransformationContext{false}())
108108
)
109-
return settrans!!(last(evaluate!!(model, vi)), t)
109+
return set_transformed!!(last(evaluate!!(model, vi)), t)
110110
end
111111

112112
function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
113113
model = contextualize(
114114
model, setleafcontext(model.context, DynamicTransformationContext{true}())
115115
)
116-
return settrans!!(last(evaluate!!(model, vi)), NoTransformation())
116+
return set_transformed!!(last(evaluate!!(model, vi)), NoTransformation())
117117
end
118118

119119
function link(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
@@ -185,12 +185,14 @@ end
185185
values_as(vi::ThreadSafeVarInfo) = values_as(vi.varinfo)
186186
values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T)
187187

188-
function settrans!!(vi::ThreadSafeVarInfo, trans::Bool, vn::VarName)
189-
return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans, vn)
188+
function set_transformed!!(vi::ThreadSafeVarInfo, val::Bool, vn::VarName)
189+
return Accessors.@set vi.varinfo = set_transformed!!(vi.varinfo, val, vn)
190190
end
191191

192-
istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn)
193-
istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns)
192+
is_transformed(vi::ThreadSafeVarInfo, vn::VarName) = is_transformed(vi.varinfo, vn)
193+
function is_transformed(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName})
194+
return is_transformed(vi.varinfo, vns)
195+
end
194196

195197
getindex_internal(vi::ThreadSafeVarInfo, vn::VarName) = getindex_internal(vi.varinfo, vn)
196198

0 commit comments

Comments
 (0)