From d5ac784d21f79e40d3b202edeab3ba3b5972c43f Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 4 Aug 2022 15:04:01 +0200 Subject: [PATCH 1/6] Add and test setting of current time string --- src/dataset.jl | 2 +- src/utils.jl | 2 ++ test/Project.toml | 1 + test/helpers.jl | 3 ++- test/test_utils.jl | 10 +++++++++- 5 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/dataset.jl b/src/dataset.jl index 6e6d0b51d..1739844cb 100644 --- a/src/dataset.jl +++ b/src/dataset.jl @@ -138,7 +138,7 @@ Generate default attributes metadata for a dataset generated by inference librar """ function default_attributes(library=nothing) return ( - created_at=Dates.format(Dates.now(), Dates.ISODateTimeFormat), + created_at=current_time_iso(), arviz_version=string(package_version(ArviZ)), arviz_language="julia", library_attributes(library)..., diff --git a/src/utils.jl b/src/utils.jl index 92ccb256e..721a55734 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -310,3 +310,5 @@ function Dimensions.selectindices(l::Dimensions.LookupArray, sel::AsSlice; kw... inds = i isa AbstractVector ? i : [i] return inds end + +current_time_iso() = Dates.format(Dates.now(), Dates.ISODateTimeFormat) diff --git a/test/Project.toml b/test/Project.toml index e9bc6620a..3e73bc663 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,7 @@ [deps] CmdStan = "593b3428-ca2f-500c-ae53-031589ec8ddd" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" diff --git a/test/helpers.jl b/test/helpers.jl index c74ced771..dbe7bccf9 100644 --- a/test/helpers.jl +++ b/test/helpers.jl @@ -1,6 +1,7 @@ using Random using PyCall using ArviZ: attributes +using Dates try ArviZ.initialize_bokeh() @@ -38,7 +39,7 @@ function random_data() chain=1:4, draw=1:100, shared=["s1", "s2", "s3"], dima=1:4, dimb=2:6, dimy=1:5 ) dims = (a=(:shared, :dima), b=(:shared, :dimb), y=(:shared, :dimy)) - metadata = (inference_library="PPL",) + metadata = (created_at=ArviZ.current_time_iso(), inference_library="PPL") posterior = random_dataset(var_names, dims, coords, metadata) posterior_predictive = random_dataset(data_names, dims, coords, metadata) prior = random_dataset(var_names, dims, coords, metadata) diff --git a/test/test_utils.jl b/test/test_utils.jl index 41a01edd3..52ad5a19b 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -1,5 +1,5 @@ using DataFrames: DataFrames -using PyCall, PyPlot +using Dates, PyCall, PyPlot pandas = ArviZ.pandas @@ -148,4 +148,12 @@ pandas = ArviZ.pandas @test ArviZ.package_version(ArviZ) isa VersionNumber @test ArviZ.package_version(PyCall) isa VersionNumber end + + @testset "current_time_iso" begin + iso = ArviZ.current_time_iso() + @test iso isa String + @test startswith(iso, Dates.format(Dates.now(), Dates.ISODateFormat)) + iso2 = ArviZ.current_time_iso() + @test Dates.DateTime(iso2) > Dates.DateTime(iso) + end end From 4e792c539676627468828d8a97a36508fd1d464f Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 4 Aug 2022 21:45:34 +0200 Subject: [PATCH 2/6] Allow other groups --- src/inference_data.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/inference_data.jl b/src/inference_data.jl index 7d9c4f018..73e4fd7bb 100644 --- a/src/inference_data.jl +++ b/src/inference_data.jl @@ -225,7 +225,8 @@ _index_to_indices(i::Int) = [i] _index_to_indices(sel::Dimensions.Selector) = AsSlice(sel) @generated function _reorder_group_names(::Val{names}) where {names} - return Tuple(sort(collect(names); by=k -> SUPPORTED_GROUPS_DICT[k])) + lt = (a, b) -> (a isa Integer && b isa Integer) ? a < b : string(a) < string(b) + return Tuple(sort(collect(names); lt, by=k -> get(SUPPORTED_GROUPS_DICT, k, string(k)))) end @generated _keys_and_types(::NamedTuple{keys,types}) where {keys,types} = (keys, types) From e5978c70bbf52eedba5039b743fb5dff84642695 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 4 Aug 2022 21:45:42 +0200 Subject: [PATCH 3/6] Fix typo --- src/inference_data.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/inference_data.jl b/src/inference_data.jl index 73e4fd7bb..2e3064f0b 100644 --- a/src/inference_data.jl +++ b/src/inference_data.jl @@ -13,7 +13,7 @@ Internally, groups are stored in a `NamedTuple`, which can be accessed using InferenceData(groups::NamedTuple) InferenceData(; groups...) -Construct an inference data from either a `NamedTuple` or keyword arguments of groups. +Construct inference data from either a `NamedTuple` or keyword arguments of groups. Groups must be [`Dataset`](@ref) objects. From a336a0bd68930ac8fa46a5b1e6f3dd6bf2a8a45a Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 4 Aug 2022 21:46:03 +0200 Subject: [PATCH 4/6] Throw more useful error --- src/dataset.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/dataset.jl b/src/dataset.jl index 1739844cb..dac7febcf 100644 --- a/src/dataset.jl +++ b/src/dataset.jl @@ -181,7 +181,14 @@ Generate `DimensionsionalData.Dimension` objects for each dimension of `array`. function generate_dims end function generate_dims(array, name; dims=(), coords=(;), default_dims=()) num_default_dims = length(default_dims) - length(dims) + num_default_dims > ndims(array) && @error "blah" + if length(dims) + num_default_dims > ndims(array) + dim_names = Dimensions.name(Dimensions.basedims((dims..., default_dims...))) + throw( + DimensionMismatch( + "Provided dimensions $dim_names more than dimensions of array: $(ndims(array))", + ), + ) + end dims_named = ntuple(ndims(array) - length(default_dims)) do i dim = get(dims, i, nothing) dim === nothing && return Symbol("$(name)_dim_$(i)") From 9ba019cc5416fb6a7107874df398436066af8530 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 4 Aug 2022 21:46:11 +0200 Subject: [PATCH 5/6] Document default_dims --- src/dataset.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/dataset.jl b/src/dataset.jl index dac7febcf..3685dac91 100644 --- a/src/dataset.jl +++ b/src/dataset.jl @@ -90,6 +90,8 @@ Convert `NamedTuple` mapping variable names to arrays to a [`Dataset`](@ref). dimension. If indices for a dimension in `dims` are provided, they are used even if the dimension contains its own indices. If a dimension is missing, its indices are automatically generated. + - `default_dims`: a set of dimension names assumed to apply to the first dimensions of all + variables, used to assign sample dimensions. """ function namedtuple_to_dataset end function namedtuple_to_dataset( From 3828c3166f7984e97e02b8596da3dac25bc37c38 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 8 Aug 2022 10:04:27 +0200 Subject: [PATCH 6/6] Add functions for checking that the schema is followed --- src/ArviZ.jl | 5 ++ src/schema.jl | 174 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 179 insertions(+) create mode 100644 src/schema.jl diff --git a/src/ArviZ.jl b/src/ArviZ.jl index 63796796e..637ad454c 100644 --- a/src/ArviZ.jl +++ b/src/ArviZ.jl @@ -71,12 +71,16 @@ export bfmi, ess, rhat, mcse ## Stats utils export autocov, autocorr, make_ufunc, wrap_xarray_ufunc + ## Dataset export Dataset, convert_to_dataset, namedtuple_to_dataset ## InferenceData export InferenceData +## Schema +export check_follows_schema, follows_schema + ## Data export convert_to_inference_data, extract_dataset, @@ -131,6 +135,7 @@ include("rcparams.jl") include("xarray.jl") include("dataset.jl") include("inference_data.jl") +include("schema.jl") include("data.jl") include("diagnostics.jl") include("plots.jl") diff --git a/src/schema.jl b/src/schema.jl new file mode 100644 index 000000000..26a394e4b --- /dev/null +++ b/src/schema.jl @@ -0,0 +1,174 @@ + +struct InferenceDataSchemaError <: Exception + msg::String +end + +""" + check_follows_schema(data::InferenceData; indices=false) -> Nothing + check_follows_schema(data::Dataset; name=:dataset, required_dims=()) -> Nothing + +Raise an [`InferenceDataSchemaError`](@ref) if `data` does not follow the +[InferenceData schema](https://python.arviz.org/en/v$(arviz_version())/schema/schema.html). + +When applicable, if `indices=true`, values of indices are also checked for consistency +across groups. + +Use [`follows_schema`](@ref) to return whether `data` follows the schema without raising an +error. +""" +check_follows_schema + +function check_follows_schema(data::Dataset; name=:dataset, required_dims=()) + dims = Dimensions.dims(data) + dim_names = Dimensions.name(dims) + length(Dimensions.commondims(dims, dim_names)) == length(dims) || + throw(InferenceDataSchemaError("$name has non-unique dimension names: $dim_names")) + # - groups that contain samples must contain the sample dimensions + all(Dimensions.hasdim(data, required_dims)) || throw( + InferenceDataSchemaError( + "$name does not have the required dimensions: $required_dims" + ), + ) + # - variables must not share names with dimensions + shared_var_dims = Dimensions.commondims(DimensionalData.keys(data), dims) + if !isempty(shared_var_dims) + throw( + InferenceDataSchemaError( + "$name has variables and dimensions with the same name: $(Dimensions.name(shared_var_dims))", + ), + ) + end + metadata = DimensionalData.metadata(data) + eltype(keys(metadata)) <: Symbol || + throw(InferenceDataSchemaError("$name has metadata with non-Symbol keys.")) + # - each group contains the attribute `:created_at`. + haskey(metadata, :created_at) || + throw(InferenceDataSchemaError("$name has no `:created_at` entry in its metadata.")) + return nothing +end +function check_follows_schema(data::InferenceData; indices=false) + # each group separately should follow the schema + foreach( + zip(groupnames(data), groups(data), _maybe_sample_dims(data)) + ) do (name, group, sample_dims) + check_follows_schema(group; sample_dims, name) + end + sample_dims = DEFAULT_SAMPLE_DIMS + # posterior-derived groups should share same sample dims/indices + check_dims_match(_posterior_related_groups(data), sample_dims; indices) + # prior-derived groups should share same sample dims/indices + check_dims_match(_prior_related_groups(data), sample_dims; indices) + # any parameters shared by prior and posterior should have the same non-sample dims/indices + if hasgroup(data, :posterior) && hasgroup(data, :prior) + var_dims = Dimensions.otherdims( + Dimensions.commondims(data.posterior, data.prior), sample_dims + ) + check_dims_match(NamedTuple{(:posterior, :prior)}(data), var_dims; indices) + end + # any dim names shared by log_likelihood, prior_predictive, posterior_predictive and + # observed_data must share the same indices + data_groups = _data_related_groups(data) + if length(data_groups) > 1 + data_dims = Dimensions.otherdims(Dimensions.commondims(data_groups...), sample_dims) + check_dims_match(data_groups, data_dims; indices) + end + return nothing +end + +""" + follows_schema(data; kwargs...) -> Bool + +Return whether `data` follows the [InferenceData schema](https://python.arviz.org/en/v$(arviz_version())/schema/schema.html). + +`kwargs` are passed to [`check_follows_schema`](@ref). + +See [`check_follows_schema`](@ref) for informative error messages. +""" +function follows_schema(data; kwargs...) + try + check_follows_schema(data; kwargs...) + return true + catch e + e isa InferenceDataSchemaError && return false + throw(e) + end +end + +@generated function _maybe_sample_dims(::InferenceData{group_names}) where {group_names} + return map(group_names) do name + if name ∈ ( + :posterior, + :posterior_predictive, + :predictive, + :sample_stats, + :log_likelihood, + :prior, + :prior_predictive, + :sample_stats_prior, + ) + return DEFAULT_SAMPLE_DIMS + else + return () + end + end +end + +@generated function _filter_groups_type( + ::InferenceData{groups}, ::Val{other} +) where {groups,other} + shared = Tuple(intersect(groups, other)) + return NamedTuple{shared} +end + +function _posterior_related_groups(idata::InferenceData) + posterior_groups = ( + :posterior, :posterior_predictive, :sample_stats, :log_likelihood, :predictive + ) + return _filter_groups_type(idata, Val(posterior_groups))(idata) +end + +function _prior_related_groups(idata::InferenceData) + prior_groups = (:prior, :prior_predictive, :sample_stats_prior) + return _filter_groups_type(idata, Val(prior_groups))(idata) +end + +function _data_related_groups(idata::InferenceData) + data_groups = ( + :observed_data, :log_likelihood, :prior_predictive, :posterior_predictive + ) + return _filter_groups_type(idata, Val(data_groups))(idata) +end + +function check_dims_match(groups, dims; indices::Bool=true) + isempty(groups) && return nothing + try + # dims = Dimensions.commondims(first(groups), _dims) + # if length(dims) != length(_dims) + # dims_missing = setdiff(Dimensions.name(_dims), Dimensions.name(dims)) + # throw(ErrorException("dimensions missing $(dims_missing)")) + # end + comparesomedims(groups...; dims, val=indices) + catch e + if e isa DimensionMismatch + throw( + InferenceDataSchemaError( + "Dimension mismatch in groups $(keys(groups)): $(e.msg)" + ), + ) + else + throw(e) + end + end + return nothing +end + +function comparesomedims(datasets::Dataset...; dims, val=false) + sub_dims = map(datasets) do ds + Dimensions.sortdims(Dimensions.commondims(ds, dims), dims) + end + sub_dims_ref = first(sub_dims) + for _dims in Iterators.drop(sub_dims, 1) + Dimensions.comparedims(sub_dims_ref, _dims; val) + end + return nothing +end