Skip to content

Commit 00604db

Browse files
sunxd3github-actions[bot]devmotionyebai
authored
Transition to Accessors.jl (#585)
* remove `BangBang.possible` * version bumps * remove dep `MLUtils` * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * finish sentence * Update docs/src/tutorials/prob-interface.md Co-authored-by: David Widmann <[email protected]> * Update docs/src/tutorials/prob-interface.md Co-authored-by: David Widmann <[email protected]> * make `kfolds` a function * Update docs/src/tutorials/prob-interface.md Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * transition to `Accessors` * more updates * Update src/utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * use fixed AbstractPPL for tests * adjust some util code related to compositing varname and optic * Update src/utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * update with recent AbstractPPL merge * test new APPL fix * remove the `Pkg.add`, causing issue with env resolution * use APPL pending fix for testing; fix more errors * fix more errors * use latest version of AbstractPPL; add Accessors to docmeta for testing * bump APPL version to test --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: David Widmann <[email protected]> Co-authored-by: Hong Ge <[email protected]>
1 parent 9deef5e commit 00604db

19 files changed

+258
-237
lines changed

Project.toml

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.24.11"
3+
version = "0.25.0"
4+
45

56
[deps]
67
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
78
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
89
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
10+
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
911
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
1012
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
1113
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -20,15 +22,31 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2022
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
2123
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2224
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
23-
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2425
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2526
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2627

28+
[weakdeps]
29+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
30+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
31+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
32+
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
33+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
34+
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
35+
36+
[extensions]
37+
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
38+
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
39+
DynamicPPLForwardDiffExt = ["ForwardDiff"]
40+
DynamicPPLMCMCChainsExt = ["MCMCChains"]
41+
DynamicPPLReverseDiffExt = ["ReverseDiff"]
42+
DynamicPPLZygoteRulesExt = ["ZygoteRules"]
43+
2744
[compat]
2845
ADTypes = "0.2"
2946
AbstractMCMC = "5"
30-
AbstractPPL = "0.7"
31-
BangBang = "0.3"
47+
AbstractPPL = "0.8.4"
48+
Accessors = "0.1"
49+
BangBang = "0.4"
3250
Bijectors = "0.13.9"
3351
ChainRulesCore = "1"
3452
Compat = "4"
@@ -44,29 +62,12 @@ MacroTools = "0.5.6"
4462
OrderedCollections = "1"
4563
Random = "1.6"
4664
Requires = "1"
47-
Setfield = "1"
4865
Test = "1.6"
4966
ZygoteRules = "0.2"
5067
julia = "1.6"
5168

52-
[extensions]
53-
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
54-
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
55-
DynamicPPLForwardDiffExt = ["ForwardDiff"]
56-
DynamicPPLMCMCChainsExt = ["MCMCChains"]
57-
DynamicPPLReverseDiffExt = ["ReverseDiff"]
58-
DynamicPPLZygoteRulesExt = ["ZygoteRules"]
59-
6069
[extras]
6170
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
6271
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
6372
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
6473
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
65-
66-
[weakdeps]
67-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
68-
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
69-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
70-
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
71-
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
72-
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

docs/Project.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,19 @@
11
[deps]
2+
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
23
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
34
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
45
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
56
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
67
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
78
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
8-
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
9-
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
109
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1110

1211
[compat]
12+
Accessors = "0.1"
1313
DataStructures = "0.18"
1414
Distributions = "0.25"
1515
Documenter = "1"
1616
FillArrays = "0.13, 1"
1717
LogDensityProblems = "2"
1818
MCMCChains = "5, 6"
19-
MLUtils = "0.3, 0.4"
20-
Setfield = "0.7.1, 0.8, 1"
2119
StableRNGs = "1"

docs/src/tutorials/prob-interface.md

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,28 @@ To give an example of the probability interface in use, we can use it to estimat
107107
In cross-validation, we split the dataset into several equal parts.
108108
Then, we choose one of these sets to serve as the validation set.
109109
Here, we measure fit using the cross entropy (Bayes loss).[^1]
110+
(For the sake of simplicity, in the following code, we enforce that `nfolds` must divide the number of data points. For a more competent implementation, see [MLUtils.jl](https://juliaml.github.io/MLUtils.jl/dev/api/#MLUtils.kfolds).)
110111

111112
```@example probinterface
112-
using MLUtils
113+
# Calculate the train/validation splits across `nfolds` partitions, assume `length(dataset)` divides `nfolds`
114+
function kfolds(dataset::Array{<:Real}, nfolds::Int)
115+
fold_size, remaining = divrem(length(dataset), nfolds)
116+
if remaining != 0
117+
error("The number of folds must divide the number of data points.")
118+
end
119+
first_idx = firstindex(dataset)
120+
last_idx = lastindex(dataset)
121+
splits = map(0:(nfolds - 1)) do i
122+
start_idx = first_idx + i * fold_size
123+
end_idx = start_idx + fold_size
124+
train_set_indices = [first_idx:(start_idx - 1); end_idx:last_idx]
125+
return (view(dataset, train_set_indices), view(dataset, start_idx:(end_idx - 1)))
126+
end
127+
return splits
128+
end
113129
114130
function cross_val(
115-
dataset::AbstractVector{<:Real};
131+
dataset::Vector{<:Real};
116132
nfolds::Int=5,
117133
nsamples::Int=1_000,
118134
rng::Random.AbstractRNG=Random.default_rng(),

src/DynamicPPL.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ using ADTypes: ADTypes
1212
using BangBang: BangBang, push!!, empty!!, setindex!!
1313
using MacroTools: MacroTools
1414
using ConstructionBase: ConstructionBase
15-
using Setfield: Setfield
15+
using Accessors: Accessors
1616
using LogDensityProblems: LogDensityProblems
1717
using LogDensityProblemsAD: LogDensityProblemsAD
1818

src/abstract_varinfo.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ julia> values_as(SimpleVarInfo(data), NamedTuple)
262262
(x = 1.0, m = [2.0])
263263
264264
julia> values_as(SimpleVarInfo(data), OrderedDict)
265-
OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Any} with 2 entries:
265+
OrderedDict{VarName{sym, typeof(identity)} where sym, Any} with 2 entries:
266266
x => 1.0
267267
m => [2.0]
268268
@@ -312,7 +312,7 @@ julia> values_as(vi, NamedTuple)
312312
(s = 1.0, m = 2.0)
313313
314314
julia> values_as(vi, OrderedDict)
315-
OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Float64} with 2 entries:
315+
OrderedDict{VarName{sym, typeof(identity)} where sym, Float64} with 2 entries:
316316
s => 1.0
317317
m => 2.0
318318
@@ -338,7 +338,7 @@ julia> values_as(vi, NamedTuple)
338338
(s = 1.0, m = 2.0)
339339
340340
julia> values_as(vi, OrderedDict)
341-
OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Float64} with 2 entries:
341+
OrderedDict{VarName{sym, typeof(identity)} where sym, Float64} with 2 entries:
342342
s => 1.0
343343
m => 2.0
344344
@@ -426,7 +426,7 @@ julia> # Extract one with only `m`.
426426
427427
428428
julia> keys(varinfo_subset1)
429-
1-element Vector{VarName{:m, Setfield.IdentityLens}}:
429+
1-element Vector{VarName{:m, typeof(identity)}}:
430430
m
431431
432432
julia> varinfo_subset1[@varname(m)]

src/compiler.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ const INTERNALNAMES = (:__model__, :__context__, :__varinfo__)
44
need_concretize(expr)
55
66
Return `true` if `expr` needs to be concretized, i.e., if it contains a colon `:` or
7-
requires a dynamic lens.
7+
requires a dynamic optic.
88
99
# Examples
1010
11-
```jldoctest; setup=:(using Setfield)
11+
```jldoctest; setup=:(using Accessors)
1212
julia> DynamicPPL.need_concretize(:(x[1, :]))
1313
true
1414
@@ -19,7 +19,7 @@ julia> DynamicPPL.need_concretize(:(x[1, 1]))
1919
false
2020
"""
2121
function need_concretize(expr)
22-
return Setfield.need_dynamic_lens(expr) || begin
22+
return Accessors.need_dynamic_optic(expr) || begin
2323
flag = false
2424
MacroTools.postwalk(expr) do ex
2525
# Concretise colon by default
@@ -202,13 +202,13 @@ variables.
202202
# Example
203203
```jldoctest; setup=:(using Distributions, LinearAlgebra)
204204
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(ones(2), I), randn(2, 2), @varname(x)); vns[end]
205-
x[:,2]
205+
x[:, 2]
206206
207207
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x)); vns[end]
208-
x[1,2]
208+
x[1, 2]
209209
210210
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:])); vns[end]
211-
x[:][1,2]
211+
x[:][1, 2]
212212
213213
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(3), @varname(x[1])); vns[end]
214214
x[1][3]
@@ -226,7 +226,7 @@ function unwrap_right_left_vns(
226226
# for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`,
227227
# and we therefore add the `Colon()` below.
228228
vns = map(axes(left, 2)) do i
229-
return AbstractPPL.concretize(vn Setfield.IndexLens((Colon(), i)), left)
229+
return AbstractPPL.concretize(Accessors.IndexLens((Colon(), i)) vn, left)
230230
end
231231
return unwrap_right_left_vns(right, left, vns)
232232
end
@@ -236,7 +236,7 @@ function unwrap_right_left_vns(
236236
vn::VarName,
237237
)
238238
vns = map(CartesianIndices(left)) do i
239-
return vn Setfield.IndexLens(Tuple(i))
239+
return Accessors.IndexLens(Tuple(i)) vn
240240
end
241241
return unwrap_right_left_vns(right, left, vns)
242242
end
@@ -437,7 +437,7 @@ function generate_tilde_assume(left, right, vn)
437437
expr = :($left = $value)
438438
if left isa Expr
439439
expr = AbstractPPL.drop_escape(
440-
Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true)
440+
Accessors.setmacro(BangBang.prefermutation, expr; overwrite=true)
441441
)
442442
end
443443

