Skip to content

Commit 7a8ba7e

Browse files
committed
Sibling PR of introduction of Setfield.jl in AbstractPPL.jl (#295)
This is a sibling PR to TuringLang/AbstractPPL.jl#26 fixing some issues + allowing us to do neat stuff. We also finally drop the passing of the `inds` around in the tilde-pipeline, which is not very useful now that we have the more general lenses in `VarName`. TODOs: - [X] ~Deprecate `*tilde_*` with `inds` argument appropriately.~ EDIT: On second thought, let's not. See comment for reason. - [x] It seems like the prob macro is now somehow broken 😕 - [X] ~(Maybe) Rewrite `@model` to not escape the entire expression.~ Deferred to #311 - [X] Figure out performance degradation. - Answer: `hash` for `Tuple` vs. `hash` for immutable struct 😕 ## Sample fields of structs ```julia julia> @model function demo(x, y) s ~ InverseGamma(2, 3) m ~ Normal(0, √s) for i in 2:length(x.a) - 1 x.a[i] ~ Normal(m, √s) end # Dynamic indexing x.a[begin] ~ Normal(-100.0, 1.0) x.a[end] ~ Normal(100.0, 1.0) # Immutable set y.a ~ Normal() # Dotted z = Vector{Float64}(undef, 3) z[1:2] .~ Normal() z[end:end] .~ Normal() return (; s, m, x, y, z) end julia> struct MyCoolStruct{T} a::T end julia> m = demo(MyCoolStruct([missing, missing]), MyCoolStruct(missing)); julia> m() (s = 3.483799020996254, m = -0.35566330762328, x = MyCoolStruct{Vector{Union{Missing, Float64}}}(Union{Missing, Float64}[-100.75592540694562, 98.61295291877542]), y = MyCoolStruct{Float64}(-2.1107980419121546), z = [-2.2868359094832584, -1.1378866583607443, 1.172250491861777]) ``` ## Sample fields of `DataFrame` ```julia julia> using DataFrames julia> using Setfield: ConstructionBase julia> function ConstructionBase.setproperties(df::DataFrame, patch::NamedTuple) # Only need `copy` because we'll replace entire columns columns = copy(DataFrames._columns(df)) colindex = DataFrames.index(df) for k in keys(patch) columns[colindex[k]] = patch[k] end return DataFrame(columns, colindex) end julia> @model function demo(x) s ~ InverseGamma(2, 3) m ~ Normal(0, √s) for i in 1:length(x.a) - 1 x.a[i] ~ Normal(m, √s) end x.a[end] ~ Normal(100.0, 1.0) return x end demo (generic function with 1 method) julia> m = demo(df, (a = missing, )); julia> m() 3×1 DataFrame Row │ a │ Float64? ─────┼────────── 1 │ 1.0 2 │ 2.0 3 │ 99.8838 julia> df 3×1 DataFrame Row │ a │ Float64? ─────┼─────────── 1 │ 1.0 2 │ 2.0 3 │ missing ``` # Benchmarks Unfortunately there does seem to be performance regression when using a very large number of varnames in a loop in the model (for broadcasting which uses the same number of varnames but does so "internally", there is no difference): ![image](https://user-images.githubusercontent.com/11074788/127791298-da3d0fb2-baab-428b-a555-3f4d2c63bd3b.png) The weird thing is that we're using less memory, indicating that type-inference might better? <details> <summary>0.31.1</summary> ## 0.31.1 ## ### Setup ### ```julia using BenchmarkTools, DynamicPPL, Distributions, Serialization ``` ```julia import DynamicPPLBenchmarks: time_model_def, make_suite, typed_code, weave_child ``` ### Models ### #### `demo1` #### ```julia @model function demo1(x) m ~ Normal() x ~ Normal(m, 1) return (m = m, x = x) end model_def = demo1; data = 1.0; ``` ```julia @time model_def(data)(); ``` ``` 0.059594 seconds (115.76 k allocations: 6.982 MiB, 99.91% compilation tim e) ``` ```julia m = time_model_def(model_def, data); ``` ``` 0.000004 seconds (2 allocations: 48 bytes) ``` ```julia suite = make_suite(m); results = run(suite); ``` ```julia results["evaluation_untyped"] ``` ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 619.000 ns … 19.678 μs ┊ GC (min … max): 0.00% … 0.0 0% Time (median): 654.000 ns ┊ GC (median): 0.00% Time (mean ± σ): 677.650 ns ± 333.145 ns ┊ GC (mean ± σ): 0.00% ± 0.0 0% ▅▆▇█▅▄▃ ▃▅███████▇▆▅▄▃▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂ ▃ 619 ns Histogram: frequency by time 945 ns < Memory estimate: 480 bytes, allocs estimate: 13. ``` ```julia results["evaluation_typed"] ``` ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 249.000 ns … 11.048 μs ┊ GC (min … max): 0.00% … 0.0 0% Time (median): 264.000 ns ┊ GC (median): 0.00% Time (mean ± σ): 267.650 ns ± 137.452 ns ┊ GC (mean ± σ): 0.00% ± 0.0 0% ▂▄ ▆▇ █▇ ▇▄ ▂▂ ▂▂▂▁▂▂▁▃▃▁▅▅▁███▁██▁██▁██▁██▁▇▇▅▁▄▄▁▃▃▁▃▃▁▃▂▁▂▂▂▁▂▂▁▂▂▁▂▂▁▂▂▂ ▃ 249 ns Histogram: frequency by time 291 ns < Memory estimate: 0 bytes, allocs estimate: 0. ``` ```julia if WEAVE_ARGS[:include_typed_code] typed = typed_code(m) end ``` #### `demo2` #### ```julia @model function demo2(y) # Our prior belief about the probability of heads in a coin. p ~ Beta(1, 1) # The number of observations. N = length(y) for n in 1:N # Heads or tails of a coin are drawn from a Bernoulli distribution. y[n] ~ Bernoulli(p) end end model_def = demo2; data = rand(0:1, 10); ``` ```julia @time model_def(data)(); ``` ``` 0.067078 seconds (143.91 k allocations: 8.544 MiB, 99.91% compilation tim e) ``` ```julia m = time_model_def(model_def, data); ``` ``` 0.000002 seconds (1 allocation: 32 bytes) ``` ```julia suite = make_suite(m); results = run(suite); ``` ```julia results["evaluation_untyped"] ``` ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 1.637 μs … 48.917 μs ┊ GC (min … max): 0.00% … 0.00% Time (median): 1.694 μs ┊ GC (median): 0.00% Time (mean ± σ): 1.746 μs ± 550.372 ns ┊ GC (mean ± σ): 0.00% ± 0.00% ▂█▇▃ ▁▄████▇▄▄▅▅▅▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂ 1.64 μs Histogram: frequency by time 2.23 μs < Memory estimate: 1.66 KiB, allocs estimate: 47. ``` ```julia results["evaluation_typed"] ``` ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 506.000 ns … 10.733 μs ┊ GC (min … max): 0.00% … 0.0 0% Time (median): 546.000 ns ┊ GC (median): 0.00% Time (mean ± σ): 553.478 ns ± 118.542 ns ┊ GC (mean ± σ): 0.00% ± 0.0 0% ▃█ ▆▅ ▂▃██▇▇██▅▃▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▁▁▂▂▁▁▁▁▁▂▂▁▂▂▁▁▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂ ▃ 506 ns Histogram: frequency by time 933 ns < Memory estimate: 0 bytes, allocs estimate: 0. ``` ```julia if WEAVE_ARGS[:include_typed_code] typed = typed_code(m) end ``` #### `demo3` #### ```julia @model function demo3(x) D, N = size(x) # Draw the parameters for cluster 1. μ1 ~ Normal() # Draw the parameters for cluster 2. μ2 ~ Normal() μ = [μ1, μ2] # Comment out this line if you instead want to draw the weights. w = [0.5, 0.5] # Draw assignments for each datum and generate it from a multivariate normal. k = Vector{Int}(undef, N) for i in 1:N k[i] ~ Categorical(w) x[:,i] ~ MvNormal([μ[k[i]], μ[k[i]]], 1.) end return k end model_def = demo3 # Construct 30 data points for each cluster. N = 30 # Parameters for each cluster, we assume that each cluster is Gaussian distributed in the example. μs = [-3.5, 0.0] # Construct the data points. data = mapreduce(c -> rand(MvNormal([μs[c], μs[c]], 1.), N), hcat, 1:2); ``` ```julia @time model_def(data)(); ``` ``` 0.097628 seconds (224.06 k allocations: 13.410 MiB, 99.79% compilation ti me) ``` ```julia m = time_model_def(model_def, data); ``` ``` 0.000002 seconds (1 allocation: 32 bytes) ``` ```julia suite = make_suite(m); results = run(suite); ``` ```julia results["evaluation_untyped"] ``` ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 48.200 μs … 16.129 ms ┊ GC (min … max): 0.00% … 99.5 3% Time (median): 51.017 μs ┊ GC (median): 0.00% Time (mean ± σ): 60.128 μs ± 265.008 μs ┊ GC (mean ± σ): 7.61% ± 1.7 2% ▂▆█ ████▂▂▂▁▂▃▄▅▇▅▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂ 48.2 μs Histogram: frequency by time 101 μs < Memory estimate: 48.20 KiB, allocs estimate: 1042. ``` ```julia results["evaluation_typed"] ``` ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 22.210 μs … 13.796 ms ┊ GC (min … max): 0.00% … 99.7 0% Time (median): 25.882 μs ┊ GC (median): 0.00% Time (mean ± σ): 27.536 μs ± 137.815 μs ┊ GC (mean ± σ): 5.00% ± 1.0 0% █▇▆▄▂ ▁▇▆▇▆▅▄▂ ▂▂▂▁ ▂ ████████████████████████▆▆▃▅▅▅▅▅▅▁▆▇▆▅▅▅▆▆▅▆▇▇▇▇▆▆▅▆▆▅▅▇█▇█▇ █ 22.2 μs Histogram: log(frequency) by time 51 μs < Memory estimate: 17.62 KiB, allocs estimate: 183. ``` ```julia if WEAVE_ARGS[:include_typed_code] typed = typed_code(m) end ``` #### `demo4`: loads of indexing #### ```julia @model function demo4(n, ::Type{TV}=Vector{Float64}) where {TV} m ~ Normal() x = TV(undef, n) for i in eachindex(x) x[i] ~ Normal(m, 1.0) end end model_def = demo4 data = (100_000, ); ``` ```julia @time model_def(data)(); ``` ``` 0.435154 seconds (3.12 M allocations: 192.275 MiB, 8.73% gc time, 1.84% c ompilation time) ``` ```julia m = time_model_def(model_def, data); ``` ``` 0.000002 seconds (2 allocations: 64 bytes) ``` ```julia suite = make_suite(m); results = run(suite); ``` ```julia results["evaluation_untyped"] ``` ``` BenchmarkTools.Trial: 62 samples with 1 evaluation. Range (min … max): 61.601 ms … 101.432 ms ┊ GC (min … max): 0.00% … 25.0 2% Time (median): 76.902 ms ┊ GC (median): 0.00% Time (mean ± σ): 77.276 ms ± 11.445 ms ┊ GC (mean ± σ): 6.48% ± 10.7 7% ▂ ▂ █ ▆ ▆▆██▄▄▁▄█▄▁▁▁▁▁▁▆▁█▁█▁▄████▁▄▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▄▁▁▆▁▁▄▆▄▁▄▁▆▄▁▄ ▁ 61.6 ms Histogram: frequency by time 101 ms < Memory estimate: 44.37 MiB, allocs estimate: 1357727. ``` ```julia results["evaluation_typed"] ``` ``` BenchmarkTools.Trial: 189 samples with 1 evaluation. Range (min … max): 23.796 ms … 40.845 ms ┊ GC (min … max): 0.00% … 0.00% Time (median): 24.838 ms ┊ GC (median): 0.00% Time (mean ± σ): 25.162 ms ± 1.434 ms ┊ GC (mean ± σ): 0.00% ± 0.00% ▁ ▂▂▃█▂ ▃▂ ▁ ▃▅█▃▇▅█▇██████▇█████▇▄▅▃▅▆▇█▃▆▃▃▄▅▁▄▁▃▁▆▅▄▁▁▁▃▁▃▄▁▁▃▃▁▁▁▁▃▄ ▃ 23.8 ms Histogram: frequency by time 27.8 ms < Memory estimate: 781.70 KiB, allocs estimate: 6. ``` ```julia if WEAVE_ARGS[:include_typed_code] typed = typed_code(m) end ``` ```julia @model function demo4_dotted(n, ::Type{TV}=Vector{Float64}) where {TV} m ~ Normal() x = TV(undef, n) x .~ Normal(m, 1.0) end model_def = demo4_dotted data = (100_000, ); ``` ```julia @time model_def(data)(); ``` ``` 1.476057 seconds (5.08 M allocations: 375.205 MiB, 5.02% gc time, 0.62% c ompilation time) ``` ```julia m = time_model_def(model_def, data); ``` ``` 0.000002 seconds (2 allocations: 64 bytes) ``` ```julia suite = make_suite(m); results = run(suite); ``` ```julia results["evaluation_untyped"] ``` ``` BenchmarkTools.Trial: 39 samples with 1 evaluation. Range (min … max): 112.078 ms … 350.311 ms ┊ GC (min … max): 11.20% … 4. 74% Time (median): 115.686 ms ┊ GC (median): 12.93% Time (mean ± σ): 122.722 ms ± 37.638 ms ┊ GC (mean ± σ): 12.96% ± 2. 85% █▅ ▁ ██▅█▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅ ▁ 112 ms Histogram: log(frequency) by time 350 ms < Memory estimate: 347.71 MiB, allocs estimate: 964550. ``` ```julia results["evaluation_typed"] ``` ``` BenchmarkTools.Trial: 59 samples with 1 evaluation. Range (min … max): 69.420 ms … 407.970 ms ┊ GC (min … max): 12.25% … 6.3 0% Time (median): 71.514 ms ┊ GC (median): 12.41% Time (mean ± σ): 78.481 ms ± 43.867 ms ┊ GC (mean ± σ): 12.80% ± 2.8 4% ▅▂█ █▅ ▇██████▅▅▄▁▅▁▁▁▁▁▁▁▁▁▁▁▁▁▅▁▁▁▄▄▁▁▁▄▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄ ▁ 69.4 ms Histogram: frequency by time 94.2 ms < Memory estimate: 337.55 MiB, allocs estimate: 399306. ``` ```julia if WEAVE_ARGS[:include_typed_code] typed = typed_code(m) end ``` </details> <details> <summary>This PR</summary> ## This PR ## ### Setup ### ```julia using BenchmarkTools, DynamicPPL, Distributions, Serialization ``` ```julia import DynamicPPLBenchmarks: time_model_def, make_suite, typed_code, weave_child ``` ### Models ### #### `demo1` #### ```julia @model function demo1(x) m ~ Normal() x ~ Normal(m, 1) return (m = m, x = x) end model_def = demo1; data = 1.0; ``` ```julia @time model_def(data)(); ``` ``` 1.063017 seconds (2.88 M allocations: 180.745 MiB, 4.19% gc time, 99.90% compilation time) ``` ```julia m = time_model_def(model_def, data); ``` ``` 0.000004 seconds (2 allocations: 48 bytes) ``` ```julia suite = make_suite(m); results = run(suite); ``` ```julia results["evaluation_untyped"] ``` ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 615.000 ns … 13.280 ms ┊ GC (min … max): 0.00% … 0.0 0% Time (median): 650.000 ns ┊ GC (median): 0.00% Time (mean ± σ): 2.037 μs ± 132.793 μs ┊ GC (mean ± σ): 0.00% ± 0.0 0% ▅█▇▅▄▄▃▂▁▁ ▁ ███████████▇▇▇▆▆▆▆▃▄▆▆▅▆▇▆▆▇▆▆▇▆▆▆▆▅▆▆▅▅▅▅▄▄▅▅▃▅▅▃▅▄▅▅▅▅▅▄▅▆▅ █ 615 ns Histogram: log(frequency) by time 1.7 μs < Memory estimate: 480 bytes, allocs estimate: 13. ``` ```julia results["evaluation_typed"] ``` ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 272.000 ns … 9.093 μs ┊ GC (min … max): 0.00% … 0.0 0% Time (median): 284.000 ns ┊ GC (median): 0.00% Time (mean ± σ): 310.535 ns ± 156.251 ns ┊ GC (mean ± σ): 0.00% ± 0.0 0% ▅█▆▄▃▃▂▁▁ ▁ ███████████▇▇▆▄▄▃▃▄▅▆▅▆▅▆▆▆▆▆▆▆▇▇▆▆▆▆▆▇▆▆▆▆▇▆▇▇▇▇▆▆▆▆▆▅▆▆▅▄▅▅ █ 272 ns Histogram: log(frequency) by time 643 ns < Memory estimate: 0 bytes, allocs estimate: 0. ``` ```julia if WEAVE_ARGS[:include_typed_code] typed = typed_code(m) end ``` #### `demo2` #### ```julia @model function demo2(y) # Our prior belief about the probability of heads in a coin. p ~ Beta(1, 1) # The number of observations. N = length(y) for n in 1:N # Heads or tails of a coin are drawn from a Bernoulli distribution. y[n] ~ Bernoulli(p) end end model_def = demo2; data = rand(0:1, 10); ``` ```julia @time model_def(data)(); ``` ``` 0.401535 seconds (863.20 k allocations: 51.771 MiB, 2.88% gc time, 99.90% compilation time) ``` ```julia m = time_model_def(model_def, data); ``` ``` 0.000003 seconds (1 allocation: 32 bytes) ``` ```julia suite = make_suite(m); results = run(suite); ``` ```julia results["evaluation_untyped"] ``` ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 1.672 μs … 9.849 ms ┊ GC (min … max): 0.00% … 0.00% Time (median): 1.754 μs ┊ GC (median): 0.00% Time (mean ± σ): 2.835 μs ± 98.472 μs ┊ GC (mean ± σ): 0.00% ± 0.00% ▅██▇▆▆▅▄▄▃▂▂▁▁ ▁▁▁ ▁ ▂ ██████████████████▇▇▇▇▇▆▇▆▅▆▄▄▁▄▄▄▆▇██████████▆▆▇▇▇▇▆▇▆▆▆▆ █ 1.67 μs Histogram: log(frequency) by time 3.19 μs < Memory estimate: 1.50 KiB, allocs estimate: 37. ``` ```julia results["evaluation_typed"] ``` ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 544.000 ns … 19.704 μs ┊ GC (min … max): 0.00% … 0.0 0% Time (median): 567.000 ns ┊ GC (median): 0.00% Time (mean ± σ): 578.671 ns ± 222.201 ns ┊ GC (mean ± σ): 0.00% ± 0.0 0% ▄█▇▅▂▃ ▃███████▅▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂▁▁▁▁▁▂▁▂▁▂▁▁▁▁▁▂▂▂▂▂▂▂▁▂▂▂▂▂ ▃ 544 ns Histogram: frequency by time 888 ns < Memory estimate: 0 bytes, allocs estimate: 0. ``` ```julia if WEAVE_ARGS[:include_typed_code] typed = typed_code(m) end ``` #### `demo3` #### ```julia @model function demo3(x) D, N = size(x) # Draw the parameters for cluster 1. μ1 ~ Normal() # Draw the parameters for cluster 2. μ2 ~ Normal() μ = [μ1, μ2] # Comment out this line if you instead want to draw the weights. w = [0.5, 0.5] # Draw assignments for each datum and generate it from a multivariate normal. k = Vector{Int}(undef, N) for i in 1:N k[i] ~ Categorical(w) x[:,i] ~ MvNormal([μ[k[i]], μ[k[i]]], 1.) end return k end model_def = demo3 # Construct 30 data points for each cluster. N = 30 # Parameters for each cluster, we assume that each cluster is Gaussian distributed in the example. μs = [-3.5, 0.0] # Construct the data points. data = mapreduce(c -> rand(MvNormal([μs[c], μs[c]], 1.), N), hcat, 1:2); ``` ```julia @time model_def(data)(); ``` ``` 1.031824 seconds (2.34 M allocations: 139.934 MiB, 3.16% gc time, 99.96% compilation time) ``` ```julia m = time_model_def(model_def, data); ``` ``` 0.000004 seconds (1 allocation: 32 bytes) ``` ```julia suite = make_suite(m); results = run(suite); ``` ```julia results["evaluation_untyped"] ``` ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 52.509 μs … 9.913 ms ┊ GC (min … max): 0.00% … 0.00 % Time (median): 53.706 μs ┊ GC (median): 0.00% Time (mean ± σ): 61.948 μs ± 210.490 μs ┊ GC (mean ± σ): 9.84% ± 3.27 % ▂▆██▇▆▅▄▄▃▃▂▂▁▁▁▁ ▁ ▂ █████████████████████▇█▇▇▇█████████▇▇▇▇▅▆▆▅▆▅▅▆▇▅▅▄▅▅▄▄▄▄▂▄▃ █ 52.5 μs Histogram: log(frequency) by time 71.3 μs < Memory estimate: 47.66 KiB, allocs estimate: 1007. ``` ```julia results["evaluation_typed"] ``` ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 25.046 μs … 7.474 ms ┊ GC (min … max): 0.00% … 99.4 0% Time (median): 25.591 μs ┊ GC (median): 0.00% Time (mean ± σ): 29.101 μs ± 105.160 μs ┊ GC (mean ± σ): 6.84% ± 1.9 8% ▇█▆▄▃▂▂▁ ▃▄▂▃▃▂▂▁ ▂ █████████▇█████████▇▆▅▆▆▇▇▇▇▅▆▆▅▃▅▄▃▂▂▃▂▄▄▅▄▄▅▄▅▅▅▆▆▅▅▅▅▆▇▇█ █ 25 μs Histogram: log(frequency) by time 46 μs < Memory estimate: 17.62 KiB, allocs estimate: 183. ``` ```julia if WEAVE_ARGS[:include_typed_code] typed = typed_code(m) end ``` #### `demo4`: lots of univariate random variables #### ```julia @model function demo4(n, ::Type{TV}=Vector{Float64}) where {TV} m ~ Normal() x = TV(undef, n) for i in eachindex(x) x[i] ~ Normal(m, 1.0) end end model_def = demo4 data = (100_000, ); ``` ```julia @time model_def(data)(); ``` ``` 0.835503 seconds (3.93 M allocations: 244.654 MiB, 10.38% gc time, 9.43% compilation time) ``` ```julia m = time_model_def(model_def, data); ``` ``` 0.000004 seconds (2 allocations: 64 bytes) ``` ```julia suite = make_suite(m); results = run(suite); ``` ```julia results["evaluation_untyped"] ``` ``` BenchmarkTools.Trial: 60 samples with 1 evaluation. Range (min … max): 68.149 ms … 104.358 ms ┊ GC (min … max): 0.00% … 0.00 % Time (median): 77.456 ms ┊ GC (median): 0.00% Time (mean ± σ): 80.173 ms ± 9.858 ms ┊ GC (mean ± σ): 6.67% ± 8.31 % ▆█ █▄ ▂▄ █▆██▁▁▄▁▁▁▁▁▁▁▁▁▁▆▆▁██▆▁▄▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆████▄▁▄▁▄ ▁ 68.1 ms Histogram: frequency by time 94.8 ms < Memory estimate: 42.78 MiB, allocs estimate: 1253404. ``` ```julia results["evaluation_typed"] ``` ``` BenchmarkTools.Trial: 145 samples with 1 evaluation. Range (min … max): 29.232 ms … 139.283 ms ┊ GC (min … max): 0.00% … 0.00 % Time (median): 30.997 ms ┊ GC (median): 0.00% Time (mean ± σ): 32.506 ms ± 9.228 ms ┊ GC (mean ± σ): 0.23% ± 1.93 % ▁▆█▇▆▃▁ ▃▆███████▅▄▃▃▃▃▃▃▁▅▅▄▃▁▁▃▁▃▃▃▁▁▁▃▃▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃ ▃ 29.2 ms Histogram: frequency by time 46.4 ms < Memory estimate: 781.86 KiB, allocs estimate: 7. ``` ```julia if WEAVE_ARGS[:include_typed_code] typed = typed_code(m) end ``` ```julia @model function demo4_dotted(n, ::Type{TV}=Vector{Float64}) where {TV} m ~ Normal() x = TV(undef, n) x .~ Normal(m, 1.0) end model_def = demo4_dotted data = (100_000, ); ``` ```julia @time model_def(data)(); ``` ``` 1.421197 seconds (5.08 M allocations: 375.131 MiB, 6.23% gc time, 0.62% c ompilation time) ``` ```julia m = time_model_def(model_def, data); ``` ``` 0.000002 seconds (2 allocations: 64 bytes) ``` ```julia suite = make_suite(m); results = run(suite); ``` ```julia results["evaluation_untyped"] ``` ``` BenchmarkTools.Trial: 39 samples with 1 evaluation. Range (min … max): 108.605 ms … 348.289 ms ┊ GC (min … max): 9.70% … 9. 23% Time (median): 118.470 ms ┊ GC (median): 15.38% Time (mean ± σ): 121.407 ms ± 37.585 ms ┊ GC (mean ± σ): 13.35% ± 3. 15% ▆ █ █▁█▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃ ▁ 109 ms Histogram: frequency by time 348 ms < Memory estimate: 347.69 MiB, allocs estimate: 963583. ``` ```julia results["evaluation_typed"] ``` ``` BenchmarkTools.Trial: 61 samples with 1 evaluation. Range (min … max): 66.380 ms … 350.632 ms ┊ GC (min … max): 9.01% … 4.7 7% Time (median): 73.635 ms ┊ GC (median): 16.29% Time (mean ± σ): 75.751 ms ± 35.996 ms ┊ GC (mean ± σ): 12.78% ± 3.8 9% █ ▄ ▃ ▇█▆▆▄▄▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▆████▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃ ▁ 66.4 ms Histogram: frequency by time 84.5 ms < Memory estimate: 337.55 MiB, allocs estimate: 399306. ``` ```julia if WEAVE_ARGS[:include_typed_code] typed = typed_code(m) end ``` </details>
1 parent 060a4d1 commit 7a8ba7e

13 files changed

+252
-189
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.15.2"
3+
version = "0.16.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -12,16 +12,18 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1212
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1313
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1414
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
15+
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
1516
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1617
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1718

1819
[compat]
1920
AbstractMCMC = "2, 3.0"
20-
AbstractPPL = "0.2"
21+
AbstractPPL = "0.3"
2122
BangBang = "0.3"
2223
Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9"
2324
ChainRulesCore = "0.9.7, 0.10, 1"
2425
Distributions = "0.23.8, 0.24, 0.25"
2526
MacroTools = "0.5.6"
27+
Setfield = "0.7.1"
2628
ZygoteRules = "0.2"
2729
julia = "1.3"

benchmarks/benchmark_body.jmd

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,15 @@ m = time_model_def(model_def, data);
88

99
```julia
1010
suite = make_suite(m);
11-
results = run(suite)
12-
results
11+
results = run(suite);
12+
```
13+
14+
```julia
15+
results["evaluation_untyped"]
16+
```
17+
18+
```julia
19+
results["evaluation_typed"]
1320
```
1421

1522
```julia; echo=false; results="hidden";

benchmarks/benchmarks.jmd

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,37 @@ data = mapreduce(c -> rand(MvNormal([μs[c], μs[c]], 1.), N), hcat, 1:2);
9494
```julia; echo=false
9595
weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS)
9696
```
97+
98+
### `demo4`: loads of indexing
99+
100+
```julia
101+
@model function demo4(n, ::Type{TV}=Vector{Float64}) where {TV}
102+
m ~ Normal()
103+
x = TV(undef, n)
104+
for i in eachindex(x)
105+
x[i] ~ Normal(m, 1.0)
106+
end
107+
end
108+
109+
model_def = demo4
110+
data = (100_000, );
111+
```
112+
113+
```julia; echo=false
114+
weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS)
115+
```
116+
117+
```julia
118+
@model function demo4_dotted(n, ::Type{TV}=Vector{Float64}) where {TV}
119+
m ~ Normal()
120+
x = TV(undef, n)
121+
x .~ Normal(m, 1.0)
122+
end
123+
124+
model_def = demo4_dotted
125+
data = (100_000, );
126+
```
127+
128+
```julia; echo=false
129+
weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS)
130+
```

src/DynamicPPL.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ using MacroTools: MacroTools
1111
using ZygoteRules: ZygoteRules
1212
using BangBang: BangBang
1313

14+
using Setfield: Setfield
15+
using BangBang: BangBang
16+
1417
using Random: Random
1518

1619
import Base:

src/compiler.jl

Lines changed: 69 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ function isassumption(expr::Union{Symbol,Expr})
1818
vn = gensym(:vn)
1919

2020
return quote
21-
let $vn = $(varname(expr))
21+
let $vn = $(AbstractPPL.drop_escape(varname(expr)))
2222
if $(DynamicPPL.contextual_isassumption)(__context__, $vn)
2323
# Considered an assumption by `__context__` which means either:
2424
# 1. We hit the default implementation, e.g. using `DefaultContext`,
@@ -133,17 +133,17 @@ variables.
133133
134134
# Example
135135
```jldoctest; setup=:(using Distributions, LinearAlgebra)
136-
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(ones(2), I), randn(2, 2), @varname(x)); string(vns[end])
137-
"x[:,2]"
136+
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(ones(2), I), randn(2, 2), @varname(x)); vns[end]
137+
x[:,2]
138138
139-
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:])); string(vns[end])
140-
"x[:][1,2]"
139+
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x)); vns[end]
140+
x[1,2]
141141
142-
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(3), @varname(x[1])); string(vns[end])
143-
"x[1][3]"
142+
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:])); vns[end]
143+
x[:][1,2]
144144
145-
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2, 3), @varname(x)); string(vns[end])
146-
"x[1,2,3]"
145+
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(3), @varname(x[1])); vns[end]
146+
x[1][3]
147147
```
148148
"""
149149
unwrap_right_left_vns(right, left, vns) = right, left, vns
@@ -158,7 +158,7 @@ function unwrap_right_left_vns(
158158
# for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`,
159159
# and we therefore add the `Colon()` below.
160160
vns = map(axes(left, 2)) do i
161-
return VarName(vn, (vn.indexing..., (Colon(), i)))
161+
return vn Setfield.IndexLens((Colon(), i))
162162
end
163163
return unwrap_right_left_vns(right, left, vns)
164164
end
@@ -168,7 +168,7 @@ function unwrap_right_left_vns(
168168
vn::VarName,
169169
)
170170
vns = map(CartesianIndices(left)) do i
171-
return VarName(vn, (vn.indexing..., Tuple(i)))
171+
return vn Setfield.IndexLens(Tuple(i))
172172
end
173173
return unwrap_right_left_vns(right, left, vns)
174174
end
@@ -317,6 +317,10 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
317317
# Do not touch interpolated expressions
318318
expr.head === :$ && return expr.args[1]
319319

