@@ -204,7 +204,7 @@ pub struct EmpiricalDistribution(MaybeMultiplexed<crate::quant::EmpiricalDistrib
204204/// is true in general.
205205#[ pyclass]
206206#[ derive( Debug ) ]
207- pub struct RatedGrid ( MaybeMultiplexed < crate :: quant:: StaticRatedGrid > ) ;
207+ pub struct RatedGrid ( MaybeMultiplexed < crate :: quant:: DynamicRatedGrid > ) ;
208208
209209#[ derive( Debug ) ]
210210enum MaybeMultiplexed < T > {
@@ -1270,15 +1270,15 @@ impl EmpiricalDistribution {
12701270 pub fn rated_grid ( & self ) -> RatedGrid {
12711271 let rated_grid = match & self . 0 {
12721272 MaybeMultiplexed :: Single ( distribution) => {
1273- MaybeMultiplexed :: Single ( distribution. rated_grid ( ) )
1273+ MaybeMultiplexed :: Single ( distribution. dynamic_rated_grid ( ) )
12741274 }
12751275 MaybeMultiplexed :: Multiple {
12761276 distributions,
12771277 axis,
12781278 } => MaybeMultiplexed :: Multiple {
12791279 distributions : distributions
12801280 . iter ( )
1281- . map ( |distribution| distribution. rated_grid ( ) )
1281+ . map ( |distribution| distribution. dynamic_rated_grid ( ) )
12821282 . collect ( ) ,
12831283 axis : * axis,
12841284 } ,
@@ -1406,9 +1406,8 @@ impl RatedGrid {
14061406 py : Python < ' _ > ,
14071407 index : Option < usize > ,
14081408 ) -> PyResult < ( PyObject , PyObject ) > {
1409- self . 0 . extract_data ( py, index, move |grid| {
1410- grid. points_and_rates ( ) . iter ( ) . copied ( )
1411- } )
1409+ self . 0
1410+ . extract_data ( py, index, move |grid| grid. points_and_rates ( ) )
14121411 }
14131412}
14141413
@@ -1668,8 +1667,10 @@ fn vbq<'p>(
16681667 update_prior,
16691668 reference,
16701669 |prior, old, new| {
1671- prior. remove ( old, 1 ) ?;
1672- prior. insert ( new, 1 ) ;
1670+ if old != new {
1671+ prior. remove ( old, 1 ) ?;
1672+ prior. insert ( new, 1 ) ;
1673+ }
16731674 Ok ( ( ) )
16741675 } ,
16751676 )
@@ -1702,8 +1703,10 @@ fn vbq_(
17021703 update_prior,
17031704 reference,
17041705 |prior, old, new| {
1705- prior. remove ( old, 1 ) ?;
1706- prior. insert ( new, 1 ) ;
1706+ if old != new {
1707+ prior. remove ( old, 1 ) ?;
1708+ prior. insert ( new, 1 ) ;
1709+ }
17071710 Ok ( ( ) )
17081711 } ,
17091712 )
@@ -1750,6 +1753,25 @@ fn vbq_(
17501753/// counterparts). The argument `rate_penalty` is a convenience. Setting `rate_penalty` to a value
17511754/// different from `1.0` has the same effect as multiplying all entries of `posterior_variance` by
17521755/// `rate_penalty`.
1756+ /// - `update_prior`: optional boolean that decides whether the provided `prior` will be updated
1757+ /// after quantizing each entry of `unquantized`. Defaults to `false` if now `reference` is
1758+ /// provided. Providing a `reverence` implies `update_prior=True.`
1759+ /// Setting `update_prior=True` has two effects:
1760+ /// (i) once `vbq` terminates, all `unquantized` (or `reference`) points are removed from `prior`
1761+ /// and replaced by the (returned) quantized points; and
1762+ /// (ii) since the updates occur by piecemeal immediately once each entry was quantized, entries
1763+ /// towards the end of the array `unquantized` are quantized with a better estimate of the final
1764+ /// distribution of quantized points. However, this also means that each entry of `unquantized`
1765+ /// gets quantized with a different prior, and therefore potentially to a slightly different grid,
1766+ /// which can result in spurious clusters of grid points that lie very close to each other. For
1767+ /// this reason, setting `update_prior=True` is recommended only for intermediate runs of VBQ that
1768+ /// are part of some convergence process. Any final run of VBQ should set `update_prior=False`.
1769+ /// - `reference`: an optional array with same dimensions as `unquantized`. This array contains
1770+ /// the result from a previous quantization. If provided, then, after quantizing each value
1771+ /// `unquantized[indices]`, the `RatedGrid` will be updated by reducing the count of grid points
1772+ /// at grid point `reference[indices]` by one and increasing the count of grid points at the new
1773+ /// quantized value by one. These changes in counts lead to changes in the rates associated with
1774+ /// the grid points.
17531775///
17541776/// ## Example 1: quantization with a *global* grid (i.e. without `specialize_along_axis`)
17551777///
@@ -1892,6 +1914,7 @@ fn rate_distortion_quantization<'p>(
18921914 grid : Py < RatedGrid > ,
18931915 posterior_variance : PyReadonlyF32ArrayOrScalar < ' p > ,
18941916 rate_penalty : f32 ,
1917+ reference : Option < PyReadwriteArrayDyn < ' p , f32 > > ,
18951918) -> PyResult < & ' p PyArrayDyn < f32 > > {
18961919 quantize_out_of_place (
18971920 RateDistortionQuantization ,
@@ -1901,8 +1924,14 @@ fn rate_distortion_quantization<'p>(
19011924 posterior_variance,
19021925 rate_penalty,
19031926 None ,
1904- None ,
1905- |_grid, _old, _new| Ok ( ( ) ) ,
1927+ reference,
1928+ |grid : & mut crate :: quant:: DynamicRatedGrid , old, new| {
1929+ if old != new {
1930+ grid. remove ( old, 1 ) ?;
1931+ grid. insert ( new, 1 ) . expect ( "new grid point exists" ) ;
1932+ }
1933+ Ok ( ( ) )
1934+ } ,
19061935 )
19071936}
19081937
@@ -1922,6 +1951,7 @@ fn rate_distortion_quantization_(
19221951 grid : Py < RatedGrid > ,
19231952 posterior_variance : PyReadonlyF32ArrayOrScalar < ' _ > ,
19241953 rate_penalty : f32 ,
1954+ reference : Option < PyReadwriteArrayDyn < ' _ , f32 > > ,
19251955) -> PyResult < ( ) > {
19261956 quantize_in_place (
19271957 RateDistortionQuantization ,
@@ -1930,8 +1960,14 @@ fn rate_distortion_quantization_(
19301960 posterior_variance,
19311961 rate_penalty,
19321962 None ,
1933- None ,
1934- |_grid, _old, _new| Ok ( ( ) ) ,
1963+ reference,
1964+ |grid : & mut crate :: quant:: DynamicRatedGrid , old, new| {
1965+ if old != new {
1966+ grid. remove ( old, 1 ) ?;
1967+ grid. insert ( new, 1 ) . expect ( "new grid point exists" ) ;
1968+ }
1969+ Ok ( ( ) )
1970+ } ,
19351971 )
19361972}
19371973
0 commit comments