From 3a423e609c411a3617db00fd603f7a14dcfae90f Mon Sep 17 00:00:00 2001 From: Carlos Parada <71727937+ParadaCarleton@users.noreply.github.com> Date: Thu, 9 Oct 2025 20:57:57 +0000 Subject: [PATCH 1/3] Correct log-weights implementation --- src/ImportanceSampling.jl | 10 +++++----- test/tests/BasicTests.jl | 21 +++++++++++++++++++++ 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/src/ImportanceSampling.jl b/src/ImportanceSampling.jl index ecf61fb..5cfc8bc 100644 --- a/src/ImportanceSampling.jl +++ b/src/ImportanceSampling.jl @@ -202,7 +202,7 @@ end function psis(is_ratios::AbstractVector{<:Real}, args...; kwargs...) new_ratios = copy(is_ratios) - ξ = psis!(new_ratios, kwargs...) + ξ = psis!(new_ratios; kwargs...) return new_ratios, ξ end @@ -253,13 +253,13 @@ function psis!(is_ratios::AbstractVector{T}, r_eff::T=one(T); is_ratios .= first.(ratio_index) @views tail = is_ratios[tail_start:len] _check_tail(tail) + cutoff = is_ratios[tail_start - 1] if log_weights biggest = maximum(tail) - @. tail = exp(tail - biggest) + @. is_ratios = exp(is_ratios - biggest) + cutoff = exp(cutoff - biggest) end - # Get value just before the tail starts: - cutoff = is_ratios[tail_start - 1] ξ = _psis_smooth_tail!(tail, cutoff, r_eff) # truncate at max of raw weights (1 after scaling) @@ -268,7 +268,7 @@ function psis!(is_ratios::AbstractVector{T}, r_eff::T=one(T); invpermute!(is_ratios, last.(ratio_index)) if log_weights - @. tail = log(tail + biggest) + @. is_ratios = log(is_ratios) + biggest end return ξ diff --git a/test/tests/BasicTests.jl b/test/tests/BasicTests.jl index 440d775..3be31b0 100644 --- a/test/tests/BasicTests.jl +++ b/test/tests/BasicTests.jl @@ -2,6 +2,7 @@ using AxisKeys using NamedDims using Statistics import RData +using Random @testset "Basic Arrays" begin @@ -102,3 +103,23 @@ import RData @test ParetoSmooth.naive_lpd(log_lik_arr) ≈ jul_loo.estimates(:naive_lpd, :total) @test ParetoSmooth.naive_lpd(log_lik_arr) ≈ r_eff_loo.estimates(:naive_lpd, :total) end + +@testset "Log-weights vs. raw weights" begin + Random.seed!(123) + log_lik = randn(30) + r_eff = 1.01 + + # Test with log_weights = true + log_weights_true = copy(log_lik) + log_weights_true, xi_true = psis!(log_weights_true, r_eff; log_weights=true) + + # Test with log_weights = false + raw_weights_false = exp.(log_lik .- maximum(log_lik)) + xi_false = psis!(raw_weights_false, r_eff; log_weights=false) + + # The shape parameters should be the same + @test xi_true ≈ xi_false atol=1e-9 + + # The smoothed weights should also be equivalent + @test exp.(log_weights_true .- maximum(log_lik)) ≈ raw_weights_false rtol=1e-9 +end From 886a6b08fb711919aac8e97b2d6148f2f7c8be48 Mon Sep 17 00:00:00 2001 From: Carlos Parada <71727937+ParadaCarleton@users.noreply.github.com> Date: Fri, 10 Oct 2025 22:20:56 +0000 Subject: [PATCH 2/3] Fix instructions (there's 50 tests now) --- .github/copilot-instructions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 2aec4a9..9e0ea74 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -60,7 +60,7 @@ ParetoSmooth.jl is a Julia package for efficient approximate leave-one-out cross ### Testing Requirements - Run tests before committing: `julia --project=. -e "using Pkg; Pkg.test()"` - **CRITICAL**: NEVER CANCEL tests even if they appear to hang - they take 2 minutes -- All 48 tests must pass for a valid build +- All tests must pass for a valid build - Tests validate against R reference data in `test/data/` ## Common Tasks From b2886b45f03cc57ffc37922fd1a297825bd32c1e Mon Sep 17 00:00:00 2001 From: Carlos Parada <71727937+ParadaCarleton@users.noreply.github.com> Date: Fri, 10 Oct 2025 22:21:50 +0000 Subject: [PATCH 3/3] Fix error in tests --- test/tests/BasicTests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/tests/BasicTests.jl b/test/tests/BasicTests.jl index 3be31b0..24bb25a 100644 --- a/test/tests/BasicTests.jl +++ b/test/tests/BasicTests.jl @@ -111,7 +111,7 @@ end # Test with log_weights = true log_weights_true = copy(log_lik) - log_weights_true, xi_true = psis!(log_weights_true, r_eff; log_weights=true) + xi_true = psis!(log_weights_true, r_eff; log_weights=true) # Test with log_weights = false raw_weights_false = exp.(log_lik .- maximum(log_lik))