320+
# Do we don't want escaped expressions because we unfortunately
321+
# escape the entire body afterwards.
322+
Meta.isexpr(expr, :escape) && return generate_mainbody(mod, found, expr.args[1], warn)
323+
320324
# If it's a macro, we expand it
321325
if Meta.isexpr(expr, :macrocall)
322326
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn)
@@ -349,38 +353,36 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
349353
return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...)
350354
end
351355

356+
function generate_tilde_literal(left, right)
357+
# If the LHS is a literal, it is always an observation
358+
return quote
359+
$(DynamicPPL.tilde_observe!)(
360+
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
361+
)
362+
end
363+
end
364+
352365
"""
353366
generate_tilde(left, right)
354367
355368
Generate an `observe` expression for data variables and `assume` expression for parameter
356369
variables.
357370
"""
358371
function generate_tilde(left, right)
359-
# If the LHS is a literal, it is always an observation
360-
if isliteral(left)
361-
return quote
362-
$(DynamicPPL.tilde_observe!)(
363-
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
364-
)
365-
end
366-
end
372+
isliteral(left) && return generate_tilde_literal(left, right)
367373

368374
# Otherwise it is determined by the model or its value,
369375
# if the LHS represents an observation
370-
@gensym vn inds isassumption
376+
@gensym vn isassumption
377+
378+
# HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact
379+
# that in DynamicPPL we the entire function body. Instead we should be
380+
# more selective with our escape. Until that's the case, we remove them all.
371381
return quote
372-
$vn = $(varname(left))
373-
$inds = $(vinds(left))
382+
$vn = $(AbstractPPL.drop_escape(varname(left)))
374383
$isassumption = $(DynamicPPL.isassumption(left))
375384
if $isassumption
376-
$left = $(DynamicPPL.tilde_assume!)(
377-
__context__,
378-
$(DynamicPPL.unwrap_right_vn)(
379-
$(DynamicPPL.check_tilde_rhs)($right), $vn
380-
)...,
381-
$inds,
382-
__varinfo__,
383-
)
385+
$(generate_tilde_assume(left, right, vn))
384386
else
385387
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
386388
if !$(DynamicPPL.inargnames)($vn, __model__)
@@ -392,44 +394,46 @@ function generate_tilde(left, right)
392394
$(DynamicPPL.check_tilde_rhs)($right),
393395
$(maybe_view(left)),
394396
$vn,
395-
$inds,
396397
__varinfo__,
397398
)
398399
end
399400
end
400401
end
401402