src/contexts.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,9 @@ end
288288

289289
function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
290290
if @generated
291-
return :(VarName{$(QuoteNode(Symbol(Prefix, PREFIX_SEPARATOR, Sym)))}(getlens(vn)))
291+
return :(VarName{$(QuoteNode(Symbol(Prefix, PREFIX_SEPARATOR, Sym)))}(getoptic(vn)))
292292
else
293-
VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(getlens(vn))
293+
VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(getoptic(vn))
294294
end
295295
end
296296

src/model.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -279,11 +279,11 @@ in their trace/`VarInfo`:
279279
280280
```jldoctest condition
281281
julia> keys(VarInfo(demo_outer()))
282-
1-element Vector{VarName{:m, Setfield.IdentityLens}}:
282+
1-element Vector{VarName{:m, typeof(identity)}}:
283283
m
284284
285285
julia> keys(VarInfo(demo_outer_prefix()))
286-
1-element Vector{VarName{Symbol("inner.m"), Setfield.IdentityLens}}:
286+
1-element Vector{VarName{Symbol("inner.m"), typeof(identity)}}:
287287
inner.m
288288
```
289289
@@ -448,7 +448,7 @@ julia> conditioned(cm)
448448
julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed,
449449
# `a.m` is treated as a random variable.
450450
keys(VarInfo(cm))
451-
1-element Vector{VarName{Symbol("a.m"), Setfield.IdentityLens}}:
451+
1-element Vector{VarName{Symbol("a.m"), typeof(identity)}}:
452452
a.m
453453
454454
julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation.
@@ -634,11 +634,11 @@ in their trace/`VarInfo`:
634634
635635
```jldoctest fix
636636
julia> keys(VarInfo(demo_outer()))
637-
1-element Vector{VarName{:m, Setfield.IdentityLens}}:
637+
1-element Vector{VarName{:m, typeof(identity)}}:
638638
m
639639
640640
julia> keys(VarInfo(demo_outer_prefix()))
641-
1-element Vector{VarName{Symbol("inner.m"), Setfield.IdentityLens}}:
641+
1-element Vector{VarName{Symbol("inner.m"), typeof(identity)}}:
642642
inner.m
643643
```
644644
@@ -830,7 +830,7 @@ julia> fixed(cm)
830830
julia> # Since we fixed on `m`, not `a.m` as it will appear after prefixed,
831831
# `a.m` is treated as a random variable.
832832
keys(VarInfo(cm))
833-
1-element Vector{VarName{Symbol("a.m"), Setfield.IdentityLens}}:
833+
1-element Vector{VarName{Symbol("a.m"), typeof(identity)}}:
834834
a.m
835835
836836
julia> # If we instead fix on `a.m`, `m` in the model will be considered an observation.

