Skip to content

Remove eltype, matchingvalue, get_matching_type #1015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Aug 7, 2025

Half-hearted attempt. I'd like to see what breaks and why.

Copy link
Contributor

github-actions bot commented Aug 7, 2025

Benchmark Report for Commit 4a29a2a

Computer Information

Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

|                 Model | Dimension |  AD Backend |      VarInfo Type | Linked | Eval Time / Ref Time | AD Time / Eval Time |
|-----------------------|-----------|-------------|-------------------|--------|----------------------|---------------------|
| Simple assume observe |         1 | forwarddiff |             typed |  false |                  8.6 |                 1.6 |
|           Smorgasbord |       201 | forwarddiff |             typed |  false |                683.7 |                40.8 |
|           Smorgasbord |       201 | forwarddiff | simple_namedtuple |   true |                485.1 |                46.8 |
|           Smorgasbord |       201 | forwarddiff |           untyped |   true |               1298.1 |                27.8 |
|           Smorgasbord |       201 | forwarddiff |       simple_dict |   true |               6582.8 |                29.8 |
|           Smorgasbord |       201 | reversediff |             typed |   true |               1486.1 |                28.5 |
|           Smorgasbord |       201 |    mooncake |             typed |   true |               1041.8 |                 4.1 |
|    Loop univariate 1k |      1000 |    mooncake |             typed |   true |               5677.6 |                 4.2 |
|       Multivariate 1k |      1000 |    mooncake |             typed |   true |                977.6 |                 9.1 |
|   Loop univariate 10k |     10000 |    mooncake |             typed |   true |              64636.8 |                 3.6 |
|      Multivariate 10k |     10000 |    mooncake |             typed |   true |               8596.1 |                10.0 |
|               Dynamic |        10 |    mooncake |             typed |   true |                137.5 |                11.9 |
|              Submodel |         1 |    mooncake |             typed |   true |                 13.1 |                 5.6 |
|                   LDA |        12 | reversediff |             typed |   true |               1191.3 |                 3.8 |

@@ -707,53 +707,3 @@ function warn_empty(body)
end
return nothing
end

# TODO(mhauru) matchingvalue has methods that can accept both types and values. Why?
Copy link
Member Author

@penelopeysm penelopeysm Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because matchingvalue gets called on all the model function's arguments, and types can be arguments to the model as well, e.g.

@model function f(x, T) ... end
model = f(1.0, Float64)

Comment on lines -722 to -723
# TODO(mhauru) Why do we make a deepcopy, even though in the !hasmissing branch we
# are happy to return `value` as-is?
Copy link
Member Author

@penelopeysm penelopeysm Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change was made here:

#191

The motivation is here:

TuringLang/Turing.jl#1464 (comment)

This has to do with some subtle mutation behaviour. For example

@model function f(x)
    x[1] ~ Normal()
end

If model = f([1.0]), the tilde statement is an observe, and thus even if you reassign to x[1] it doesn't change the value of x. This is the !hasmissing branch, and since overwriting is a no-op, we don't need to deepcopy it.

If model = f([missing]) - the tilde statement is now an assume, and when you run the model it will sample a new value for x[1] and set that value in x. Then if you rerun the model x[1] is no longer missing. This is the case where deepcopy is triggered.

Copy link
Member Author

@penelopeysm penelopeysm Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So apart from the deepcopy to avoid aliasing, the other place where matchingvalue does something meaningful is

@model function f(y, ::Type{T}=Float64) where {T}
    x = Vector{T}(undef, length(y))
    for i in eachindex(y)
        x[i] ~ Normal()
        y[i] ~ Normal(x[i])
    end
end
model = f([1.0])

If you just evaluate this normally with floats, it's all good. Nothing special needs to happen.

If you evaluate this with ReverseDiff, then things need to change. Specifically:

  1. x needs to become a vector of TrackedReals rather than a vector of Floats.
  2. In order to accomplish this, the ARGUMENT to the model needs to change: even though T SEEMS to be specified as Float64, in fact, matchingvalue hijacks it to turn it into TrackedReal when calling model().
  3. How does matchingvalue know that it needs to become a TrackedReal? Simple - when you call logdensity_and_gradient it calls unflatten to set the parameters (which will be TrackedReals) in the varinfo. matchingvalue then looks inside the varinfo to see if the varinfo contains TrackedReals! Hence eltype(vi) 🙃

It actually gets a bit more complicated. When you define the model, the @model macro already hijacks it to turn T into TypeWrap{Float64}(), and then when you actually evaluate the model matchingvalue hijacks it even further to turn it into TypeWrap{TrackedReal}(). Not sure why TypeWrap is needed but apparently it's something to do with avoiding DataType.

ForwardDiff actually works just fine on this PR. I don't know why, but I also remember there was a talk I gave where we were surprised that actually ForwardDiff NUTS worked fine without special ::Type{T}=Float64 stuff, so that is consistent with this observation.

