Skip to content

Commit 6b3441f

Browse files
committed
More fixes
1 parent b9724bc commit 6b3441f

File tree

3 files changed

+9
-16
lines changed

3 files changed

+9
-16
lines changed

src/threadsafe.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,12 @@ getmetadata(vi::ThreadSafeVarInfo, vn::VarName) = getmetadata(vi.varinfo, vn)
5656
getidx(vi::ThreadSafeVarInfo, vn::VarName) = getidx(vi.varinfo, vn)
5757
getrange(vi::ThreadSafeVarInfo, vn::VarName) = getrange(vi.varinfo, vn)
5858
getdist(vi::ThreadSafeVarInfo, vn::VarName) = getdist(vi.varinfo, vn)
59+
getval(vi::ThreadSafeVarInfo, vn::VarName) = getval(vi.varinfo, vn)
5960

6061
function setgid!(vi::ThreadSafeVarInfo, gid::Selector, vn::VarName)
6162
setgid!(vi.varinfo, gid, vn)
6263
end
64+
setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn)
6365

6466
keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo)
6567
haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn)

test/Turing/inference/Inference.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -473,42 +473,42 @@ floatof(::Type) = Real # fallback if type inference failed
473473

474474
function get_matching_type(
475475
spl::AbstractSampler,
476-
vi::VarInfo,
476+
vi,
477477
::Type{T},
478478
) where {T}
479479
return T
480480
end
481481
function get_matching_type(
482482
spl::AbstractSampler,
483-
vi::VarInfo,
483+
vi,
484484
::Type{<:AbstractFloat},
485485
)
486486
return floatof(eltype(vi, spl))
487487
end
488488
function get_matching_type(
489489
spl::Sampler{<:Hamiltonian},
490-
vi::VarInfo,
490+
vi,
491491
::Type{<:Union{Missing, AbstractFloat}},
492492
)
493493
return Union{Missing, floatof(eltype(vi, spl))}
494494
end
495495
function get_matching_type(
496496
spl::Sampler{<:Hamiltonian},
497-
vi::VarInfo,
497+
vi,
498498
::Type{<:AbstractFloat},
499499
)
500500
return floatof(eltype(vi, spl))
501501
end
502502
function get_matching_type(
503503
spl::Sampler{<:Hamiltonian},
504-
vi::VarInfo,
504+
vi,
505505
::Type{TV},
506506
) where {T, N, TV <: Array{T, N}}
507507
return Array{get_matching_type(spl, vi, T), N}
508508
end
509509
function get_matching_type(
510510
spl::Sampler{<:Union{PG, SMC}},
511-
vi::VarInfo,
511+
vi,
512512
::Type{TV},
513513
) where {T, N, TV <: Array{T, N}}
514514
return TArray{T, N}

test/compiler.jl

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -222,19 +222,10 @@ end
222222
varinfo = DynamicPPL.VarInfo(model)
223223
model(varinfo)
224224
@test getlogp(varinfo) == lp
225-
@test varinfo_ === varinfo
225+
@test varinfo_ isa AbstractVarInfo
226226
@test model_ === model
227227
@test sampler_ === SampleFromPrior()
228228
@test context_ === DefaultContext()
229-
@test length(logps_) == Threads.nthreads()
230-
@test sum(logps_) == lp
231-
for i in 1:length(logps_)
232-
if i == Threads.threadid()
233-
@test logps_[i] == lp
234-
else
235-
@test iszero(logps_[i])
236-
end
237-
end
238229

239230
# test DPPL#61
240231
@model testmodel(z) = begin

0 commit comments

Comments
 (0)