Skip to content

Commit a0b8999

Browse files
authored
fixed names_values and added tests (#2021)
1 parent 970abac commit a0b8999

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

src/inference/Inference.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -297,23 +297,23 @@ function names_values(extra_data::AbstractVector{<:NamedTuple{names}}) where nam
297297
return collect(names), values
298298
end
299299

300-
function names_values(extra_data::AbstractVector{<:NamedTuple})
300+
function names_values(xs::AbstractVector{<:NamedTuple})
301301
# Obtain all parameter names.
302-
names_set = Set(Symbol[])
303-
for data in extra_data
304-
for name in names(data)
305-
push!(extra_names_set, name)
302+
names_set = Set{Symbol}()
303+
for x in xs
304+
for k in keys(x)
305+
push!(names_set, k)
306306
end
307307
end
308-
extra_names = collect(extra_names_set)
308+
names_unique = collect(names_set)
309309

310310
# Extract all values as matrix.
311311
values = [
312-
hasfield(data, name) ? missing : getfield(data, name)
313-
for data in extra_data, name in extra_names
312+
haskey(x, name) ? x[name] : missing
313+
for x in xs, name in names_unique
314314
]
315315

316-
return extra_names, values
316+
return names_unique, values
317317
end
318318

319319
getlogevidence(transitions, sampler, state) = missing

test/inference/Inference.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,4 +523,14 @@
523523
vdemo3kw(; T) = vdemo3(T)
524524
sample(vdemo3kw(; T=Vector{Float64}), alg, 250)
525525
end
526+
527+
@testset "names_values" begin
528+
ks, xs = Turing.Inference.names_values([
529+
(a=1,),
530+
(b=2,),
531+
(a=3, b=4)
532+
])
533+
@test all(xs[:, 1] .=== [1, missing, 3])
534+
@test all(xs[:, 2] .=== [missing, 2, 4])
535+
end
526536
end

0 commit comments

Comments
 (0)