So this whole thing pretty much only exists to make ReverseDiff happy.

To get around this, I propose that we drop compatibility with ReverseDiff

Copy link
Member Author

@penelopeysm penelopeysm Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, for most models, ForwardDiff and ReverseDiff still work because of this special nice behaviour:

julia> x = Float64[1.0, 2.0]
2-element Vector{Float64}:
 1.0
 2.0

julia> x[1] = ForwardDiff.Dual(3.0) # x[1] ~ dist doesn't do this
ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{Nothing, Float64, 0})
The type `Float64` exists, but no method is defined for this combination of argument types when trying to construct it.

julia> x = Accessors.set(x, (@optic _[1]), ForwardDiff.Dual(3.0)) # x[1] ~ dist actually does this!
2-element Vector{ForwardDiff.Dual{Nothing, Float64, 0}}:
 Dual{Nothing}(3.0)
 Dual{Nothing}(2.0)

There is only one erroring test in CI, which happens because the model explicitly includes the assignment x[i] = ... rather than a tilde-statement x[i] ~ .... Changing the assignment to use Accessors.set makes it work just fine.

BUT there are correctness issues with ReverseDiff (not errors), and I have no clue where those stem from. And really interestingly, it's only a problem for one of the demo models, not any of the others, even though many of them use the Type{T} syntax.

src/model.jl Outdated
Comment on lines 936 to 938
:($matchingvalue(varinfo, model.args.$var)...)
:(deepcopy(model.args.$var)...)
else
:($matchingvalue(varinfo, model.args.$var))
:(deepcopy(model.args.$var))
Copy link
Member Author

@penelopeysm penelopeysm Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So matchingvalue used to deepcopy things sometimes. Right now I work around this by indiscriminately deepcopying. This is a Bad Thing and we should definitely have more careful rules about when something needs to be deepcopied. However, I don't believe that such rules need to use the whole matching_type machinery.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indiscriminately deepcopying here breaks ReverseDiff. See comment below: #1015 (comment)

Copy link

codecov bot commented Aug 7, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 80.82%. Comparing base (ea6b6de) to head (4a29a2a).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1015      +/-   ##
==========================================
- Coverage   82.16%   80.82%   -1.34%     
==========================================
  Files          38       38              
  Lines        3935     3891      -44     
==========================================
- Hits         3233     3145      -88     
- Misses        702      746      +44     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@coveralls
Copy link

coveralls commented Aug 7, 2025

Pull Request Test Coverage Report for Build 16818637635

Details

  • 3 of 3 (100.0%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage decreased (-1.6%) to 80.828%

Totals Coverage Status
Change from base Build 16810321891: -1.6%
Covered Lines: 3145
Relevant Lines: 3891

💛 - Coveralls

Copy link
Contributor

github-actions bot commented Aug 7, 2025

DynamicPPL.jl documentation for PR #1015 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR1015/

@penelopeysm
Copy link
Member Author

penelopeysm commented Aug 7, 2025

ReverseDiff correctness issue with this PR:

using DynamicPPL, Distributions, FiniteDifferences, ReverseDiff, ADTypes, LinearAlgebra, Random
using DynamicPPL.TestUtils.AD: run_ad, WithBackend

@model function inner(m, x)
    @show m
    return x ~ Normal(m[1])
end
@model function outer(x)
    # m has to be vector-valued for it to fail
    m ~ MvNormal(zeros(1), I)
    # If you use this line it works
    # x ~ Normal(m[1])
    # This line is seemingly equivalent but fails
    t ~ to_submodel(inner(m, x))
end
model = outer(1.5)

run_ad(
    model,
    AutoReverseDiff();
    test=WithBackend(AutoFiniteDifferences(fdm=central_fdm(5, 1))),
    rng=Xoshiro(468)
);

@penelopeysm
Copy link
Member Author

penelopeysm commented Aug 7, 2025

Ironically, removing the deepcopy makes ReverseDiff work correctly on the above. The reason is I think ReverseDiff expects to mutate the arguments to the function and then pick up the derivatives by inspecting how the arguments were mutated. Deepcopying the argument broke that promise. So we have the slightly weird scenario where ReverseDiff.TrackedArray arguments must not be deepcopied, but alias-able arguments with missing in them must be deepcopied. But with all of that implemented it seems like we should be able to get rid of all of this, although we would also need a fix to #823.

@mhauru
Copy link
Member

mhauru commented Aug 8, 2025

If this doesn't get merged, then I would at least like to have the lessons you learned here recorded somewhere if possible.

@penelopeysm
Copy link
Member Author

Yupp, definitely. This PR is basically me liveblogging as I find things out 😄 but at the very least, I'm sure we could improve those docstrings / add comments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants