@@ -44,10 +44,13 @@ pub struct DiscreteModel<'s> {
4444 is_algebraic : Vec < bool > ,
4545 stop : Option < Tensor < ' s > > ,
4646 state0_input_deps : Vec < ( usize , usize ) > ,
47+ dstate0_input_deps : Vec < ( usize , usize ) > ,
4748 rhs_state_deps : Vec < ( usize , usize ) > ,
4849 rhs_input_deps : Vec < ( usize , usize ) > ,
4950 out_input_deps : Vec < ( usize , usize ) > ,
5051 out_state_deps : Vec < ( usize , usize ) > ,
52+ mass_state_deps : Vec < ( usize , usize ) > ,
53+ mass_input_deps : Vec < ( usize , usize ) > ,
5154}
5255
5356impl fmt:: Display for DiscreteModel < ' _ > {
@@ -105,10 +108,13 @@ impl<'s> DiscreteModel<'s> {
105108 is_algebraic : Vec :: new ( ) ,
106109 stop : None ,
107110 state0_input_deps : Vec :: new ( ) ,
111+ dstate0_input_deps : Vec :: new ( ) ,
108112 rhs_input_deps : Vec :: new ( ) ,
109113 rhs_state_deps : Vec :: new ( ) ,
110114 out_input_deps : Vec :: new ( ) ,
111115 out_state_deps : Vec :: new ( ) ,
116+ mass_state_deps : Vec :: new ( ) ,
117+ mass_input_deps : Vec :: new ( ) ,
112118 }
113119 }
114120
@@ -245,6 +251,15 @@ impl<'s> DiscreteModel<'s> {
245251 match Layout :: concatenate ( & elmts, rank) {
246252 Ok ( layout) => {
247253 let layout = env. new_layout_ptr ( layout) ;
254+ // if sparse, filter out zero elements
255+ let elmts = if layout. is_sparse ( ) {
256+ elmts
257+ . into_iter ( )
258+ . filter ( |e| !matches ! ( e. expr( ) . kind, AstKind :: Number ( 0.0 ) ) )
259+ . collect :: < Vec < _ > > ( )
260+ } else {
261+ elmts
262+ } ;
248263 let tensor = Tensor :: new ( array. name ( ) , elmts, layout, array. indices ( ) . to_vec ( ) ) ;
249264
250265 //check that the number of indices matches the rank
@@ -335,7 +350,7 @@ impl<'s> DiscreteModel<'s> {
335350 }
336351 "dudt" => {
337352 if let Some ( built) =
338- Self :: build_array ( tensor, & mut env, true , TensorType :: Other )
353+ Self :: build_array ( tensor, & mut env, true , TensorType :: StateDot )
339354 {
340355 ret. state_dot = Some ( built) ;
341356 }
@@ -522,6 +537,7 @@ impl<'s> DiscreteModel<'s> {
522537
523538 // store the dependencies in the discrete model
524539 ret. state0_input_deps = map_dep ( & env. state0_input_deps ) ;
540+ ret. dstate0_input_deps = map_dep ( & env. dstate0_input_deps ) ;
525541 ret. rhs_input_deps = map_dep ( ret. rhs . layout ( ) . input_dependencies ( ) ) ;
526542 ret. rhs_state_deps = map_dep ( ret. rhs . layout ( ) . state_dependencies ( ) ) ;
527543 ret. out_input_deps = ret
@@ -534,6 +550,16 @@ impl<'s> DiscreteModel<'s> {
534550 . as_ref ( )
535551 . map ( |o| map_dep ( o. layout ( ) . state_dependencies ( ) ) )
536552 . unwrap_or_default ( ) ;
553+ ret. mass_state_deps = if let Some ( lhs) = & ret. lhs {
554+ map_dep ( lhs. layout ( ) . state_dependencies ( ) )
555+ } else {
556+ Vec :: new ( )
557+ } ;
558+ ret. mass_input_deps = if let Some ( lhs) = & ret. lhs {
559+ map_dep ( lhs. layout ( ) . input_dependencies ( ) )
560+ } else {
561+ Vec :: new ( )
562+ } ;
537563
538564 if env. errs ( ) . is_empty ( ) {
539565 Ok ( ret)
@@ -759,10 +785,13 @@ impl<'s> DiscreteModel<'s> {
759785 is_algebraic,
760786 stop,
761787 state0_input_deps : Vec :: new ( ) ,
788+ dstate0_input_deps : Vec :: new ( ) ,
762789 rhs_state_deps : Vec :: new ( ) ,
763790 rhs_input_deps : Vec :: new ( ) ,
764791 out_input_deps : Vec :: new ( ) ,
765792 out_state_deps : Vec :: new ( ) ,
793+ mass_state_deps : Vec :: new ( ) ,
794+ mass_input_deps : Vec :: new ( ) ,
766795 }
767796 }
768797
@@ -825,6 +854,10 @@ impl<'s> DiscreteModel<'s> {
825854 std:: mem:: take ( & mut self . state0_input_deps )
826855 }
827856
857+ pub fn take_dstate0_input_deps ( & mut self ) -> Vec < ( usize , usize ) > {
858+ std:: mem:: take ( & mut self . dstate0_input_deps )
859+ }
860+
828861 pub fn take_rhs_state_deps ( & mut self ) -> Vec < ( usize , usize ) > {
829862 std:: mem:: take ( & mut self . rhs_state_deps )
830863 }
@@ -840,6 +873,14 @@ impl<'s> DiscreteModel<'s> {
840873 pub fn take_out_state_deps ( & mut self ) -> Vec < ( usize , usize ) > {
841874 std:: mem:: take ( & mut self . out_state_deps )
842875 }
876+
877+ pub fn take_mass_state_deps ( & mut self ) -> Vec < ( usize , usize ) > {
878+ std:: mem:: take ( & mut self . mass_state_deps )
879+ }
880+
881+ pub fn take_mass_input_deps ( & mut self ) -> Vec < ( usize , usize ) > {
882+ std:: mem:: take ( & mut self . mass_input_deps )
883+ }
843884}
844885
845886#[ cfg( test) ]
@@ -1337,6 +1378,84 @@ mod tests {
13371378
13381379 ) ;
13391380
1381+ #[ test]
1382+ fn tensor_state_input_dep_mass_test ( ) {
1383+ let full_text = "
1384+ in_i { (0:2): p = 1 }
1385+ u_i { p_i }
1386+ dudt_i { p_i }
1387+ M_i { dudt_i[1] + p_i[0], dudt_i[0] + p_i[1] }
1388+ F_i { u_i }
1389+ " ;
1390+
1391+ let model = parse_ds_string ( full_text) . unwrap ( ) ;
1392+ let mut discrete_model =
1393+ DiscreteModel :: build ( "tensor_state_input_dep_mass_test" , & model) . unwrap ( ) ;
1394+ assert_eq ! (
1395+ discrete_model. take_state0_input_deps( ) ,
1396+ vec![ ( 0 , 0 ) , ( 1 , 1 ) ]
1397+ ) ;
1398+ assert_eq ! (
1399+ discrete_model. take_dstate0_input_deps( ) ,
1400+ vec![ ( 0 , 0 ) , ( 1 , 1 ) ]
1401+ ) ;
1402+ assert_eq ! (
1403+ discrete_model. take_rhs_state_deps( ) ,
1404+ vec![ ( 0 , 0 ) , ( 1 , 1 ) ] ,
1405+ "failed rhs_state_deps"
1406+ ) ;
1407+ assert_eq ! (
1408+ discrete_model. take_rhs_input_deps( ) ,
1409+ vec![ ] ,
1410+ "failed rhs_input_deps"
1411+ ) ;
1412+ assert_eq ! ( discrete_model. take_out_state_deps( ) , vec![ ] ) ;
1413+ assert_eq ! ( discrete_model. take_out_input_deps( ) , vec![ ] ) ;
1414+ assert_eq ! ( discrete_model. take_mass_state_deps( ) , vec![ ( 0 , 1 ) , ( 1 , 0 ) ] ) ;
1415+ assert_eq ! ( discrete_model. take_mass_input_deps( ) , vec![ ( 0 , 0 ) , ( 1 , 1 ) ] ) ;
1416+ }
1417+
1418+ #[ test]
1419+ fn tensor_state_input_dep_logistic_test ( ) {
1420+ let full_text = "
1421+ in_i { r = 1, k = 1 }
1422+ u_i {
1423+ y = 0.1,
1424+ z = 0,
1425+ }
1426+ dudt_i {
1427+ dydt = 0,
1428+ dzdt = 0,
1429+ }
1430+ M_i {
1431+ dydt,
1432+ 0,
1433+ }
1434+ F_i {
1435+ (r * y) * (1 - (y / k)),
1436+ (2 * y) - z,
1437+ }
1438+ out_i {
1439+ 3 * y,
1440+ 4 * z,
1441+ }
1442+ " ;
1443+ let model = parse_ds_string ( full_text) . unwrap ( ) ;
1444+ let mut discrete_model =
1445+ DiscreteModel :: build ( "tensor_state_input_dep_logistic_test" , & model) . unwrap ( ) ;
1446+ assert_eq ! ( discrete_model. take_state0_input_deps( ) , vec![ ] ) ;
1447+ assert_eq ! ( discrete_model. take_dstate0_input_deps( ) , vec![ ] ) ;
1448+ assert_eq ! (
1449+ discrete_model. take_rhs_state_deps( ) ,
1450+ vec![ ( 0 , 0 ) , ( 1 , 0 ) , ( 1 , 1 ) ]
1451+ ) ;
1452+ assert_eq ! ( discrete_model. take_rhs_input_deps( ) , vec![ ( 0 , 0 ) , ( 0 , 1 ) ] ) ;
1453+ assert_eq ! ( discrete_model. take_out_state_deps( ) , vec![ ( 0 , 0 ) , ( 1 , 1 ) ] ) ;
1454+ assert_eq ! ( discrete_model. take_out_input_deps( ) , vec![ ] ) ;
1455+ assert_eq ! ( discrete_model. take_mass_state_deps( ) , vec![ ( 0 , 0 ) ] ) ;
1456+ assert_eq ! ( discrete_model. take_mass_input_deps( ) , vec![ ] ) ;
1457+ }
1458+
13401459 macro_rules! tensor_state_input_dep_test {
13411460 ( $( $name: ident: $text: literal expect $expected_state_state_deps: expr ; $expected_state_inputs_deps: expr, ) * ) => {
13421461 $(
@@ -1389,6 +1508,7 @@ mod tests {
13891508 tsi_diag_mat_mul2: "A_ij { (0..2, 0..2): p_i } F_i { A_ij * u_j }" expect vec![ ( 0 , 0 ) , ( 1 , 1 ) ] ; vec![ ( 0 , 0 ) , ( 1 , 1 ) ] ,
13901509 tsi_concat: "F_i { (0): u_i[0], (1): p_i[0] }" expect vec![ ( 0 , 0 ) ] ; vec![ ( 1 , 0 ) ] ,
13911510 tsi_concat2: "a_ij { (0,0): u_i[0], (1,1): p_i[0] } F_i { a_ij }" expect vec![ ( 0 , 0 ) ] ; vec![ ( 1 , 0 ) ] ,
1511+ tsi_expr: "a_i { y = u_i[0], z = u_i[1] } b_i { r = p_i[0], k = p_i[1] } F_i { (r * y) * (1 - y / k), (2 * y) - z }" expect vec![ ( 0 , 0 ) , ( 1 , 0 ) , ( 1 , 1 ) ] ; vec![ ( 0 , 0 ) , ( 0 , 1 ) ] ,
13921512 }
13931513
13941514 #[ test]
0 commit comments