1
- use std:: { error:: Error , fmt:: Debug } ;
1
+ use std:: { error:: Error , fmt:: Debug , mem :: replace } ;
2
2
3
3
use faer:: { Col , Mat } ;
4
- use itertools:: izip;
4
+ use itertools:: { izip, Itertools } ;
5
5
use thiserror:: Error ;
6
6
7
7
use crate :: {
@@ -33,16 +33,35 @@ pub enum CpuMathError {
33
33
impl < F : CpuLogpFunc > Math for CpuMath < F > {
34
34
type Vector = Col < f64 > ;
35
35
type EigVectors = Mat < f64 > ;
36
- type EigValues = Mat < f64 > ;
36
+ type EigValues = Col < f64 > ;
37
37
type LogpErr = F :: LogpError ;
38
38
type Err = CpuMathError ;
39
39
40
40
fn new_array ( & self ) -> Self :: Vector {
41
41
Col :: zeros ( self . dim ( ) )
42
42
}
43
43
44
- fn logp ( & mut self , position : & [ f64 ] , gradient : & mut [ f64 ] ) -> Result < f64 , Self :: LogpErr > {
45
- self . logp_func . logp ( position, gradient)
44
+ fn new_eig_vectors < ' a > (
45
+ & ' a mut self ,
46
+ vals : impl ExactSizeIterator < Item = & ' a [ f64 ] > ,
47
+ ) -> Self :: EigVectors {
48
+ let ndim = self . dim ( ) ;
49
+ let nvecs = vals. len ( ) ;
50
+
51
+ let mut vectors: Mat < f64 > = Mat :: zeros ( ndim, nvecs) ;
52
+ vectors. col_iter_mut ( ) . zip_eq ( vals) . for_each ( |( col, vals) | {
53
+ col. try_as_slice_mut ( )
54
+ . expect ( "Array is not contiguous" )
55
+ . copy_from_slice ( vals)
56
+ } ) ;
57
+
58
+ vectors
59
+ }
60
+
61
+ fn new_eig_values ( & mut self , vals : & [ f64 ] ) -> Self :: EigValues {
62
+ let mut values: Col < f64 > = Col :: zeros ( vals. len ( ) ) ;
63
+ values. as_slice_mut ( ) . copy_from_slice ( vals) ;
64
+ values
46
65
}
47
66
48
67
fn logp_array (
@@ -54,6 +73,10 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
54
73
. logp ( position. as_slice ( ) , gradient. as_slice_mut ( ) )
55
74
}
56
75
76
+ fn logp ( & mut self , position : & [ f64 ] , gradient : & mut [ f64 ] ) -> Result < f64 , Self :: LogpErr > {
77
+ self . logp_func . logp ( position, gradient)
78
+ }
79
+
57
80
fn dim ( & self ) -> usize {
58
81
self . logp_func . dim ( )
59
82
}
@@ -136,6 +159,22 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
136
159
multiply ( array1. as_slice ( ) , array2. as_slice ( ) , dest. as_slice_mut ( ) )
137
160
}
138
161
162
+ fn array_mult_eigs (
163
+ & mut self ,
164
+ stds : & Self :: Vector ,
165
+ rhs : & Self :: Vector ,
166
+ dest : & mut Self :: Vector ,
167
+ vecs : & Self :: EigVectors ,
168
+ vals : & Self :: EigValues ,
169
+ ) {
170
+ let rhs = stds. column_vector_as_diagonal ( ) * rhs;
171
+ let trafo = vecs. transpose ( ) * ( & rhs) ;
172
+ let inner_prod = vecs * ( vals. column_vector_as_diagonal ( ) * ( & trafo) - ( & trafo) ) + rhs;
173
+ let scaled = stds. column_vector_as_diagonal ( ) * inner_prod;
174
+
175
+ let _ = replace ( dest, scaled) ;
176
+ }
177
+
139
178
fn array_vector_dot ( & mut self , array1 : & Self :: Vector , array2 : & Self :: Vector ) -> f64 {
140
179
vector_dot ( array1. as_slice ( ) , array2. as_slice ( ) )
141
180
}
@@ -156,6 +195,28 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
156
195
} ) ;
157
196
}
158
197
198
+ fn array_gaussian_eigs < R : rand:: Rng + ?Sized > (
199
+ & mut self ,
200
+ rng : & mut R ,
201
+ dest : & mut Self :: Vector ,
202
+ scale : & Self :: Vector ,
203
+ vals : & Self :: EigValues ,
204
+ vecs : & Self :: EigVectors ,
205
+ ) {
206
+ let mut draw: Col < f64 > = Col :: zeros ( self . dim ( ) ) ;
207
+ let dist = rand_distr:: StandardNormal ;
208
+ draw. as_slice_mut ( ) . iter_mut ( ) . for_each ( |p| {
209
+ * p = rng. sample ( dist) ;
210
+ } ) ;
211
+
212
+ let trafo = vecs. transpose ( ) * ( & draw) ;
213
+ let inner_prod = vecs * ( vals. column_vector_as_diagonal ( ) * ( & trafo) - ( & trafo) ) + draw;
214
+
215
+ let scaled = scale. column_vector_as_diagonal ( ) * inner_prod;
216
+
217
+ let _ = replace ( dest, scaled) ;
218
+ }
219
+
159
220
fn array_update_variance (
160
221
& mut self ,
161
222
mean : & mut Self :: Vector ,
@@ -177,6 +238,37 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
177
238
} )
178
239
}
179
240
241
+ fn array_update_var_inv_std_draw (
242
+ & mut self ,
243
+ variance_out : & mut Self :: Vector ,
244
+ inv_std : & mut Self :: Vector ,
245
+ draw_var : & Self :: Vector ,
246
+ scale : f64 ,
247
+ fill_invalid : Option < f64 > ,
248
+ clamp : ( f64 , f64 ) ,
249
+ ) {
250
+ self . arch . dispatch ( || {
251
+ izip ! (
252
+ variance_out. as_slice_mut( ) . iter_mut( ) ,
253
+ inv_std. as_slice_mut( ) . iter_mut( ) ,
254
+ draw_var. as_slice( ) . iter( ) ,
255
+ )
256
+ . for_each ( |( var_out, inv_std_out, & draw_var) | {
257
+ let draw_var = draw_var * scale;
258
+ if ( !draw_var. is_finite ( ) ) | ( draw_var == 0f64 ) {
259
+ if let Some ( fill_val) = fill_invalid {
260
+ * var_out = fill_val;
261
+ * inv_std_out = fill_val. recip ( ) . sqrt ( ) ;
262
+ }
263
+ } else {
264
+ let val = draw_var. clamp ( clamp. 0 , clamp. 1 ) ;
265
+ * var_out = val;
266
+ * inv_std_out = val. recip ( ) . sqrt ( ) ;
267
+ }
268
+ } ) ;
269
+ } ) ;
270
+ }
271
+
180
272
fn array_update_var_inv_std_draw_grad (
181
273
& mut self ,
182
274
variance_out : & mut Self :: Vector ,
@@ -232,56 +324,8 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
232
324
} ) ;
233
325
}
234
326
235
- fn array_update_var_inv_std_draw (
236
- & mut self ,
237
- variance_out : & mut Self :: Vector ,
238
- inv_std : & mut Self :: Vector ,
239
- draw_var : & Self :: Vector ,
240
- scale : f64 ,
241
- fill_invalid : Option < f64 > ,
242
- clamp : ( f64 , f64 ) ,
243
- ) {
244
- self . arch . dispatch ( || {
245
- izip ! (
246
- variance_out. as_slice_mut( ) . iter_mut( ) ,
247
- inv_std. as_slice_mut( ) . iter_mut( ) ,
248
- draw_var. as_slice( ) . iter( ) ,
249
- )
250
- . for_each ( |( var_out, inv_std_out, & draw_var) | {
251
- let draw_var = draw_var * scale;
252
- if ( !draw_var. is_finite ( ) ) | ( draw_var == 0f64 ) {
253
- if let Some ( fill_val) = fill_invalid {
254
- * var_out = fill_val;
255
- * inv_std_out = fill_val. recip ( ) . sqrt ( ) ;
256
- }
257
- } else {
258
- let val = draw_var. clamp ( clamp. 0 , clamp. 1 ) ;
259
- * var_out = val;
260
- * inv_std_out = val. recip ( ) . sqrt ( ) ;
261
- }
262
- } ) ;
263
- } ) ;
264
- }
265
-
266
- fn new_eig_vectors < ' a > (
267
- & ' a mut self ,
268
- vals : impl ExactSizeIterator < Item = & ' a [ f64 ] > ,
269
- ) -> Self :: EigVectors {
270
- todo ! ( )
271
- }
272
-
273
- fn new_eig_values ( & mut self , vals : & [ f64 ] ) -> Self :: EigValues {
274
- todo ! ( )
275
- }
276
-
277
- fn scaled_eigval_matmul (
278
- & mut self ,
279
- scale : & Self :: Vector ,
280
- vals : & Self :: EigValues ,
281
- vecs : & Self :: EigVectors ,
282
- out : & mut Self :: Vector ,
283
- ) {
284
- todo ! ( )
327
+ fn eigs_as_array ( & mut self , source : & Self :: EigValues ) -> Box < [ f64 ] > {
328
+ source. as_slice ( ) . to_vec ( ) . into ( )
285
329
}
286
330
}
287
331
0 commit comments