Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/copilot-instructions.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ ParetoSmooth.jl is a Julia package for efficient approximate leave-one-out cross
### Testing Requirements
- Run tests before committing: `julia --project=. -e "using Pkg; Pkg.test()"`
- **CRITICAL**: NEVER CANCEL tests even if they appear to hang - they take 2 minutes
- All 48 tests must pass for a valid build
- All tests must pass for a valid build
- Tests validate against R reference data in `test/data/`

## Common Tasks
Expand Down
10 changes: 5 additions & 5 deletions src/ImportanceSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ end

function psis(is_ratios::AbstractVector{<:Real}, args...; kwargs...)
new_ratios = copy(is_ratios)
ξ = psis!(new_ratios, kwargs...)
ξ = psis!(new_ratios; kwargs...)
return new_ratios, ξ
end

Expand Down Expand Up @@ -253,13 +253,13 @@ function psis!(is_ratios::AbstractVector{T}, r_eff::T=one(T);
is_ratios .= first.(ratio_index)
@views tail = is_ratios[tail_start:len]
_check_tail(tail)
cutoff = is_ratios[tail_start - 1]
if log_weights
biggest = maximum(tail)
@. tail = exp(tail - biggest)
@. is_ratios = exp(is_ratios - biggest)
cutoff = exp(cutoff - biggest)
end

# Get value just before the tail starts:
cutoff = is_ratios[tail_start - 1]
ξ = _psis_smooth_tail!(tail, cutoff, r_eff)

# truncate at max of raw weights (1 after scaling)
Expand All @@ -268,7 +268,7 @@ function psis!(is_ratios::AbstractVector{T}, r_eff::T=one(T);
invpermute!(is_ratios, last.(ratio_index))

if log_weights
@. tail = log(tail + biggest)
@. is_ratios = log(is_ratios) + biggest
end

return ξ
Expand Down
21 changes: 21 additions & 0 deletions test/tests/BasicTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using AxisKeys
using NamedDims
using Statistics
import RData
using Random

@testset "Basic Arrays" begin

Expand Down Expand Up @@ -102,3 +103,23 @@ import RData
@test ParetoSmooth.naive_lpd(log_lik_arr) ≈ jul_loo.estimates(:naive_lpd, :total)
@test ParetoSmooth.naive_lpd(log_lik_arr) ≈ r_eff_loo.estimates(:naive_lpd, :total)
end

@testset "Log-weights vs. raw weights" begin
Random.seed!(123)
log_lik = randn(30)
r_eff = 1.01

# Test with log_weights = true
log_weights_true = copy(log_lik)
xi_true = psis!(log_weights_true, r_eff; log_weights=true)

# Test with log_weights = false
raw_weights_false = exp.(log_lik .- maximum(log_lik))
xi_false = psis!(raw_weights_false, r_eff; log_weights=false)

# The shape parameters should be the same
@test xi_true ≈ xi_false atol=1e-9

# The smoothed weights should also be equivalent
@test exp.(log_weights_true .- maximum(log_lik)) ≈ raw_weights_false rtol=1e-9
end