Skip to content

Commit 0a69e64

Browse files
committed
Python: add reference argument to R/D-quant
1 parent e59fe9e commit 0a69e64

File tree

1 file changed

+50
-14
lines changed

1 file changed

+50
-14
lines changed

src/pybindings/quant.rs

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ pub struct EmpiricalDistribution(MaybeMultiplexed<crate::quant::EmpiricalDistrib
204204
/// is true in general.
205205
#[pyclass]
206206
#[derive(Debug)]
207-
pub struct RatedGrid(MaybeMultiplexed<crate::quant::StaticRatedGrid>);
207+
pub struct RatedGrid(MaybeMultiplexed<crate::quant::DynamicRatedGrid>);
208208

209209
#[derive(Debug)]
210210
enum MaybeMultiplexed<T> {
@@ -1270,15 +1270,15 @@ impl EmpiricalDistribution {
12701270
pub fn rated_grid(&self) -> RatedGrid {
12711271
let rated_grid = match &self.0 {
12721272
MaybeMultiplexed::Single(distribution) => {
1273-
MaybeMultiplexed::Single(distribution.rated_grid())
1273+
MaybeMultiplexed::Single(distribution.dynamic_rated_grid())
12741274
}
12751275
MaybeMultiplexed::Multiple {
12761276
distributions,
12771277
axis,
12781278
} => MaybeMultiplexed::Multiple {
12791279
distributions: distributions
12801280
.iter()
1281-
.map(|distribution| distribution.rated_grid())
1281+
.map(|distribution| distribution.dynamic_rated_grid())
12821282
.collect(),
12831283
axis: *axis,
12841284
},
@@ -1406,9 +1406,8 @@ impl RatedGrid {
14061406
py: Python<'_>,
14071407
index: Option<usize>,
14081408
) -> PyResult<(PyObject, PyObject)> {
1409-
self.0.extract_data(py, index, move |grid| {
1410-
grid.points_and_rates().iter().copied()
1411-
})
1409+
self.0
1410+
.extract_data(py, index, move |grid| grid.points_and_rates())
14121411
}
14131412
}
14141413

@@ -1668,8 +1667,10 @@ fn vbq<'p>(
16681667
update_prior,
16691668
reference,
16701669
|prior, old, new| {
1671-
prior.remove(old, 1)?;
1672-
prior.insert(new, 1);
1670+
if old != new {
1671+
prior.remove(old, 1)?;
1672+
prior.insert(new, 1);
1673+
}
16731674
Ok(())
16741675
},
16751676
)
@@ -1702,8 +1703,10 @@ fn vbq_(
17021703
update_prior,
17031704
reference,
17041705
|prior, old, new| {
1705-
prior.remove(old, 1)?;
1706-
prior.insert(new, 1);
1706+
if old != new {
1707+
prior.remove(old, 1)?;
1708+
prior.insert(new, 1);
1709+
}
17071710
Ok(())
17081711
},
17091712
)
@@ -1750,6 +1753,25 @@ fn vbq_(
17501753
/// counterparts). The argument `rate_penalty` is a convenience. Setting `rate_penalty` to a value
17511754
/// different from `1.0` has the same effect as multiplying all entries of `posterior_variance` by
17521755
/// `rate_penalty`.
1756+
/// - `update_prior`: optional boolean that decides whether the provided `prior` will be updated
1757+
/// after quantizing each entry of `unquantized`. Defaults to `false` if now `reference` is
1758+
/// provided. Providing a `reverence` implies `update_prior=True.`
1759+
/// Setting `update_prior=True` has two effects:
1760+
/// (i) once `vbq` terminates, all `unquantized` (or `reference`) points are removed from `prior`
1761+
/// and replaced by the (returned) quantized points; and
1762+
/// (ii) since the updates occur by piecemeal immediately once each entry was quantized, entries
1763+
/// towards the end of the array `unquantized` are quantized with a better estimate of the final
1764+
/// distribution of quantized points. However, this also means that each entry of `unquantized`
1765+
/// gets quantized with a different prior, and therefore potentially to a slightly different grid,
1766+
/// which can result in spurious clusters of grid points that lie very close to each other. For
1767+
/// this reason, setting `update_prior=True` is recommended only for intermediate runs of VBQ that
1768+
/// are part of some convergence process. Any final run of VBQ should set `update_prior=False`.
1769+
/// - `reference`: an optional array with same dimensions as `unquantized`. This array contains
1770+
/// the result from a previous quantization. If provided, then, after quantizing each value
1771+
/// `unquantized[indices]`, the `RatedGrid` will be updated by reducing the count of grid points
1772+
/// at grid point `reference[indices]` by one and increasing the count of grid points at the new
1773+
/// quantized value by one. These changes in counts lead to changes in the rates associated with
1774+
/// the grid points.
17531775
///
17541776
/// ## Example 1: quantization with a *global* grid (i.e. without `specialize_along_axis`)
17551777
///
@@ -1892,6 +1914,7 @@ fn rate_distortion_quantization<'p>(
18921914
grid: Py<RatedGrid>,
18931915
posterior_variance: PyReadonlyF32ArrayOrScalar<'p>,
18941916
rate_penalty: f32,
1917+
reference: Option<PyReadwriteArrayDyn<'p, f32>>,
18951918
) -> PyResult<&'p PyArrayDyn<f32>> {
18961919
quantize_out_of_place(
18971920
RateDistortionQuantization,
@@ -1901,8 +1924,14 @@ fn rate_distortion_quantization<'p>(
19011924
posterior_variance,
19021925
rate_penalty,
19031926
None,
1904-
None,
1905-
|_grid, _old, _new| Ok(()),
1927+
reference,
1928+
|grid: &mut crate::quant::DynamicRatedGrid, old, new| {
1929+
if old != new {
1930+
grid.remove(old, 1)?;
1931+
grid.insert(new, 1).expect("new grid point exists");
1932+
}
1933+
Ok(())
1934+
},
19061935
)
19071936
}
19081937

@@ -1922,6 +1951,7 @@ fn rate_distortion_quantization_(
19221951
grid: Py<RatedGrid>,
19231952
posterior_variance: PyReadonlyF32ArrayOrScalar<'_>,
19241953
rate_penalty: f32,
1954+
reference: Option<PyReadwriteArrayDyn<'_, f32>>,
19251955
) -> PyResult<()> {
19261956
quantize_in_place(
19271957
RateDistortionQuantization,
@@ -1930,8 +1960,14 @@ fn rate_distortion_quantization_(
19301960
posterior_variance,
19311961
rate_penalty,
19321962
None,
1933-
None,
1934-
|_grid, _old, _new| Ok(()),
1963+
reference,
1964+
|grid: &mut crate::quant::DynamicRatedGrid, old, new| {
1965+
if old != new {
1966+
grid.remove(old, 1)?;
1967+
grid.insert(new, 1).expect("new grid point exists");
1968+
}
1969+
Ok(())
1970+
},
19351971
)
19361972
}
19371973

0 commit comments

Comments
 (0)