@@ -3,13 +3,11 @@ use std::marker::PhantomData;
33use rand:: Rng ;
44
55use crate :: {
6- chain:: AdaptStrategy ,
7- euclidean_hamiltonian:: { EuclideanHamiltonian , EuclideanPoint } ,
8- hamiltonian:: { Hamiltonian , Point } ,
6+ euclidean_hamiltonian:: EuclideanPoint ,
7+ hamiltonian:: Point ,
98 mass_matrix:: { DiagMassMatrix , DrawGradCollector , MassMatrix , RunningVariance } ,
10- nuts:: NutsOptions ,
9+ nuts:: { Collector , NutsOptions } ,
1110 sampler_stats:: SamplerStats ,
12- state:: State ,
1311 Math , NutsError , Settings ,
1412} ;
1513const LOWER_LIMIT : f64 = 1e-20f64 ;
@@ -43,8 +41,10 @@ pub struct Strategy<M: Math> {
4341 _phantom : PhantomData < M > ,
4442}
4543
46- pub trait MassMatrixAdaptStrategy < M : Math > : AdaptStrategy < M > {
44+ pub trait MassMatrixAdaptStrategy < M : Math > : SamplerStats < M > {
4745 type MassMatrix : MassMatrix < M > ;
46+ type Collector : Collector < M , EuclideanPoint < M > > ;
47+ type Options : std:: fmt:: Debug + Default + Clone + Send + Sync + Copy ;
4848
4949 fn update_estimators ( & mut self , math : & mut M , collector : & Self :: Collector ) ;
5050
@@ -55,11 +55,26 @@ pub trait MassMatrixAdaptStrategy<M: Math>: AdaptStrategy<M> {
5555 fn background_count ( & self ) -> u64 ;
5656
5757 /// Give the opportunity to update the potential and return if it was changed
58- fn update_potential ( & self , math : & mut M , potential : & mut Self :: Hamiltonian ) -> bool ;
58+ fn adapt ( & self , math : & mut M , mass_matrix : & mut Self :: MassMatrix ) -> bool ;
59+
60+ fn new ( math : & mut M , options : Self :: Options , _num_tune : u64 , _chain : u64 ) -> Self ;
61+
62+ fn init < R : Rng + ?Sized > (
63+ & mut self ,
64+ math : & mut M ,
65+ _options : & mut NutsOptions ,
66+ mass_matrix : & mut Self :: MassMatrix ,
67+ point : & impl Point < M > ,
68+ _rng : & mut R ,
69+ ) -> Result < ( ) , NutsError > ;
70+
71+ fn new_collector ( & self , math : & mut M ) -> Self :: Collector ;
5972}
6073
6174impl < M : Math > MassMatrixAdaptStrategy < M > for Strategy < M > {
6275 type MassMatrix = DiagMassMatrix < M > ;
76+ type Collector = DrawGradCollector < M > ;
77+ type Options = DiagAdaptExpSettings ;
6378
6479 fn update_estimators ( & mut self , math : & mut M , collector : & DrawGradCollector < M > ) {
6580 if collector. is_good {
@@ -88,11 +103,7 @@ impl<M: Math> MassMatrixAdaptStrategy<M> for Strategy<M> {
88103 }
89104
90105 /// Give the opportunity to update the potential and return if it was changed
91- fn update_potential (
92- & self ,
93- math : & mut M ,
94- potential : & mut EuclideanHamiltonian < M , Self :: MassMatrix > ,
95- ) -> bool {
106+ fn adapt ( & self , math : & mut M , mass_matrix : & mut DiagMassMatrix < M > ) -> bool {
96107 if self . current_count ( ) < 3 {
97108 return false ;
98109 }
@@ -102,7 +113,7 @@ impl<M: Math> MassMatrixAdaptStrategy<M> for Strategy<M> {
102113 assert ! ( draw_scale == grad_scale) ;
103114
104115 if self . _settings . use_grad_based_estimate {
105- potential . mass_matrix . update_diag_draw_grad (
116+ mass_matrix. update_diag_draw_grad (
106117 math,
107118 draw_var,
108119 grad_var,
@@ -111,35 +122,11 @@ impl<M: Math> MassMatrixAdaptStrategy<M> for Strategy<M> {
111122 ) ;
112123 } else {
113124 let scale = ( self . exp_variance_draw . count ( ) as f64 ) . recip ( ) ;
114- potential. mass_matrix . update_diag_draw (
115- math,
116- draw_var,
117- scale,
118- None ,
119- ( LOWER_LIMIT , UPPER_LIMIT ) ,
120- ) ;
125+ mass_matrix. update_diag_draw ( math, draw_var, scale, None , ( LOWER_LIMIT , UPPER_LIMIT ) ) ;
121126 }
122127
123128 true
124129 }
125- }
126-
127- pub type Stats = ( ) ;
128- pub type StatsBuilder = ( ) ;
129-
130- impl < M : Math > SamplerStats < M > for Strategy < M > {
131- type Builder = Stats ;
132- type Stats = StatsBuilder ;
133-
134- fn new_builder ( & self , _settings : & impl Settings , _dim : usize ) -> Self :: Builder { }
135-
136- fn current_stats ( & self , _math : & mut M ) -> Self :: Stats { }
137- }
138-
139- impl < M : Math > AdaptStrategy < M > for Strategy < M > {
140- type Hamiltonian = EuclideanHamiltonian < M , DiagMassMatrix < M > > ;
141- type Collector = DrawGradCollector < M > ;
142- type Options = DiagAdaptExpSettings ;
143130
144131 fn new ( math : & mut M , options : Self :: Options , _num_tune : u64 , _chain : u64 ) -> Self {
145132 Self {
@@ -156,49 +143,37 @@ impl<M: Math> AdaptStrategy<M> for Strategy<M> {
156143 & mut self ,
157144 math : & mut M ,
158145 _options : & mut NutsOptions ,
159- hamiltonian : & mut Self :: Hamiltonian ,
160- position : & [ f64 ] ,
146+ mass_matrix : & mut Self :: MassMatrix ,
147+ point : & impl Point < M > ,
161148 _rng : & mut R ,
162149 ) -> Result < ( ) , NutsError > {
163- let state = hamiltonian. init_state ( math, position) ?;
164-
165- self . exp_variance_draw
166- . add_sample ( math, state. point ( ) . position ( ) ) ;
167- self . exp_variance_draw_bg
168- . add_sample ( math, state. point ( ) . position ( ) ) ;
169- self . exp_variance_grad
170- . add_sample ( math, state. point ( ) . gradient ( ) ) ;
171- self . exp_variance_grad_bg
172- . add_sample ( math, state. point ( ) . gradient ( ) ) ;
173-
174- hamiltonian. mass_matrix . update_diag_grad (
150+ self . exp_variance_draw . add_sample ( math, point. position ( ) ) ;
151+ self . exp_variance_draw_bg . add_sample ( math, point. position ( ) ) ;
152+ self . exp_variance_grad . add_sample ( math, point. gradient ( ) ) ;
153+ self . exp_variance_grad_bg . add_sample ( math, point. gradient ( ) ) ;
154+
155+ mass_matrix. update_diag_grad (
175156 math,
176- state . point ( ) . gradient ( ) ,
157+ point. gradient ( ) ,
177158 1f64 ,
178159 ( INIT_LOWER_LIMIT , INIT_UPPER_LIMIT ) ,
179160 ) ;
180161 Ok ( ( ) )
181162 }
182163
183- fn adapt < R : Rng + ?Sized > (
184- & mut self ,
185- _math : & mut M ,
186- _options : & mut NutsOptions ,
187- _potential : & mut Self :: Hamiltonian ,
188- _draw : u64 ,
189- _collector : & Self :: Collector ,
190- _state : & State < M , EuclideanPoint < M > > ,
191- _rng : & mut R ,
192- ) -> Result < ( ) , NutsError > {
193- // Must be controlled from a different meta strategy
194- Ok ( ( ) )
195- }
196-
197164 fn new_collector ( & self , math : & mut M ) -> Self :: Collector {
198165 DrawGradCollector :: new ( math)
199166 }
167+ }
200168
201- fn is_tuning ( & self ) -> bool {
202- unreachable ! ( )
203- }
169+ pub type Stats = ( ) ;
170+ pub type StatsBuilder = ( ) ;
171+
172+ impl < M : Math > SamplerStats < M > for Strategy < M > {
173+ type Builder = Stats ;
174+ type Stats = StatsBuilder ;
175+
176+ fn new_builder ( & self , _settings : & impl Settings , _dim : usize ) -> Self :: Builder { }
177+
178+ fn current_stats ( & self , _math : & mut M ) -> Self :: Stats { }
204179}
0 commit comments