-
Notifications
You must be signed in to change notification settings - Fork 1
Description
I tried porting unbalanced_sinkhorn! to rust (which I don't know at all) using ChatGPT, just as an experiment1. This is probably not very good or idiomatic rust code, but it does produce the same value as the Julia package. When I tried porting the sinkhorn_divergence code the rust compiler alerted me to the aliasing issue in #11. It also seems to run the example in main a few times faster than my Julia code, but I don't know if it's cheating given the inputs are getting compiled in rather than being runtime inputs.
use ndarray::Array;
use ndarray::{Array2, Axis};
use rand::{Rng, SeedableRng};
use std::time::Instant;
// Simple version of:
// https://github.com/ericphanson/UnbalancedOptimalTransport.jl/blob/2573b06fd4ec28c0e219ab4585443060450bf1e0/src/UnbalancedOptimalTransport.jl#L16-L49
#[derive(Debug)]
struct DiscreteMeasure {
log_density: Vec<f64>,
dual_potential: Vec<f64>,
cache: Vec<f64>,
}
impl DiscreteMeasure {
fn new(log_density: Vec<f64>) -> Self {
let n = log_density.len();
let dual_potential = vec![0.0; n];
let cache = vec![0.0; n];
Self {
log_density,
dual_potential,
cache,
}
}
}
// https://github.com/ericphanson/UnbalancedOptimalTransport.jl/blob/2573b06fd4ec28c0e219ab4585443060450bf1e0/test/runtests.jl#L11-L26
fn rand_measure(n: usize, scale: f64, seed: u64) -> DiscreteMeasure {
let mut rng: rand::rngs::StdRng = SeedableRng::seed_from_u64(seed);
let log_density = (0..n)
.map(|_| scale * rng.gen::<f64>())
.collect::<Vec<f64>>();
DiscreteMeasure::new(log_density)
}
// https://github.com/ericphanson/UnbalancedOptimalTransport.jl/blob/2573b06fd4ec28c0e219ab4585443060450bf1e0/src/divergences.jl#L49
fn approx(ρ: f64, ϵ: f64, x: f64) -> f64 {
(1.0 / (1.0 + ϵ / ρ)) * x
}
// https://github.com/ericphanson/UnbalancedOptimalTransport.jl/blob/2573b06fd4ec28c0e219ab4585443060450bf1e0/src/utilities.jl#L1-L18
fn logsumexp(w: &mut [f64]) -> f64 {
let n = w.len();
let (offset, maxind) = {
let max_index = w.iter().enumerate().max_by(|&(_, a), &(_, b)| a.partial_cmp(b).unwrap()).unwrap().0;
(w[max_index], max_index)
};
for elem in w.iter_mut() {
*elem = (*elem - offset).exp();
}
let sum_except_max: f64 = {
w[maxind] -= 1.0;
let s = w.iter().sum();
w[maxind] += 1.0;
s
};
(sum_except_max).ln_1p() + offset
}
// https://github.com/ericphanson/UnbalancedOptimalTransport.jl/blob/2573b06fd4ec28c0e219ab4585443060450bf1e0/src/sinkhorn.jl#L13-L105
fn unbalanced_sinkhorn(
D: f64,
C: &Array2<f64>,
a: &mut DiscreteMeasure,
b: &mut DiscreteMeasure,
ϵ: f64,
tol: f64,
max_iters: usize,
warn: bool,
) -> (usize, f64) {
a.dual_potential.iter_mut().for_each(|x| *x = 0.0);
b.dual_potential.iter_mut().for_each(|x| *x = 0.0);
let mut max_residual = f64::INFINITY;
let mut iters = 0;
let f = &mut a.dual_potential;
let mut tmp_f = &mut a.cache;
let g = &mut b.dual_potential;
let mut tmp_g = &mut b.cache;
let min_length_a = a.log_density.len().min(tmp_f.len()).min(C.len_of(Axis(0)));
let min_length_b = b.log_density.len().min(tmp_g.len()).min(C.len_of(Axis(0)));
while iters < max_iters && max_residual > tol {
iters += 1;
max_residual = 0.0;
for j in 0..g.len() {
for i in 0..min_length_a {
tmp_f[i] = a.log_density[i] + (f[i] - C[[i, j]]) / ϵ;
}
let new_g = -ϵ * logsumexp(&mut tmp_f);
let new_g = -approx(D, ϵ, -new_g);
let diff = (g[j] - new_g).abs();
if diff > max_residual {
max_residual = diff;
}
g[j] = new_g;
}
for j in 0..f.len() {
for i in 0..min_length_b {
tmp_g[i] = b.log_density[i] + (g[i] - C[[j, i]]) / ϵ;
}
let new_f = -ϵ * logsumexp(&mut tmp_g);
let new_f = -approx(D, ϵ, -new_f);
let diff = (f[j] - new_f).abs();
if diff > max_residual {
max_residual = diff;
}
f[j] = new_f;
}
}
if warn && iters == max_iters {
println!("Maximum iterations ({}) reached", max_iters);
}
(iters, max_residual)
}
fn main() {
let n = 5; // Define the dimension n
let scale = 10.0;
let seed_a = 1;
let seed_b = 2;
// Generate inputs a and b using rand_measure
let mut a = rand_measure(n, scale, seed_a);
let mut b = rand_measure(n, scale, seed_b);
// Create a sample cost matrix C (for demonstration purposes)
let C = Array::from_shape_fn((n, n), |(i, j)| (i + j) as f64);
println!("Input a: {:?}", a);
println!("Input b: {:?}", b);
println!("Input C: {:?}", C);
let start_time = Instant::now();
// Run the unbalanced_sinkhorn algorithm
let (iters, max_residual) =
unbalanced_sinkhorn(1.0, &C, &mut a, &mut b, 1e-1, 1e-5, 10_000, true);
let elapsed_time = start_time.elapsed();
println!("Elapsed time: {:?}", elapsed_time);
println!("Iterations: {}", iters);
println!("Max residual: {}", max_residual);
println!("a.dual_potential: {:?}", a.dual_potential);
println!("b.dual_potential: {:?}", b.dual_potential);
}Footnotes
-
It does seem like a useful tool for code translation, but I think I would've been better off learning a bit more rust syntax/concepts first, because I feel like I managed to get it running without getting a good understanding the language or code. ↩