Skip to content

Conversation

@mhauru
Copy link
Member

@mhauru mhauru commented Nov 20, 2025

I decided that rather than take over VarInfo like in #1074, the first use case of VarNamedTuple should be replacing the NamedTuple/Dict combo in FastLDF. That's what this PR does.

This is still work in progress:

  • Documentation is lacking/out of date
  • There's dead code, and unnecessarily complex code
  • Performance on Julia v1.11 needs fixing
  • There's type piracy
  • This doesn't handle Colons in VarNames.

However, tests seem to pass, so I'm putting this up. I ran the familiar FastLDF benchmarks from #1132, adapted a bit. Source code:

module VNTBench

using DynamicPPL, Distributions, LogDensityProblems, Chairmarks, LinearAlgebra
using ADTypes, ForwardDiff, ReverseDiff
@static if VERSION < v"1.12"
    using Enzyme, Mooncake
end

const adtypes = @static if VERSION < v"1.12"
    [
        ("FD", AutoForwardDiff()),
        ("RD", AutoReverseDiff()),
        ("MC", AutoMooncake()),
        ("EN" => AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const))
    ]
else
    [
        ("FD", AutoForwardDiff()),
        ("RD", AutoReverseDiff()),
    ]
end

function benchmark_ldfs(model; skip=Union{})
    vi = VarInfo(model)
    x = vi[:]
    ldf_no = DynamicPPL.LogDensityFunction(model, getlogjoint, vi)
    fldf_no = DynamicPPL.Experimental.FastLDF(model, getlogjoint, vi)
    @assert LogDensityProblems.logdensity(ldf_no, x)  LogDensityProblems.logdensity(fldf_no, x)
    median_new = median(@be LogDensityProblems.logdensity(fldf_no, x))
    print("           FastLDF: eval      ----  ")
    display(median_new)
    for name_adtype in adtypes
        name, adtype = name_adtype
        adtype isa skip && continue
        ldf = DynamicPPL.LogDensityFunction(model, getlogjoint, vi; adtype=adtype)
        fldf = DynamicPPL.Experimental.FastLDF(model, getlogjoint, vi; adtype=adtype)
        ldf_grad = LogDensityProblems.logdensity_and_gradient(ldf, x)
        fldf_grad = LogDensityProblems.logdensity_and_gradient(fldf, x)
        @assert ldf_grad[2]  fldf_grad[2]
        median_new = median(@be LogDensityProblems.logdensity_and_gradient(fldf, x))
        print("           FastLDF: grad ($name) ----  ")
        display(median_new)
    end
end

println("Trivial model")
@model f() = x ~ Normal()
benchmark_ldfs(f())

println("Eight schools")
y = [28, 8, -3, 7, -1, 1, 18, 12]
sigma = [15, 10, 16, 11, 9, 11, 10, 18]
@model function eight_schools(y, sigma)
    mu ~ Normal(0, 5)
    tau ~ truncated(Cauchy(0, 5); lower=0)
    theta ~ MvNormal(fill(mu, length(y)), tau^2 * I)
    for i in eachindex(y)
        y[i] ~ Normal(theta[i], sigma[i])
    end
    return (mu=mu, tau=tau)
end
benchmark_ldfs(eight_schools(y, sigma))

println("IndexLenses, dim=1_000")
@model function badvarnames()
    N = 1_000
    x = Vector{Float64}(undef, N)
    for i in 1:N
        x[i] ~ Normal()
    end
end
benchmark_ldfs(badvarnames())

println("Submodel")
@model function inner()
    m ~ Normal(0, 1)
    s ~ Exponential()
    return (m=m, s=s)
end
@model function withsubmodel()
    params ~ to_submodel(inner())
    y ~ Normal(params.m, params.s)
    1.0 ~ Normal(y)
end
benchmark_ldfs(withsubmodel())

end

Results on Julia v1.12:

On base(breaking):
Trivial model
           FastLDF: eval      ----  18.047 ns
           FastLDF: grad (FD) ----  51.805 ns (3 allocs: 96 bytes)
           FastLDF: grad (RD) ----  3.157 μs (45 allocs: 1.531 KiB)
Eight schools
           FastLDF: eval      ----  165.723 ns (4 allocs: 256 bytes)
           FastLDF: grad (FD) ----  685.846 ns (11 allocs: 2.594 KiB)
           FastLDF: grad (RD) ----  39.959 μs (562 allocs: 20.562 KiB)
