@@ -4,14 +4,13 @@ use argmin::{
44} ;
55use argmin_observer_slog:: SlogLogger ;
66use diffsol:: {
7- DiffSl , OdeBuilder , OdeEquations , OdeSolverMethod , OdeSolverProblem ,
8- SensitivitiesOdeSolverMethod ,
7+ DiffSl , MatrixCommon , OdeBuilder , OdeEquations , OdeSolverMethod , OdeSolverProblem , Op ,
8+ SensitivitiesOdeSolverMethod , Vector ,
99} ;
10- use nalgebra:: { DMatrix , DVector } ;
1110use std:: cell:: RefCell ;
1211
13- type M = DMatrix < f64 > ;
14- type V = DVector < f64 > ;
12+ type M = diffsol :: NalgebraMat < f64 > ;
13+ type V = diffsol :: NalgebraVec < f64 > ;
1514type T = f64 ;
1615type LS = diffsol:: NalgebraLU < f64 > ;
1716type CG = diffsol:: LlvmModule ;
@@ -29,15 +28,19 @@ impl CostFunction for Problem {
2928
3029 fn cost ( & self , param : & Self :: Param ) -> Result < Self :: Output , argmin_math:: Error > {
3130 let mut problem = self . problem . borrow_mut ( ) ;
32- problem. eqn_mut ( ) . set_params ( & V :: from_vec ( param. clone ( ) ) ) ;
31+ let ctx = * problem. eqn ( ) . context ( ) ;
32+ problem
33+ . eqn_mut ( )
34+ . set_params ( & V :: from_vec ( param. clone ( ) , ctx) ) ;
3335 let mut solver = problem. bdf :: < LS > ( ) . unwrap ( ) ;
3436 let ys = match solver. solve_dense ( & self . ts_data ) {
3537 Ok ( ys) => ys,
3638 Err ( _) => return Ok ( f64:: MAX / 1000. ) ,
3739 } ;
3840 let loss = ys
41+ . inner ( )
3942 . column_iter ( )
40- . zip ( self . ys_data . column_iter ( ) )
43+ . zip ( self . ys_data . inner ( ) . column_iter ( ) )
4144 . map ( |( a, b) | ( a - b) . norm_squared ( ) )
4245 . sum :: < f64 > ( ) ;
4346 Ok ( loss)
@@ -50,7 +53,10 @@ impl Gradient for Problem {
5053
5154 fn gradient ( & self , param : & Self :: Param ) -> Result < Self :: Gradient , argmin_math:: Error > {
5255 let mut problem = self . problem . borrow_mut ( ) ;
53- problem. eqn_mut ( ) . set_params ( & V :: from_vec ( param. clone ( ) ) ) ;
56+ let ctx = * problem. eqn ( ) . context ( ) ;
57+ problem
58+ . eqn_mut ( )
59+ . set_params ( & V :: from_vec ( param. clone ( ) , ctx) ) ;
5460 let mut solver = problem. bdf_sens :: < LS > ( ) . unwrap ( ) ;
5561 let ( ys, sens) = match solver. solve_dense_sensitivities ( & self . ts_data ) {
5662 Ok ( ( ys, sens) ) => ( ys, sens) ,
@@ -59,8 +65,13 @@ impl Gradient for Problem {
5965 let dlossdp = sens
6066 . into_iter ( )
6167 . map ( |s| {
62- s. column_iter ( )
63- . zip ( ys. column_iter ( ) . zip ( self . ys_data . column_iter ( ) ) )
68+ s. inner ( )
69+ . column_iter ( )
70+ . zip (
71+ ys. inner ( )
72+ . column_iter ( )
73+ . zip ( self . ys_data . inner ( ) . column_iter ( ) ) ,
74+ )
6475 . map ( |( si, ( yi, di) ) | 2.0 * ( yi - di) . dot ( & si) )
6576 . sum :: < f64 > ( )
6677 } )
@@ -70,10 +81,9 @@ impl Gradient for Problem {
7081}
7182
7283pub fn main ( ) {
73- let eqn = DiffSl :: < M , CG > :: compile (
74- "
75- in = [ b, d ]
76- a { 2.0/3.0 } b { 4.0/3.0 } c { 1.0 } d { 1.0 } x0 { 1.0 } y0 { 1.0 }
84+ let code = "
85+ in_i { b = 4.0/3.0, d = 1.0 }
86+ a { 2.0/3.0 } c { 1.0 } x0 { 1.0 } y0 { 1.0 }
7787 u_i {
7888 y1 = x0,
7989 y2 = y0,
@@ -82,9 +92,7 @@ pub fn main() {
8292 a * y1 - b * y1 * y2,
8393 c * y1 * y2 - d * y2,
8494 }
85- " ,
86- )
87- . unwrap ( ) ;
95+ " ;
8896
8997 let ( b_true, d_true) = ( 4.0 / 3.0 , 1.0 ) ;
9098 let t_data = ( 0 ..101 )
@@ -94,10 +102,12 @@ pub fn main() {
94102 . p ( [ b_true, d_true] )
95103 . sens_atol ( [ 1e-6 ] )
96104 . sens_rtol ( 1e-6 )
97- . build_from_eqn ( eqn )
105+ . build_from_diffsl ( code )
98106 . unwrap ( ) ;
99- let mut solver = problem. bdf :: < LS > ( ) . unwrap ( ) ;
100- let ys_data = solver. solve_dense ( & t_data) . unwrap ( ) ;
107+ let ys_data = {
108+ let mut solver = problem. bdf :: < LS > ( ) . unwrap ( ) ;
109+ solver. solve_dense ( & t_data) . unwrap ( )
110+ } ;
101111
102112 let cost = Problem {
103113 ys_data,
0 commit comments