Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/ArviZ.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,16 @@ export bfmi, ess, rhat, mcse

## Stats utils
export autocov, autocorr, make_ufunc, wrap_xarray_ufunc

## Dataset
export Dataset, convert_to_dataset, namedtuple_to_dataset

## InferenceData
export InferenceData

## Schema
export check_follows_schema, follows_schema

## Data
export convert_to_inference_data,
extract_dataset,
Expand Down Expand Up @@ -131,6 +135,7 @@ include("rcparams.jl")
include("xarray.jl")
include("dataset.jl")
include("inference_data.jl")
include("schema.jl")
include("data.jl")
include("diagnostics.jl")
include("plots.jl")
Expand Down
13 changes: 11 additions & 2 deletions src/dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ Convert `NamedTuple` mapping variable names to arrays to a [`Dataset`](@ref).
dimension. If indices for a dimension in `dims` are provided, they are used even if
the dimension contains its own indices. If a dimension is missing, its indices are
automatically generated.
- `default_dims`: a set of dimension names assumed to apply to the first dimensions of all
variables, used to assign sample dimensions.
"""
function namedtuple_to_dataset end
function namedtuple_to_dataset(
Expand Down Expand Up @@ -138,7 +140,7 @@ Generate default attributes metadata for a dataset generated by inference librar
"""
function default_attributes(library=nothing)
return (
created_at=Dates.format(Dates.now(), Dates.ISODateTimeFormat),
created_at=current_time_iso(),
arviz_version=string(package_version(ArviZ)),
arviz_language="julia",
library_attributes(library)...,
Expand Down Expand Up @@ -181,7 +183,14 @@ Generate `DimensionsionalData.Dimension` objects for each dimension of `array`.
function generate_dims end
function generate_dims(array, name; dims=(), coords=(;), default_dims=())
num_default_dims = length(default_dims)
length(dims) + num_default_dims > ndims(array) && @error "blah"
if length(dims) + num_default_dims > ndims(array)
dim_names = Dimensions.name(Dimensions.basedims((dims..., default_dims...)))
throw(
DimensionMismatch(
"Provided dimensions $dim_names more than dimensions of array: $(ndims(array))",
),
)
end
dims_named = ntuple(ndims(array) - length(default_dims)) do i
dim = get(dims, i, nothing)
dim === nothing && return Symbol("$(name)_dim_$(i)")
Expand Down
5 changes: 3 additions & 2 deletions src/inference_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Internally, groups are stored in a `NamedTuple`, which can be accessed using
InferenceData(groups::NamedTuple)
InferenceData(; groups...)

Construct an inference data from either a `NamedTuple` or keyword arguments of groups.
Construct inference data from either a `NamedTuple` or keyword arguments of groups.

Groups must be [`Dataset`](@ref) objects.

Expand Down Expand Up @@ -225,7 +225,8 @@ _index_to_indices(i::Int) = [i]
_index_to_indices(sel::Dimensions.Selector) = AsSlice(sel)

@generated function _reorder_group_names(::Val{names}) where {names}
return Tuple(sort(collect(names); by=k -> SUPPORTED_GROUPS_DICT[k]))
lt = (a, b) -> (a isa Integer && b isa Integer) ? a < b : string(a) < string(b)
return Tuple(sort(collect(names); lt, by=k -> get(SUPPORTED_GROUPS_DICT, k, string(k))))
end

@generated _keys_and_types(::NamedTuple{keys,types}) where {keys,types} = (keys, types)
174 changes: 174 additions & 0 deletions src/schema.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@

struct InferenceDataSchemaError <: Exception
msg::String
end

"""
check_follows_schema(data::InferenceData; indices=false) -> Nothing
check_follows_schema(data::Dataset; name=:dataset, required_dims=()) -> Nothing

Raise an [`InferenceDataSchemaError`](@ref) if `data` does not follow the
[InferenceData schema](https://python.arviz.org/en/v$(arviz_version())/schema/schema.html).

When applicable, if `indices=true`, values of indices are also checked for consistency
across groups.

Use [`follows_schema`](@ref) to return whether `data` follows the schema without raising an
error.
"""
check_follows_schema

function check_follows_schema(data::Dataset; name=:dataset, required_dims=())
dims = Dimensions.dims(data)
dim_names = Dimensions.name(dims)
length(Dimensions.commondims(dims, dim_names)) == length(dims) ||
throw(InferenceDataSchemaError("$name has non-unique dimension names: $dim_names"))
# - groups that contain samples must contain the sample dimensions
all(Dimensions.hasdim(data, required_dims)) || throw(
InferenceDataSchemaError(
"$name does not have the required dimensions: $required_dims"
),
)
# - variables must not share names with dimensions
shared_var_dims = Dimensions.commondims(DimensionalData.keys(data), dims)
if !isempty(shared_var_dims)
throw(
InferenceDataSchemaError(
"$name has variables and dimensions with the same name: $(Dimensions.name(shared_var_dims))",
),
)
end
metadata = DimensionalData.metadata(data)
eltype(keys(metadata)) <: Symbol ||
throw(InferenceDataSchemaError("$name has metadata with non-Symbol keys."))
# - each group contains the attribute `:created_at`.
haskey(metadata, :created_at) ||
throw(InferenceDataSchemaError("$name has no `:created_at` entry in its metadata."))
return nothing
end
function check_follows_schema(data::InferenceData; indices=false)
# each group separately should follow the schema
foreach(
zip(groupnames(data), groups(data), _maybe_sample_dims(data))
) do (name, group, sample_dims)
check_follows_schema(group; sample_dims, name)
end
sample_dims = DEFAULT_SAMPLE_DIMS
# posterior-derived groups should share same sample dims/indices
check_dims_match(_posterior_related_groups(data), sample_dims; indices)
# prior-derived groups should share same sample dims/indices
check_dims_match(_prior_related_groups(data), sample_dims; indices)
# any parameters shared by prior and posterior should have the same non-sample dims/indices
if hasgroup(data, :posterior) && hasgroup(data, :prior)
var_dims = Dimensions.otherdims(
Dimensions.commondims(data.posterior, data.prior), sample_dims
)
check_dims_match(NamedTuple{(:posterior, :prior)}(data), var_dims; indices)
end
# any dim names shared by log_likelihood, prior_predictive, posterior_predictive and
# observed_data must share the same indices
data_groups = _data_related_groups(data)
if length(data_groups) > 1
data_dims = Dimensions.otherdims(Dimensions.commondims(data_groups...), sample_dims)
check_dims_match(data_groups, data_dims; indices)
end
return nothing
end

"""
follows_schema(data; kwargs...) -> Bool

Return whether `data` follows the [InferenceData schema](https://python.arviz.org/en/v$(arviz_version())/schema/schema.html).

`kwargs` are passed to [`check_follows_schema`](@ref).

See [`check_follows_schema`](@ref) for informative error messages.
"""
function follows_schema(data; kwargs...)
try
check_follows_schema(data; kwargs...)
return true
catch e
e isa InferenceDataSchemaError && return false
throw(e)
end
end

@generated function _maybe_sample_dims(::InferenceData{group_names}) where {group_names}
return map(group_names) do name
if name ∈ (
:posterior,
:posterior_predictive,
:predictive,
:sample_stats,
:log_likelihood,
:prior,
:prior_predictive,
:sample_stats_prior,
)
return DEFAULT_SAMPLE_DIMS
else
return ()
end
end
end

@generated function _filter_groups_type(
::InferenceData{groups}, ::Val{other}
) where {groups,other}
shared = Tuple(intersect(groups, other))
return NamedTuple{shared}
end

function _posterior_related_groups(idata::InferenceData)
posterior_groups = (
:posterior, :posterior_predictive, :sample_stats, :log_likelihood, :predictive
)
return _filter_groups_type(idata, Val(posterior_groups))(idata)
end

function _prior_related_groups(idata::InferenceData)
prior_groups = (:prior, :prior_predictive, :sample_stats_prior)
return _filter_groups_type(idata, Val(prior_groups))(idata)
end

function _data_related_groups(idata::InferenceData)
data_groups = (
:observed_data, :log_likelihood, :prior_predictive, :posterior_predictive
)
return _filter_groups_type(idata, Val(data_groups))(idata)
end

function check_dims_match(groups, dims; indices::Bool=true)
isempty(groups) && return nothing
try
# dims = Dimensions.commondims(first(groups), _dims)
# if length(dims) != length(_dims)
# dims_missing = setdiff(Dimensions.name(_dims), Dimensions.name(dims))
# throw(ErrorException("dimensions missing $(dims_missing)"))
# end
comparesomedims(groups...; dims, val=indices)
catch e
if e isa DimensionMismatch
throw(
InferenceDataSchemaError(
"Dimension mismatch in groups $(keys(groups)): $(e.msg)"
),
)
else
throw(e)
end
end
return nothing
end

function comparesomedims(datasets::Dataset...; dims, val=false)
sub_dims = map(datasets) do ds
Dimensions.sortdims(Dimensions.commondims(ds, dims), dims)
end
sub_dims_ref = first(sub_dims)
for _dims in Iterators.drop(sub_dims, 1)
Dimensions.comparedims(sub_dims_ref, _dims; val)
end
return nothing
end
2 changes: 2 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -310,3 +310,5 @@ function Dimensions.selectindices(l::Dimensions.LookupArray, sel::AsSlice; kw...
inds = i isa AbstractVector ? i : [i]
return inds
end

current_time_iso() = Dates.format(Dates.now(), Dates.ISODateTimeFormat)
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
CmdStan = "593b3428-ca2f-500c-ae53-031589ec8ddd"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Expand Down
3 changes: 2 additions & 1 deletion test/helpers.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Random
using PyCall
using ArviZ: attributes
using Dates

try
ArviZ.initialize_bokeh()
Expand Down Expand Up @@ -38,7 +39,7 @@ function random_data()
chain=1:4, draw=1:100, shared=["s1", "s2", "s3"], dima=1:4, dimb=2:6, dimy=1:5
)
dims = (a=(:shared, :dima), b=(:shared, :dimb), y=(:shared, :dimy))
metadata = (inference_library="PPL",)
metadata = (created_at=ArviZ.current_time_iso(), inference_library="PPL")
posterior = random_dataset(var_names, dims, coords, metadata)
posterior_predictive = random_dataset(data_names, dims, coords, metadata)
prior = random_dataset(var_names, dims, coords, metadata)
Expand Down
10 changes: 9 additions & 1 deletion test/test_utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using DataFrames: DataFrames
using PyCall, PyPlot
using Dates, PyCall, PyPlot

pandas = ArviZ.pandas

Expand Down Expand Up @@ -148,4 +148,12 @@ pandas = ArviZ.pandas
@test ArviZ.package_version(ArviZ) isa VersionNumber
@test ArviZ.package_version(PyCall) isa VersionNumber
end

@testset "current_time_iso" begin
iso = ArviZ.current_time_iso()
@test iso isa String
@test startswith(iso, Dates.format(Dates.now(), Dates.ISODateFormat))
iso2 = ArviZ.current_time_iso()
@test Dates.DateTime(iso2) > Dates.DateTime(iso)
end
end