IndexLenses, dim=1_000
           FastLDF: eval      ----  24.250 μs (14 allocs: 8.312 KiB)
           FastLDF: grad (FD) ----  6.296 ms (1516 allocs: 11.197 MiB)
           FastLDF: grad (RD) ----  2.577 ms (38029 allocs: 1.321 MiB)
Submodel
           FastLDF: eval      ----  57.568 ns
           FastLDF: grad (FD) ----  179.448 ns (3 allocs: 112 bytes)
           FastLDF: grad (RD) ----  10.750 μs (145 allocs: 5.062 KiB)

On this branch:
Trivial model
           FastLDF: eval      ----  11.869 ns
           FastLDF: grad (FD) ----  53.264 ns (3 allocs: 96 bytes)
           FastLDF: grad (RD) ----  3.273 μs (45 allocs: 1.531 KiB)
Eight schools
           FastLDF: eval      ----  203.159 ns (4 allocs: 256 bytes)
           FastLDF: grad (FD) ----  718.750 ns (11 allocs: 2.594 KiB)
           FastLDF: grad (RD) ----  39.792 μs (562 allocs: 20.562 KiB)
IndexLenses, dim=1_000
           FastLDF: eval      ----  9.181 μs (2 allocs: 8.031 KiB)
           FastLDF: grad (FD) ----  4.235 ms (508 allocs: 11.174 MiB)
           FastLDF: grad (RD) ----  2.560 ms (38017 allocs: 1.321 MiB)
Submodel
           FastLDF: eval      ----  49.660 ns
           FastLDF: grad (FD) ----  221.359 ns (3 allocs: 112 bytes)
           FastLDF: grad (RD) ----  10.667 μs (148 allocs: 5.219 KiB)

Same thing but in Julia v1.11:

On base(breaking):
Trivial model
           FastLDF: eval      ----  11.082 ns
           FastLDF: grad (FD) ----  53.747 ns (3 allocs: 96 bytes)
           FastLDF: grad (RD) ----  3.069 μs (46 allocs: 1.562 KiB)
           FastLDF: grad (MC) ----  221.910 ns (2 allocs: 64 bytes)
           FastLDF: grad (EN) ----  128.970 ns (2 allocs: 64 bytes)
Eight schools
           FastLDF: eval      ----  164.326 ns (4 allocs: 256 bytes)
           FastLDF: grad (FD) ----  690.049 ns (11 allocs: 2.594 KiB)
           FastLDF: grad (RD) ----  39.250 μs (562 allocs: 20.562 KiB)
           FastLDF: grad (MC) ----  1.082 μs (10 allocs: 656 bytes)
           FastLDF: grad (EN) ----  733.325 ns (13 allocs: 832 bytes)
IndexLenses, dim=1_000
           FastLDF: eval      ----  33.458 μs (15 allocs: 8.344 KiB)
           FastLDF: grad (FD) ----  6.652 ms (1516 allocs: 11.197 MiB)
           FastLDF: grad (RD) ----  2.488 ms (38028 allocs: 1.321 MiB)
           FastLDF: grad (MC) ----  89.583 μs (21 allocs: 24.469 KiB)
           FastLDF: grad (EN) ----  92.833 μs (20 allocs: 102.531 KiB)
Submodel
           FastLDF: eval      ----  70.884 ns
           FastLDF: grad (FD) ----  135.958 ns (3 allocs: 112 bytes)
           FastLDF: grad (RD) ----  10.563 μs (148 allocs: 5.188 KiB)
           FastLDF: grad (MC) ----  481.250 ns (2 allocs: 80 bytes)
           FastLDF: grad (EN) ----  344.612 ns (2 allocs: 80 bytes)

On this branch:
Trivial model
           FastLDF: eval      ----  1.309 μs (27 allocs: 800 bytes)
           FastLDF: grad (FD) ----  1.522 μs (30 allocs: 960 bytes)
           FastLDF: grad (RD) ----  4.667 μs (71 allocs: 2.344 KiB)
           FastLDF: grad (MC) ----  358.143 ns (7 allocs: 224 bytes)
           FastLDF: grad (EN) ----  130.768 ns (2 allocs: 64 bytes)
