Skip to content

Commit f58ddfe

Browse files
authored
fix logsumexp and add a test (#11)
* fix logsumexp and add a test * drop printlns from test * bump version * format * changelog * add link
1 parent f6afef0 commit f58ddfe

File tree

4 files changed

+33
-3
lines changed

4 files changed

+33
-3
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# Changelog
22

3+
## [0.18.1] - 2025-01-07
4+
5+
### Fixed
6+
- Fixed logsumexp function for -inf values
7+
38
## [0.18.0] - 2024-06-24
49

510
### Added
@@ -215,6 +220,7 @@
215220
- Remove dependency on `quadrature` crate in favor of hand-rolled adaptive
216221
Simpson's rule, which handles multimodal distributions better.
217222

223+
[0.18.1]: https://github.com/promise-ai/rv/compare/v0.18.0...v0.18.1
218224
[0.18.0]: https://github.com/promise-ai/rv/compare/v0.17.0...v0.18.0
219225
[0.17.0]: https://github.com/promise-ai/rv/compare/v0.16.5...v0.17.0
220226
[0.16.5]: https://github.com/promise-ai/rv/compare/v0.16.4...v0.16.5

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "rv"
3-
version = "0.18.0"
3+
version = "0.18.1"
44
authors = ["Baxter Eaves", "Michael Schmidt", "Chad Scherrer"]
55
description = "Random variables"
66
repository = "https://github.com/promised-ai/rv"

src/misc/func.rs

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,10 @@ where
127127
let (alpha, r) =
128128
self.fold((f64::NEG_INFINITY, 0.0), |(alpha, r), x| {
129129
let x = *x.borrow();
130-
if x <= alpha {
130+
131+
if x == f64::NEG_INFINITY {
132+
return (alpha, r);
133+
} else if x <= alpha {
131134
(alpha, r + (x - alpha).exp())
132135
} else {
133136
(x, (alpha - x).exp().mul_add(r, 1.0))
@@ -864,6 +867,27 @@ mod tests {
864867
assert_eq!(argmax(&xs), vec![4, 6]);
865868
}
866869

870+
#[test]
871+
fn logsumexp_nan_handling() {
872+
let a: f64 = -3.0;
873+
let b: f64 = -7.0;
874+
let target: f64 = logaddexp(a, b);
875+
let xs = [
876+
-f64::INFINITY,
877+
a,
878+
-f64::INFINITY,
879+
b,
880+
-f64::INFINITY,
881+
-f64::INFINITY,
882+
-f64::INFINITY,
883+
-f64::INFINITY,
884+
-f64::INFINITY,
885+
-f64::INFINITY,
886+
];
887+
let result = xs.iter().logsumexp();
888+
assert!((result - target).abs() < 1e-12);
889+
}
890+
867891
proptest! {
868892
#[test]
869893
fn proptest_logsumexp(xs in prop::collection::vec(-1e10_f64..1e10_f64, 0..100)) {

0 commit comments

Comments
 (0)