src/model_utils.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,13 @@ end
7878
function varname_in_chain!(
7979
x::AbstractArray, vn_parent::VarName{sym}, chain, chain_idx, iteration_idx, out
8080
) where {sym}
81-
# We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens.
82-
# This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)`
81+
# We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the optic.
82+
# This way we can use `getoptic(vn)` to extract the value from `x` and use `getoptic(vn) ∘ vn_parent`
8383
# to extract the value from the `chain`.
8484
for vn in varname_leaves(VarName{sym}(), x)
8585
# Update `out`, possibly in place, and return.
86-
l = AbstractPPL.getlens(vn)
87-
varname_in_chain!(x, vn_parent l, chain, chain_idx, iteration_idx, out)
86+
l = AbstractPPL.getoptic(vn)
87+
varname_in_chain!(x, l vn_parent, chain, chain_idx, iteration_idx, out)
8888
end
8989
return out
9090
end
@@ -103,17 +103,17 @@ end
103103
function values_from_chain(
104104
x::AbstractArray, vn_parent::VarName{sym}, chain, chain_idx, iteration_idx
105105
) where {sym}
106-
# We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens.
107-
# This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)`
106+
# We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the optic.
107+
# This way we can use `getoptic(vn)` to extract the value from `x` and use `getoptic(vn) ∘ vn_parent`
108108
# to extract the value from the `chain`.
109109
out = similar(x)
110110
for vn in varname_leaves(VarName{sym}(), x)
111111
# Update `out`, possibly in place, and return.
112-
l = AbstractPPL.getlens(vn)
113-
out = Setfield.set(
112+
l = AbstractPPL.getoptic(vn)
113+
out = Accessors.set(
114114
out,
115115
BangBang.prefermutation(l),
116-
chain[iteration_idx, Symbol(vn_parent l), chain_idx],
116+
chain[iteration_idx, Symbol(l vn_parent), chain_idx],
117117
)
118118
end
119119

0 commit comments

Comments
 (0)