Eight schools
           FastLDF: eval      ----  164.326 ns (4 allocs: 256 bytes)
           FastLDF: grad (FD) ----  645.378 ns (11 allocs: 2.594 KiB)
           FastLDF: grad (RD) ----  39.541 μs (562 allocs: 20.562 KiB)
           FastLDF: grad (MC) ----  1.043 μs (10 allocs: 656 bytes)
           FastLDF: grad (EN) ----  747.925 ns (13 allocs: 832 bytes)
IndexLenses, dim=1_000
           FastLDF: eval      ----  9.430 μs (3 allocs: 8.062 KiB)
           FastLDF: grad (FD) ----  4.616 ms (508 allocs: 11.174 MiB)
           FastLDF: grad (RD) ----  2.467 ms (38016 allocs: 1.321 MiB)
           FastLDF: grad (MC) ----  73.292 μs (9 allocs: 24.188 KiB)
           FastLDF: grad (EN) ----  72.875 μs (8 allocs: 102.250 KiB)
Submodel
           FastLDF: eval      ----  52.213 ns
           FastLDF: grad (FD) ----  107.166 ns (3 allocs: 112 bytes)
           FastLDF: grad (RD) ----  10.521 μs (142 allocs: 5.078 KiB)
           FastLDF: grad (MC) ----  453.493 ns (2 allocs: 80 bytes)
           FastLDF: grad (EN) ----  320.367 ns (2 allocs: 80 bytes)

So on 1.12 all looks good: This is a bit faster than the old version, substantial faster when there are a lot of IndexLenses, as it should. On 1.11 performance is destroyed, probably because type inference fails/gives up, and I need to fix that.

