Skip to content

Commit d42362a

Browse files
Correct log-weights implementation (#125)
* Correct log-weights implementation
1 parent bc0aef0 commit d42362a

File tree

3 files changed

+27
-6
lines changed

3 files changed

+27
-6
lines changed

.github/copilot-instructions.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ ParetoSmooth.jl is a Julia package for efficient approximate leave-one-out cross
6060
### Testing Requirements
6161
- Run tests before committing: `julia --project=. -e "using Pkg; Pkg.test()"`
6262
- **CRITICAL**: NEVER CANCEL tests even if they appear to hang - they take 2 minutes
63-
- All 48 tests must pass for a valid build
63+
- All tests must pass for a valid build
6464
- Tests validate against R reference data in `test/data/`
6565

6666
## Common Tasks

src/ImportanceSampling.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ end
202202

203203
function psis(is_ratios::AbstractVector{<:Real}, args...; kwargs...)
204204
new_ratios = copy(is_ratios)
205-
ξ = psis!(new_ratios, kwargs...)
205+
ξ = psis!(new_ratios; kwargs...)
206206
return new_ratios, ξ
207207
end
208208

@@ -253,13 +253,13 @@ function psis!(is_ratios::AbstractVector{T}, r_eff::T=one(T);
253253
is_ratios .= first.(ratio_index)
254254
@views tail = is_ratios[tail_start:len]
255255
_check_tail(tail)
256+
cutoff = is_ratios[tail_start - 1]
256257
if log_weights
257258
biggest = maximum(tail)
258-
@. tail = exp(tail - biggest)
259+
@. is_ratios = exp(is_ratios - biggest)
260+
cutoff = exp(cutoff - biggest)
259261
end
260262

261-
# Get value just before the tail starts:
262-
cutoff = is_ratios[tail_start - 1]
263263
ξ = _psis_smooth_tail!(tail, cutoff, r_eff)
264264

265265
# truncate at max of raw weights (1 after scaling)
@@ -268,7 +268,7 @@ function psis!(is_ratios::AbstractVector{T}, r_eff::T=one(T);
268268
invpermute!(is_ratios, last.(ratio_index))
269269

270270
if log_weights
271-
@. tail = log(tail + biggest)
271+
@. is_ratios = log(is_ratios) + biggest
272272
end
273273

274274
return ξ

test/tests/BasicTests.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using AxisKeys
22
using NamedDims
33
using Statistics
44
import RData
5+
using Random
56

67
@testset "Basic Arrays" begin
78

@@ -102,3 +103,23 @@ import RData
102103
@test ParetoSmooth.naive_lpd(log_lik_arr) jul_loo.estimates(:naive_lpd, :total)
103104
@test ParetoSmooth.naive_lpd(log_lik_arr) r_eff_loo.estimates(:naive_lpd, :total)
104105
end
106+
107+
@testset "Log-weights vs. raw weights" begin
108+
Random.seed!(123)
109+
log_lik = randn(30)
110+
r_eff = 1.01
111+
112+
# Test with log_weights = true
113+
log_weights_true = copy(log_lik)
114+
xi_true = psis!(log_weights_true, r_eff; log_weights=true)
115+
116+
# Test with log_weights = false
117+
raw_weights_false = exp.(log_lik .- maximum(log_lik))
118+
xi_false = psis!(raw_weights_false, r_eff; log_weights=false)
119+
120+
# The shape parameters should be the same
121+
@test xi_true xi_false atol=1e-9
122+
123+
# The smoothed weights should also be equivalent
124+
@test exp.(log_weights_true .- maximum(log_lik)) raw_weights_false rtol=1e-9
125+
end

0 commit comments

Comments
 (0)