Skip to content
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ jobs:
matrix:
version:
- '1.11' # Replace this with the minimum Julia version that your package supports. E.g. if your package requires Julia 1.5 or higher, change this to '1.5'.
- '1'
os:
- ubuntu-latest
arch:
Expand Down
9 changes: 5 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@ ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FunctionZeros = "b21f74c0-b399-568f-9643-d20f4fa2c814"
HCubature = "19dc6840-f33b-545b-b366-655c7e3ffd49"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Expand All @@ -20,14 +17,18 @@ StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"

[extensions]
PlotsExt = "Plots"
PlotsExt = ["Plots", "Interpolations", "KernelDensity"]

[compat]
ArgCheck = "2.5.0"
Distributions = "v0.24.6, 0.25"
DynamicPPL = "0.25 - 0.39"
FunctionZeros = "0.2.0,0.3.0, 1"
HCubature = "1"
Interpolations = "0.14.0,0.15.0,0.16.0"
Expand Down
3 changes: 2 additions & 1 deletion src/SequentialSamplingModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ export PoissonRace
export ShiftedLogNormal
export SSM1D
export SSM2D
export SSMProductDistribution
export stDDM
export ContinuousMultivariateSSM
export Wald
Expand All @@ -85,7 +86,7 @@ export plot_model
export plot_model!
export plot_quantiles
export plot_quantiles!
export predict_distribution
export product_distribution
export rand
export simulate
export std
Expand Down
2 changes: 1 addition & 1 deletion src/multi_choice_models/MDFT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ make_default_contrast(3)
-0.5 -0.5 1.0
```
"""
function make_default_contrast(n)
function make_default_contrast(n::Integer)
C = fill(0.0, n, n)
C .= -1 / (n - 1)
for r ∈ 1:n
Expand Down
53 changes: 42 additions & 11 deletions src/product_distribution.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,48 @@
"""
SSMProductDistribution

Wrapper around `ProductDistribution` for sequential sampling models.
This type allows us to define `logpdf` methods for `NamedTuple` data
without type piracy.
"""
struct SSMProductDistribution{D <: ProductDistribution}
dist::D
end

"""
product_distribution(dists)

Create a product distribution from a vector of distributions.
Returns an `SSMProductDistribution` for SSM types, or a standard
`ProductDistribution` for other types.
"""
function product_distribution(dists::AbstractVector)
pd = ProductDistribution(dists)
# Check if this is an SSM that produces NamedTuple data
if eltype(dists) <: SSM2D
return SSMProductDistribution(pd)
else
return pd
end
end

Base.size(s::SSMProductDistribution, dims...) = size(s.dist, dims...)
Base.length(s::SSMProductDistribution) = length(s.dist)

function rand(
rng::AbstractRNG,
s::Sampleable{T, R}
) where {T <: Matrixvariate, R <: SequentialSamplingModels.Mixed}
s::SSMProductDistribution
)
n = size(s, 2)
data = (; choice = fill(0, n), rt = fill(0.0, n))
return rand!(rng, s, data)
end

function rand(
rng::AbstractRNG,
s::Sampleable{T, R},
s::SSMProductDistribution,
dims::Dims
) where {T <: Matrixvariate, R <: SequentialSamplingModels.Mixed}
)
n = size(s, 2)
ax = map(Base.OneTo, dims)
data = [(; choice = fill(0, n), rt = fill(0.0, n)) for _ in Iterators.product(ax...)]
Expand All @@ -20,23 +51,23 @@ end

function rand!(
rng::AbstractRNG,
s::Sampleable{T, R},
s::SSMProductDistribution,
data::NamedTuple
) where {T <: Matrixvariate, R <: SequentialSamplingModels.Mixed}
)
for i ∈ 1:size(s, 2)
data.choice[i], data.rt[i] = rand(rng, s.dists[i])
data.choice[i], data.rt[i] = rand(rng, s.dist.dists[i])
end
return data
end

function logpdf(d::ProductDistribution, data_array::Array{<:NamedTuple, N}) where {N}
function logpdf(d::SSMProductDistribution, data_array::Array{<:NamedTuple, N}) where {N}
return [logpdf(d, data) for data ∈ data_array]
end

function logpdf(d::ProductDistribution, data::NamedTuple)
function logpdf(d::SSMProductDistribution, data::NamedTuple)
LL = 0.0
for i ∈ 1:length(d.dists)
LL += logpdf(d.dists[i], data.choice[i], data.rt[i])
for i ∈ 1:length(d.dist.dists)
LL += logpdf(d.dist.dists[i], data.choice[i], data.rt[i])
end
return LL
end
6 changes: 4 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Expand All @@ -14,5 +16,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
TuringUtilities = "35dc62cd-6c01-44e1-a736-6cc36bfce0cc"

[sources.TuringUtilities]
url = "https://github.com/itsdfish/TuringUtilities.jl"
[sources]
TuringUtilities = {rev = "main", url = "https://github.com/itsdfish/TuringUtilities.jl"}
21 changes: 21 additions & 0 deletions test/codequality.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
@safetestset "Code Quality" begin

# check code quality via Aqua
@safetestset "Aqua" begin
using Aqua
using SequentialSamplingModels
Aqua.test_all(
SequentialSamplingModels;
ambiguities = false,
deps_compat = (check_extras = false,),
project_extras = false
)
end

# test JET
@safetestset "JET" begin
using JET
using SequentialSamplingModels
JET.test_package(SequentialSamplingModels; target_modules = (SequentialSamplingModels,))
end
end
8 changes: 8 additions & 0 deletions test/product_distribution_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
@safetestset "rand SSM1D 1" begin
using Distributions
using SequentialSamplingModels
using SequentialSamplingModels: product_distribution
using Test

walds = [Wald(; ν = 2.5, α = 0.1, τ = 0.2), Wald(; ν = 1.5, α = 1, τ = 10)]
Expand All @@ -15,6 +16,7 @@
@safetestset "rand SSM1D 2" begin
using Distributions
using SequentialSamplingModels
using SequentialSamplingModels: product_distribution
using Test

walds = [Wald(; ν = 2.5, α = 0.1, τ = 0.2), Wald(; ν = 1.5, α = 1, τ = 10)]
Expand All @@ -28,6 +30,7 @@
@safetestset "rand logpdf 1" begin
using Distributions
using SequentialSamplingModels
using SequentialSamplingModels: product_distribution
using Test

walds = [Wald(; ν = 2.5, α = 0.1, τ = 0.2), Wald(; ν = 1.5, α = 1, τ = 10)]
Expand All @@ -41,6 +44,7 @@
@safetestset "logpdf SSM1D 2" begin
using Distributions
using SequentialSamplingModels
using SequentialSamplingModels: product_distribution
using Test

walds = [Wald(; ν = 2.5, α = 0.1, τ = 0.2), Wald(; ν = 1.5, α = 1, τ = 10)]
Expand All @@ -54,6 +58,7 @@
@safetestset "rand SSM2D 1" begin
using Distributions
using SequentialSamplingModels
using SequentialSamplingModels: product_distribution
using Test

lbas = [
Expand All @@ -70,6 +75,7 @@
@safetestset "rand SSM2D 2" begin
using Distributions
using SequentialSamplingModels
using SequentialSamplingModels: product_distribution
using Test

lbas = [
Expand All @@ -86,6 +92,7 @@
@safetestset "logpdf SSM2D 1" begin
using Distributions
using SequentialSamplingModels
using SequentialSamplingModels: product_distribution
using Test

lbas = [
Expand All @@ -103,6 +110,7 @@
@safetestset "logpdf SSM2D 2" begin
using Distributions
using SequentialSamplingModels
using SequentialSamplingModels: product_distribution
using Test

lbas = [
Expand Down
Loading