The main point of this PR is not performance, but having a general data structure for storing information keyed by VarNames, so I'm happy as long as performance doesn't degrade. Next up would be using this same data structure for ConditionContext (hoping to fix #1148), ValuesAsInModelAcc, maybe some other Accumulators, InitFromParams, GibbsContext, and finally to implement an AbstractVarInfo type.

I'll update the docs page with more information about what the current design is that I've implemented, but the one sentence summary is that it's nested NamedTuples, and then whenever we meet IndexLenses, it's an Array for the values together with a mask-Array that marks which values are valid values and which are just placeholders.

I think I know how to fix all the current short-comings, except for Colons in VarNames. Setting a value in a VNT with a Colon could be done, but getting seems ill-defined, at least without providing further information about the size the value should be.

vnt = VarNamedTuple(
vnt = setindex!!(vnt, 1, @varname(x[2]))
vnt = setindex!!(vnt, 1, @varname(x[4]))
getindex(@varname(x[:])  # What should this return?

cc @penelopeysm, though this isn't ready for reviews yet.

@github-actions
Copy link
Contributor

github-actions bot commented Nov 20, 2025

Benchmark Report

  • this PR's head: 871eb9fd1216f392460462d4c84d8a38ca89da05
  • base branch: 3cd8d3431e14ebc581266c1323d1db8a5bd4c0eb

Computer Information

Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

┌───────────────────────┬───────┬─────────────┬───────────────────┬────────┬─────────────────────────────────┬───────────────────────────┬─────────────────────────────────┐
│                       │       │             │                   │        │        t(eval) / t(ref)         │     t(grad) / t(eval)     │        t(grad) / t(ref)         │
│                       │       │             │                   │        │ ──────────┬───────────┬──────── │ ──────┬─────────┬──────── │ ──────────┬───────────┬──────── │
│                 Model │   Dim │  AD Backend │           VarInfo │ Linked │      base │   this PR │ speedup │  base │ this PR │ speedup │      base │   this PR │ speedup │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼───────────┼───────────┼─────────┼───────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│               Dynamic │    10 │    mooncake │             typed │   true │    419.06 │    467.49 │    0.90 │ 10.38 │    9.56 │    1.09 │   4351.37 │   4471.40 │    0.97 │
│                   LDA │    12 │ reversediff │             typed │   true │   2788.41 │   2922.53 │    0.95 │  2.06 │    2.24 │    0.92 │   5750.78 │   6548.58 │    0.88 │
│   Loop univariate 10k │ 10000 │    mooncake │             typed │   true │ 151999.15 │ 157531.68 │    0.96 │  5.69 │    5.73 │    0.99 │ 864397.02 │ 902772.36 │    0.96 │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼───────────┼───────────┼─────────┼───────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│    Loop univariate 1k │  1000 │    mooncake │             typed │   true │  13750.41 │  15822.22 │    0.87 │  6.21 │    5.13 │    1.21 │  85411.71 │  81201.51 │    1.05 │
│      Multivariate 10k │ 10000 │    mooncake │             typed │   true │  31962.62 │  32144.23 │    0.99 │  9.67 │    9.69 │    1.00 │ 309017.17 │ 311528.73 │    0.99 │
│       Multivariate 1k │  1000 │    mooncake │             typed │   true │   3638.13 │   3976.84 │    0.91 │  8.60 │    8.05 │    1.07 │  31273.26 │  32003.52 │    0.98 │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼───────────┼───────────┼─────────┼───────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│ Simple assume observe │     1 │ forwarddiff │             typed │  false │     17.12 │     17.62 │    0.97 │  1.86 │    1.88 │    0.99 │     31.84 │     33.13 │    0.96 │
│           Smorgasbord │   201 │ forwarddiff │             typed │  false │   2488.23 │   2548.51 │    0.98 │ 47.31 │   47.29 │    1.00 │ 117730.52 │ 120507.17 │    0.98 │
│           Smorgasbord │   201 │ forwarddiff │       simple_dict │   true │  23606.95 │  24001.75 │    0.98 │ 25.39 │   26.03 │    0.98 │ 599406.03 │ 624744.27 │    0.96 │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼───────────┼───────────┼─────────┼───────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│           Smorgasbord │   201 │ forwarddiff │ simple_namedtuple │   true │   1033.53 │   1071.08 │    0.96 │ 78.19 │   79.13 │    0.99 │  80809.38 │  84757.50 │    0.95 │
│           Smorgasbord │   201 │      enzyme │             typed │   true │   2543.38 │   2657.70 │    0.96 │  4.74 │    4.84 │    0.98 │  12065.44 │  12875.64 │    0.94 │
│           Smorgasbord │   201 │    mooncake │             typed │   true │   2544.68 │   2675.62 │    0.95 │  5.95 │    5.52 │    1.08 │  15143.96 │  14761.11 │    1.03 │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼───────────┼───────────┼─────────┼───────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│           Smorgasbord │   201 │ reversediff │             typed │   true │   2614.44 │   2681.75 │    0.97 │ 55.04 │   55.60 │    0.99 │ 143900.24 │ 149094.20 │    0.97 │
│           Smorgasbord │   201 │ forwarddiff │      typed_vector │   true │   2569.43 │   2650.38 │    0.97 │ 44.93 │   45.36 │    0.99 │ 115446.57 │ 120227.33 │    0.96 │
│           Smorgasbord │   201 │ forwarddiff │           untyped │   true │   2299.79 │   2340.44 │    0.98 │ 44.97 │   47.35 │    0.95 │ 103421.01 │ 110828.14 │    0.93 │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼───────────┼───────────┼─────────┼───────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│           Smorgasbord │   201 │ forwarddiff │    untyped_vector │   true │   2283.87 │   2367.33 │    0.96 │ 45.88 │   45.63 │    1.01 │ 104779.53 │ 108020.70 │    0.97 │
│              Submodel │     1 │    mooncake │             typed │   true │     25.83 │     26.03 │    0.99 │  5.24 │    5.35 │    0.98 │    135.32 │    139.36 │    0.97 │
└───────────────────────┴───────┴─────────────┴───────────────────┴────────┴───────────┴───────────┴─────────┴───────┴─────────┴─────────┴───────────┴───────────┴─────────┘

@codecov
Copy link

codecov bot commented Nov 20, 2025

Codecov Report

❌ Patch coverage is 32.31939% with 178 lines in your changes missing coverage. Please review.
✅ Project coverage is 78.10%. Comparing base (3cd8d34) to head (871eb9f).

Files with missing lines Patch % Lines
src/varnamedtuple.jl 27.75% 177 Missing ⚠️
src/utils.jl 75.00% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff              @@
##           breaking    #1150      +/-   ##
============================================
- Coverage     81.67%   78.10%   -3.58%     
============================================
  Files            42       43       +1     
  Lines          3930     4165     +235     
============================================
+ Hits           3210     3253      +43     
- Misses          720      912     +192     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@penelopeysm
Copy link
Member

penelopeysm commented Nov 20, 2025

It looks to me that the 1.11 perf is only a lot worse on the trivial model. In my experience (ran into this exact issue with Enzyme once, see also https://github.com/TuringLang/DynamicPPL.jl/pull/877/files), trivial models with 1 variable can be quite susceptible to changes in inlining strategy. It may be that a judicious @inline or @noinline somewhere will fix this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants