diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 32d5e024..41aca6ef 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -16,6 +16,7 @@ jobs: matrix: version: - '1.11' # Replace this with the minimum Julia version that your package supports. E.g. if your package requires Julia 1.5 or higher, change this to '1.5'. + - '1' os: - ubuntu-latest arch: diff --git a/Project.toml b/Project.toml index 317e662e..50f1fa3f 100644 --- a/Project.toml +++ b/Project.toml @@ -8,10 +8,7 @@ ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FunctionZeros = "b21f74c0-b399-568f-9643-d20f4fa2c814" HCubature = "19dc6840-f33b-545b-b366-655c7e3ffd49" -Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" -KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" @@ -20,14 +17,18 @@ StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [weakdeps] +DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" +Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" +KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" [extensions] -PlotsExt = "Plots" +PlotsExt = ["Plots", "Interpolations", "KernelDensity"] [compat] ArgCheck = "2.5.0" Distributions = "v0.24.6, 0.25" +DynamicPPL = "0.25 - 0.39" FunctionZeros = "0.2.0,0.3.0, 1" HCubature = "1" Interpolations = "0.14.0,0.15.0,0.16.0" diff --git a/src/SequentialSamplingModels.jl b/src/SequentialSamplingModels.jl index a723ed6c..1a3d19ce 100644 --- a/src/SequentialSamplingModels.jl +++ b/src/SequentialSamplingModels.jl @@ -64,6 +64,7 @@ export PoissonRace export ShiftedLogNormal export SSM1D export SSM2D +export SSMProductDistribution export stDDM export ContinuousMultivariateSSM export Wald @@ -85,7 +86,7 @@ export plot_model export plot_model! export plot_quantiles export plot_quantiles! -export predict_distribution +export product_distribution export rand export simulate export std diff --git a/src/multi_choice_models/MDFT.jl b/src/multi_choice_models/MDFT.jl index 9e22af95..f346ab7a 100644 --- a/src/multi_choice_models/MDFT.jl +++ b/src/multi_choice_models/MDFT.jl @@ -272,7 +272,7 @@ make_default_contrast(3) -0.5 -0.5 1.0 ``` """ -function make_default_contrast(n) +function make_default_contrast(n::Integer) C = fill(0.0, n, n) C .= -1 / (n - 1) for r ∈ 1:n diff --git a/src/product_distribution.jl b/src/product_distribution.jl index d19cbd02..73761f63 100644 --- a/src/product_distribution.jl +++ b/src/product_distribution.jl @@ -1,7 +1,38 @@ +""" + SSMProductDistribution + +Wrapper around `ProductDistribution` for sequential sampling models. +This type allows us to define `logpdf` methods for `NamedTuple` data +without type piracy. +""" +struct SSMProductDistribution{D <: ProductDistribution} + dist::D +end + +""" + product_distribution(dists) + +Create a product distribution from a vector of distributions. +Returns an `SSMProductDistribution` for SSM types, or a standard +`ProductDistribution` for other types. +""" +function product_distribution(dists::AbstractVector) + pd = ProductDistribution(dists) + # Check if this is an SSM that produces NamedTuple data + if eltype(dists) <: SSM2D + return SSMProductDistribution(pd) + else + return pd + end +end + +Base.size(s::SSMProductDistribution, dims...) = size(s.dist, dims...) +Base.length(s::SSMProductDistribution) = length(s.dist) + function rand( rng::AbstractRNG, - s::Sampleable{T, R} -) where {T <: Matrixvariate, R <: SequentialSamplingModels.Mixed} + s::SSMProductDistribution +) n = size(s, 2) data = (; choice = fill(0, n), rt = fill(0.0, n)) return rand!(rng, s, data) @@ -9,9 +40,9 @@ end function rand( rng::AbstractRNG, - s::Sampleable{T, R}, + s::SSMProductDistribution, dims::Dims -) where {T <: Matrixvariate, R <: SequentialSamplingModels.Mixed} +) n = size(s, 2) ax = map(Base.OneTo, dims) data = [(; choice = fill(0, n), rt = fill(0.0, n)) for _ in Iterators.product(ax...)] @@ -20,23 +51,23 @@ end function rand!( rng::AbstractRNG, - s::Sampleable{T, R}, + s::SSMProductDistribution, data::NamedTuple -) where {T <: Matrixvariate, R <: SequentialSamplingModels.Mixed} +) for i ∈ 1:size(s, 2) - data.choice[i], data.rt[i] = rand(rng, s.dists[i]) + data.choice[i], data.rt[i] = rand(rng, s.dist.dists[i]) end return data end -function logpdf(d::ProductDistribution, data_array::Array{<:NamedTuple, N}) where {N} +function logpdf(d::SSMProductDistribution, data_array::Array{<:NamedTuple, N}) where {N} return [logpdf(d, data) for data ∈ data_array] end -function logpdf(d::ProductDistribution, data::NamedTuple) +function logpdf(d::SSMProductDistribution, data::NamedTuple) LL = 0.0 - for i ∈ 1:length(d.dists) - LL += logpdf(d.dists[i], data.choice[i], data.rt[i]) + for i ∈ 1:length(d.dist.dists) + LL += logpdf(d.dist.dists[i], data.choice[i], data.rt[i]) end return LL end diff --git a/test/Project.toml b/test/Project.toml index 85d21439..efaa9327 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,8 @@ [deps] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" @@ -14,5 +16,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" TuringUtilities = "35dc62cd-6c01-44e1-a736-6cc36bfce0cc" -[sources.TuringUtilities] -url = "https://github.com/itsdfish/TuringUtilities.jl" \ No newline at end of file +[sources] +TuringUtilities = {rev = "main", url = "https://github.com/itsdfish/TuringUtilities.jl"} \ No newline at end of file diff --git a/test/codequality.jl b/test/codequality.jl new file mode 100644 index 00000000..50c37775 --- /dev/null +++ b/test/codequality.jl @@ -0,0 +1,21 @@ +@safetestset "Code Quality" begin + + # check code quality via Aqua + @safetestset "Aqua" begin + using Aqua + using SequentialSamplingModels + Aqua.test_all( + SequentialSamplingModels; + ambiguities = false, + deps_compat = (check_extras = false,), + project_extras = false + ) + end + + # test JET + @safetestset "JET" begin + using JET + using SequentialSamplingModels + JET.test_package(SequentialSamplingModels; target_modules = (SequentialSamplingModels,)) + end +end diff --git a/test/product_distribution_tests.jl b/test/product_distribution_tests.jl index 876dd8e6..72dacf9b 100644 --- a/test/product_distribution_tests.jl +++ b/test/product_distribution_tests.jl @@ -2,6 +2,7 @@ @safetestset "rand SSM1D 1" begin using Distributions using SequentialSamplingModels + using SequentialSamplingModels: product_distribution using Test walds = [Wald(; ν = 2.5, α = 0.1, τ = 0.2), Wald(; ν = 1.5, α = 1, τ = 10)] @@ -15,6 +16,7 @@ @safetestset "rand SSM1D 2" begin using Distributions using SequentialSamplingModels + using SequentialSamplingModels: product_distribution using Test walds = [Wald(; ν = 2.5, α = 0.1, τ = 0.2), Wald(; ν = 1.5, α = 1, τ = 10)] @@ -28,6 +30,7 @@ @safetestset "rand logpdf 1" begin using Distributions using SequentialSamplingModels + using SequentialSamplingModels: product_distribution using Test walds = [Wald(; ν = 2.5, α = 0.1, τ = 0.2), Wald(; ν = 1.5, α = 1, τ = 10)] @@ -41,6 +44,7 @@ @safetestset "logpdf SSM1D 2" begin using Distributions using SequentialSamplingModels + using SequentialSamplingModels: product_distribution using Test walds = [Wald(; ν = 2.5, α = 0.1, τ = 0.2), Wald(; ν = 1.5, α = 1, τ = 10)] @@ -54,6 +58,7 @@ @safetestset "rand SSM2D 1" begin using Distributions using SequentialSamplingModels + using SequentialSamplingModels: product_distribution using Test lbas = [ @@ -70,6 +75,7 @@ @safetestset "rand SSM2D 2" begin using Distributions using SequentialSamplingModels + using SequentialSamplingModels: product_distribution using Test lbas = [ @@ -86,6 +92,7 @@ @safetestset "logpdf SSM2D 1" begin using Distributions using SequentialSamplingModels + using SequentialSamplingModels: product_distribution using Test lbas = [ @@ -103,6 +110,7 @@ @safetestset "logpdf SSM2D 2" begin using Distributions using SequentialSamplingModels + using SequentialSamplingModels: product_distribution using Test lbas = [