Skip to content

Commit f9adfeb

Browse files
fix partial missing bug (#191)
This PR fixes TuringLang/Turing.jl#1464. Co-authored-by: David Widmann <[email protected]>
1 parent 89e481e commit f9adfeb

File tree

4 files changed

+55
-36
lines changed

4 files changed

+55
-36
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.10.1"
3+
version = "0.10.2"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/compiler.jl

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -369,16 +369,47 @@ Convert the `value` to the correct type for the `sampler` and the `vi` object.
369369
function matchingvalue(sampler, vi, value)
370370
T = typeof(value)
371371
if hasmissing(T)
372-
return convert(get_matching_type(sampler, vi, T), value)
372+
_value = convert(get_matching_type(sampler, vi, T), value)
373+
if _value === value
374+
return deepcopy(_value)
375+
else
376+
return _value
377+
end
373378
else
374379
return value
375380
end
376381
end
377382
matchingvalue(sampler, vi, value::FloatOrArrayType) = get_matching_type(sampler, vi, value)
378383

379384
"""
380-
get_matching_type(spl, vi, ::Type{T}) where {T}
381-
Get the specialized version of type `T` for sampler `spl`. For example,
382-
if `T === Float64` and `spl::Hamiltonian`, the matching type is `eltype(vi[spl])`.
385+
get_matching_type(spl::AbstractSampler, vi, ::Type{T}) where {T}
386+
387+
Get the specialized version of type `T` for sampler `spl`.
388+
389+
For example, if `T === Float64` and `spl::Hamiltonian`, the matching type is
390+
`eltype(vi[spl])`.
383391
"""
384-
function get_matching_type end
392+
get_matching_type(spl::AbstractSampler, vi, ::Type{T}) where {T} = T
393+
function get_matching_type(
394+
spl::AbstractSampler,
395+
vi,
396+
::Type{<:Union{Missing, AbstractFloat}},
397+
)
398+
return Union{Missing, floatof(eltype(vi, spl))}
399+
end
400+
function get_matching_type(
401+
spl::AbstractSampler,
402+
vi,
403+
::Type{<:AbstractFloat},
404+
)
405+
return floatof(eltype(vi, spl))
406+
end
407+
function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T,N}}) where {T,N}
408+
return Array{get_matching_type(spl, vi, T), N}
409+
end
410+
function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T}}) where T
411+
return Array{get_matching_type(spl, vi, T)}
412+
end
413+
414+
floatof(::Type{T}) where {T <: Real} = typeof(one(T)/one(T))
415+
floatof(::Type) = Real # fallback if type inference failed

test/Turing/inference/Inference.jl

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -424,36 +424,6 @@ for alg in (:HMC, :HMCDA, :NUTS, :SGLD, :SGHMC)
424424
@eval DynamicPPL.getspace(::$alg{<:Any, space}) where {space} = space
425425
end
426426

427-
floatof(::Type{T}) where {T <: Real} = typeof(one(T)/one(T))
428-
floatof(::Type) = Real # fallback if type inference failed
429-
430-
function get_matching_type(
431-
spl::AbstractSampler,
432-
vi,
433-
::Type{T},
434-
) where {T}
435-
return T
436-
end
437-
function get_matching_type(
438-
spl::AbstractSampler,
439-
vi,
440-
::Type{<:Union{Missing, AbstractFloat}},
441-
)
442-
return Union{Missing, floatof(eltype(vi, spl))}
443-
end
444-
function get_matching_type(
445-
spl::AbstractSampler,
446-
vi,
447-
::Type{<:AbstractFloat},
448-
)
449-
return floatof(eltype(vi, spl))
450-
end
451-
function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T,N}}) where {T,N}
452-
return Array{get_matching_type(spl, vi, T), N}
453-
end
454-
function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T}}) where T
455-
return Array{get_matching_type(spl, vi, T)}
456-
end
457427
function get_matching_type(
458428
spl::Sampler{<:Union{PG, SMC}},
459429
vi,

test/compiler.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,24 @@ end
267267
end
268268
model = testmodel(rand(10))
269269
@test all(z -> isapprox(z, 0; atol = 0.2), mean(model() for _ in 1:1000))
270+
271+
# test Turing#1464
272+
@model function gdemo(x)
273+
s ~ InverseGamma(2, 3)
274+
m ~ Normal(0, sqrt(s))
275+
for i in eachindex(x)
276+
x[i] ~ Normal(m, sqrt(s))
277+
end
278+
end
279+
x = [1.0, missing]
280+
VarInfo(gdemo(x))
281+
@test ismissing(x[2])
282+
283+
# https://github.com/TuringLang/Turing.jl/issues/1464#issuecomment-731153615
284+
vi = VarInfo(gdemo(x))
285+
@test haskey(vi.metadata, :x)
286+
vi = VarInfo(gdemo(x))
287+
@test haskey(vi.metadata, :x)
270288
end
271289
@testset "nested model" begin
272290
function makemodel(p)

0 commit comments

Comments
 (0)