@@ -360,15 +360,15 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
360
360
)
361
361
}
362
362
363
- fn transformed_logp (
363
+ fn init_from_untransformed_position (
364
364
& mut self ,
365
365
params : & Self :: TransformParams ,
366
366
untransformed_position : & Self :: Vector ,
367
367
untransformed_gradient : & mut Self :: Vector ,
368
368
transformed_position : & mut Self :: Vector ,
369
369
transformed_gradient : & mut Self :: Vector ,
370
370
) -> Result < ( f64 , f64 ) , Self :: LogpErr > {
371
- self . logp_func . transformed_logp (
371
+ self . logp_func . init_from_untransformed_position (
372
372
params,
373
373
untransformed_position. as_slice ( ) ,
374
374
untransformed_gradient. as_slice_mut ( ) ,
@@ -377,11 +377,28 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
377
377
)
378
378
}
379
379
380
+ fn init_from_transformed_position (
381
+ & mut self ,
382
+ params : & Self :: TransformParams ,
383
+ untransformed_position : & mut Self :: Vector ,
384
+ untransformed_gradient : & mut Self :: Vector ,
385
+ transformed_position : & Self :: Vector ,
386
+ transformed_gradient : & mut Self :: Vector ,
387
+ ) -> Result < ( f64 , f64 ) , Self :: LogpErr > {
388
+ self . logp_func . init_from_transformed_position (
389
+ params,
390
+ untransformed_position. as_slice_mut ( ) ,
391
+ untransformed_gradient. as_slice_mut ( ) ,
392
+ transformed_position. as_slice ( ) ,
393
+ transformed_gradient. as_slice_mut ( ) ,
394
+ )
395
+ }
396
+
380
397
fn update_transformation < ' a , R : rand:: Rng + ?Sized > (
381
398
& ' a mut self ,
382
399
rng : & mut R ,
383
- untransformed_positions : impl Iterator < Item = & ' a Self :: Vector > ,
384
- untransformed_gradients : impl Iterator < Item = & ' a Self :: Vector > ,
400
+ untransformed_positions : impl ExactSizeIterator < Item = & ' a Self :: Vector > ,
401
+ untransformed_gradients : impl ExactSizeIterator < Item = & ' a Self :: Vector > ,
385
402
params : & ' a mut Self :: TransformParams ,
386
403
) -> Result < ( ) , Self :: LogpErr > {
387
404
self . logp_func . update_transformation (
@@ -392,18 +409,22 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
392
409
)
393
410
}
394
411
395
- fn new_transformation (
412
+ fn new_transformation < R : rand :: Rng + ? Sized > (
396
413
& mut self ,
414
+ rng : & mut R ,
397
415
untransformed_position : & Self :: Vector ,
398
416
untransfogmed_gradient : & Self :: Vector ,
417
+ chain : u64 ,
399
418
) -> Result < Self :: TransformParams , Self :: LogpErr > {
400
419
self . logp_func . new_transformation (
420
+ rng,
401
421
untransformed_position. as_slice ( ) ,
402
422
untransfogmed_gradient. as_slice ( ) ,
423
+ chain,
403
424
)
404
425
}
405
426
406
- fn transformation_id ( & self , params : & Self :: TransformParams ) -> i64 {
427
+ fn transformation_id ( & self , params : & Self :: TransformParams ) -> Result < i64 , Self :: LogpErr > {
407
428
self . logp_func . transformation_id ( params)
408
429
}
409
430
}
@@ -417,35 +438,58 @@ pub trait CpuLogpFunc {
417
438
418
439
fn inv_transform_normalize (
419
440
& mut self ,
420
- params : & Self :: TransformParams ,
421
- untransformed_position : & [ f64 ] ,
422
- untransofrmed_gradient : & [ f64 ] ,
423
- transformed_position : & mut [ f64 ] ,
424
- transformed_gradient : & mut [ f64 ] ,
425
- ) -> Result < f64 , Self :: LogpError > ;
441
+ _params : & Self :: TransformParams ,
442
+ _untransformed_position : & [ f64 ] ,
443
+ _untransformed_gradient : & [ f64 ] ,
444
+ _transformed_position : & mut [ f64 ] ,
445
+ _transformed_gradient : & mut [ f64 ] ,
446
+ ) -> Result < f64 , Self :: LogpError > {
447
+ unimplemented ! ( )
448
+ }
426
449
427
- fn transformed_logp (
450
+ fn init_from_untransformed_position (
428
451
& mut self ,
429
- params : & Self :: TransformParams ,
430
- untransformed_position : & [ f64 ] ,
431
- untransformed_gradient : & mut [ f64 ] ,
432
- transformed_position : & mut [ f64 ] ,
433
- transformed_gradient : & mut [ f64 ] ,
434
- ) -> Result < ( f64 , f64 ) , Self :: LogpError > ;
452
+ _params : & Self :: TransformParams ,
453
+ _untransformed_position : & [ f64 ] ,
454
+ _untransformed_gradient : & mut [ f64 ] ,
455
+ _transformed_position : & mut [ f64 ] ,
456
+ _transformed_gradient : & mut [ f64 ] ,
457
+ ) -> Result < ( f64 , f64 ) , Self :: LogpError > {
458
+ unimplemented ! ( )
459
+ }
460
+
461
+ fn init_from_transformed_position (
462
+ & mut self ,
463
+ _params : & Self :: TransformParams ,
464
+ _untransformed_position : & mut [ f64 ] ,
465
+ _untransformed_gradient : & mut [ f64 ] ,
466
+ _transformed_position : & [ f64 ] ,
467
+ _transformed_gradient : & mut [ f64 ] ,
468
+ ) -> Result < ( f64 , f64 ) , Self :: LogpError > {
469
+ unimplemented ! ( )
470
+ }
435
471
436
472
fn update_transformation < ' a , R : rand:: Rng + ?Sized > (
437
473
& ' a mut self ,
438
- rng : & mut R ,
439
- untransformed_positions : impl Iterator < Item = & ' a [ f64 ] > ,
440
- untransformed_gradients : impl Iterator < Item = & ' a [ f64 ] > ,
441
- params : & ' a mut Self :: TransformParams ,
442
- ) -> Result < ( ) , Self :: LogpError > ;
474
+ _rng : & mut R ,
475
+ _untransformed_positions : impl ExactSizeIterator < Item = & ' a [ f64 ] > ,
476
+ _untransformed_gradients : impl ExactSizeIterator < Item = & ' a [ f64 ] > ,
477
+ _params : & ' a mut Self :: TransformParams ,
478
+ ) -> Result < ( ) , Self :: LogpError > {
479
+ unimplemented ! ( )
480
+ }
443
481
444
- fn new_transformation (
482
+ fn new_transformation < R : rand :: Rng + ? Sized > (
445
483
& mut self ,
446
- untransformed_position : & [ f64 ] ,
447
- untransfogmed_gradient : & [ f64 ] ,
448
- ) -> Result < Self :: TransformParams , Self :: LogpError > ;
484
+ _rng : & mut R ,
485
+ _untransformed_position : & [ f64 ] ,
486
+ _untransformed_gradient : & [ f64 ] ,
487
+ _chain : u64 ,
488
+ ) -> Result < Self :: TransformParams , Self :: LogpError > {
489
+ unimplemented ! ( )
490
+ }
449
491
450
- fn transformation_id ( & self , params : & Self :: TransformParams ) -> i64 ;
492
+ fn transformation_id ( & self , _params : & Self :: TransformParams ) -> Result < i64 , Self :: LogpError > {
493
+ unimplemented ! ( )
494
+ }
451
495
}
0 commit comments