403+
function generate_tilde_assume(left, right, vn)
404+
expr = :(
405+
$left = $(DynamicPPL.tilde_assume!)(
406+
__context__,
407+
$(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)...,
408+
__varinfo__,
409+
)
410+
)
411+
412+
return if left isa Expr
413+
AbstractPPL.drop_escape(
414+
Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true)
415+
)
416+
else
417+
return expr
418+
end
419+
end
420+
402421
"""
403422
generate_dot_tilde(left, right)
404423
405424
Generate the expression that replaces `left .~ right` in the model body.
406425
"""
407426
function generate_dot_tilde(left, right)
408-
# If the LHS is a literal, it is always an observation
409-
if isliteral(left)
410-
return quote
411-
$(DynamicPPL.dot_tilde_observe!)(
412-
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
413-
)
414-
end
415-
end
427+
isliteral(left) && return generate_tilde_literal(left, right)
416428

417429
# Otherwise it is determined by the model or its value,
418430
# if the LHS represents an observation
419-
@gensym vn inds isassumption
431+
@gensym vn isassumption
420432
return quote
421-
$vn = $(varname(left))
422-
$inds = $(vinds(left))
433+
$vn = $(AbstractPPL.drop_escape(varname(left)))
423434
$isassumption = $(DynamicPPL.isassumption(left))
424435
if $isassumption
425-
$left .= $(DynamicPPL.dot_tilde_assume!)(
426-
__context__,
427-
$(DynamicPPL.unwrap_right_left_vns)(
428-
$(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn
429-
)...,
430-
$inds,
431-
__varinfo__,
432-
)
436+
$(generate_dot_tilde_assume(left, right, vn))
433437
else
434438
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
435439
if !$(DynamicPPL.inargnames)($vn, __model__)
@@ -441,13 +445,27 @@ function generate_dot_tilde(left, right)
441445
$(DynamicPPL.check_tilde_rhs)($right),
442446
$(maybe_view(left)),
443447
$vn,
444-
$inds,
445448
__varinfo__,
446449
)
447450
end
448451
end
449452
end
450453

454+
function generate_dot_tilde_assume(left, right, vn)
455+
# We don't need to use `Setfield.@set` here since
456+
# `.=` is always going to be inplace + needs `left` to
457+
# be something that supports `.=`.
458+
return :(
459+
$left .= $(DynamicPPL.dot_tilde_assume!)(
460+
__context__,
461+
$(DynamicPPL.unwrap_right_left_vns)(
462+
$(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn
463+
)...,
464+
__varinfo__,
465+
)
466+
)
467+
end
468+
451469
const FloatOrArrayType = Type{<:Union{AbstractFloat,AbstractArray}}
452470
hasmissing(T::Type{<:AbstractArray{TA}}) where {TA<:AbstractArray} = hasmissing(TA)
453471
hasmissing(T::Type{<:AbstractArray{>:Missing}}) = true

0 commit comments

Comments
 (0)