Skip to content

Commit bcb52e0

Browse files
committed
Add Random.AbstractRNG type annotations (fixing dot_tilde_assume ambiguity)
1 parent 653c9c5 commit bcb52e0

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

src/context_implementations.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ function tilde_observe!!(context, right, left, vi)
195195
return left, acclogp_observe!!(context, vi, logp)
196196
end
197197

198-
function assume(rng, spl::Sampler, dist)
198+
function assume(rng::Random.AbstractRNG, spl::Sampler, dist)
199199
return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))")
200200
end
201201

@@ -291,14 +291,18 @@ end
291291
function dot_tilde_assume(::IsLeaf, ::AbstractContext, right, left, vns, vi)
292292
return dot_assume(right, left, vns, vi)
293293
end
294-
function dot_tilde_assume(::IsLeaf, rng, ::AbstractContext, sampler, right, left, vns, vi)
294+
function dot_tilde_assume(
295+
::IsLeaf, rng::Random.AbstractRNG, ::AbstractContext, sampler, right, left, vns, vi
296+
)
295297
return dot_assume(rng, sampler, right, vns, left, vi)
296298
end
297299

298300
function dot_tilde_assume(::IsParent, context::AbstractContext, args...)
299301
return dot_tilde_assume(childcontext(context), args...)
300302
end
301-
function dot_tilde_assume(::IsParent, rng, context::AbstractContext, args...)
303+
function dot_tilde_assume(
304+
::IsParent, rng::Random.AbstractRNG, context::AbstractContext, args...
305+
)
302306
return dot_tilde_assume(rng, childcontext(context), args...)
303307
end
304308

@@ -371,7 +375,7 @@ function dot_assume(
371375
end
372376

373377
function dot_assume(
374-
rng,
378+
rng::Random.AbstractRNG,
375379
spl::Union{SampleFromPrior,SampleFromUniform},
376380
dist::MultivariateDistribution,
377381
vns::AbstractVector{<:VarName},
@@ -404,7 +408,7 @@ function dot_assume(
404408
end
405409

406410
function dot_assume(
407-
rng,
411+
rng::Random.AbstractRNG,
408412
spl::Union{SampleFromPrior,SampleFromUniform},
409413
dists::Union{Distribution,AbstractArray{<:Distribution}},
410414
vns::AbstractArray{<:VarName},
@@ -416,7 +420,9 @@ function dot_assume(
416420
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns)))
417421
return r, lp, vi
418422
end
419-
function dot_assume(rng, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any)
423+
function dot_assume(
424+
rng::Random.AbstractRNG, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any
425+
)
420426
return error(
421427
"[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing assume statement"
422428
)
@@ -436,7 +442,7 @@ function _maybe_invlink_broadcast(vi, vn, dist)
436442
end
437443

438444
function get_and_set_val!(
439-
rng,
445+
rng::Random.AbstractRNG,
440446
vi::VarInfoOrThreadSafeVarInfo,
441447
vns::AbstractVector{<:VarName},
442448
dist::MultivariateDistribution,
@@ -478,7 +484,7 @@ function get_and_set_val!(
478484
end
479485

480486
function get_and_set_val!(
481-
rng,
487+
rng::Random.AbstractRNG,
482488
vi::VarInfoOrThreadSafeVarInfo,
483489
vns::AbstractArray{<:VarName},
484490
dists::Union{Distribution,AbstractArray{<:Distribution}},

0 commit comments

Comments
 (0)