Skip to content

Commit 36eaf3d

Browse files
feat: add mass state and input dependencies, filter sparse zeros (#95)
* feat: add mass deps * feat: track mass dependencies * feat: filter out zeros in sparse tensors
1 parent ce2c02e commit 36eaf3d

File tree

5 files changed

+156
-13
lines changed

5 files changed

+156
-13
lines changed

diffsl/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "diffsl"
3-
version = "0.8.0"
3+
version = "0.8.1"
44
edition.workspace = true
55
license.workspace = true
66
authors.workspace = true

diffsl/src/discretise/discrete_model.rs

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,13 @@ pub struct DiscreteModel<'s> {
4444
is_algebraic: Vec<bool>,
4545
stop: Option<Tensor<'s>>,
4646
state0_input_deps: Vec<(usize, usize)>,
47+
dstate0_input_deps: Vec<(usize, usize)>,
4748
rhs_state_deps: Vec<(usize, usize)>,
4849
rhs_input_deps: Vec<(usize, usize)>,
4950
out_input_deps: Vec<(usize, usize)>,
5051
out_state_deps: Vec<(usize, usize)>,
52+
mass_state_deps: Vec<(usize, usize)>,
53+
mass_input_deps: Vec<(usize, usize)>,
5154
}
5255

5356
impl fmt::Display for DiscreteModel<'_> {
@@ -105,10 +108,13 @@ impl<'s> DiscreteModel<'s> {
105108
is_algebraic: Vec::new(),
106109
stop: None,
107110
state0_input_deps: Vec::new(),
111+
dstate0_input_deps: Vec::new(),
108112
rhs_input_deps: Vec::new(),
109113
rhs_state_deps: Vec::new(),
110114
out_input_deps: Vec::new(),
111115
out_state_deps: Vec::new(),
116+
mass_state_deps: Vec::new(),
117+
mass_input_deps: Vec::new(),
112118
}
113119
}
114120

@@ -245,6 +251,15 @@ impl<'s> DiscreteModel<'s> {
245251
match Layout::concatenate(&elmts, rank) {
246252
Ok(layout) => {
247253
let layout = env.new_layout_ptr(layout);
254+
// if sparse, filter out zero elements
255+
let elmts = if layout.is_sparse() {
256+
elmts
257+
.into_iter()
258+
.filter(|e| !matches!(e.expr().kind, AstKind::Number(0.0)))
259+
.collect::<Vec<_>>()
260+
} else {
261+
elmts
262+
};
248263
let tensor = Tensor::new(array.name(), elmts, layout, array.indices().to_vec());
249264

250265
//check that the number of indices matches the rank
@@ -335,7 +350,7 @@ impl<'s> DiscreteModel<'s> {
335350
}
336351
"dudt" => {
337352
if let Some(built) =
338-
Self::build_array(tensor, &mut env, true, TensorType::Other)
353+
Self::build_array(tensor, &mut env, true, TensorType::StateDot)
339354
{
340355
ret.state_dot = Some(built);
341356
}
@@ -522,6 +537,7 @@ impl<'s> DiscreteModel<'s> {
522537

523538
// store the dependencies in the discrete model
524539
ret.state0_input_deps = map_dep(&env.state0_input_deps);
540+
ret.dstate0_input_deps = map_dep(&env.dstate0_input_deps);
525541
ret.rhs_input_deps = map_dep(ret.rhs.layout().input_dependencies());
526542
ret.rhs_state_deps = map_dep(ret.rhs.layout().state_dependencies());
527543
ret.out_input_deps = ret
@@ -534,6 +550,16 @@ impl<'s> DiscreteModel<'s> {
534550
.as_ref()
535551
.map(|o| map_dep(o.layout().state_dependencies()))
536552
.unwrap_or_default();
553+
ret.mass_state_deps = if let Some(lhs) = &ret.lhs {
554+
map_dep(lhs.layout().state_dependencies())
555+
} else {
556+
Vec::new()
557+
};
558+
ret.mass_input_deps = if let Some(lhs) = &ret.lhs {
559+
map_dep(lhs.layout().input_dependencies())
560+
} else {
561+
Vec::new()
562+
};
537563

538564
if env.errs().is_empty() {
539565
Ok(ret)
@@ -759,10 +785,13 @@ impl<'s> DiscreteModel<'s> {
759785
is_algebraic,
760786
stop,
761787
state0_input_deps: Vec::new(),
788+
dstate0_input_deps: Vec::new(),
762789
rhs_state_deps: Vec::new(),
763790
rhs_input_deps: Vec::new(),
764791
out_input_deps: Vec::new(),
765792
out_state_deps: Vec::new(),
793+
mass_state_deps: Vec::new(),
794+
mass_input_deps: Vec::new(),
766795
}
767796
}
768797

@@ -825,6 +854,10 @@ impl<'s> DiscreteModel<'s> {
825854
std::mem::take(&mut self.state0_input_deps)
826855
}
827856

857+
pub fn take_dstate0_input_deps(&mut self) -> Vec<(usize, usize)> {
858+
std::mem::take(&mut self.dstate0_input_deps)
859+
}
860+
828861
pub fn take_rhs_state_deps(&mut self) -> Vec<(usize, usize)> {
829862
std::mem::take(&mut self.rhs_state_deps)
830863
}
@@ -840,6 +873,14 @@ impl<'s> DiscreteModel<'s> {
840873
pub fn take_out_state_deps(&mut self) -> Vec<(usize, usize)> {
841874
std::mem::take(&mut self.out_state_deps)
842875
}
876+
877+
pub fn take_mass_state_deps(&mut self) -> Vec<(usize, usize)> {
878+
std::mem::take(&mut self.mass_state_deps)
879+
}
880+
881+
pub fn take_mass_input_deps(&mut self) -> Vec<(usize, usize)> {
882+
std::mem::take(&mut self.mass_input_deps)
883+
}
843884
}
844885

845886
#[cfg(test)]
@@ -1337,6 +1378,84 @@ mod tests {
13371378

13381379
);
13391380

1381+
#[test]
1382+
fn tensor_state_input_dep_mass_test() {
1383+
let full_text = "
1384+
in_i { (0:2): p = 1 }
1385+
u_i { p_i }
1386+
dudt_i { p_i }
1387+
M_i { dudt_i[1] + p_i[0], dudt_i[0] + p_i[1] }
1388+
F_i { u_i }
1389+
";
1390+
1391+
let model = parse_ds_string(full_text).unwrap();
1392+
let mut discrete_model =
1393+
DiscreteModel::build("tensor_state_input_dep_mass_test", &model).unwrap();
1394+
assert_eq!(
1395+
discrete_model.take_state0_input_deps(),
1396+
vec![(0, 0), (1, 1)]
1397+
);
1398+
assert_eq!(
1399+
discrete_model.take_dstate0_input_deps(),
1400+
vec![(0, 0), (1, 1)]
1401+
);
1402+
assert_eq!(
1403+
discrete_model.take_rhs_state_deps(),
1404+
vec![(0, 0), (1, 1)],
1405+
"failed rhs_state_deps"
1406+
);
1407+
assert_eq!(
1408+
discrete_model.take_rhs_input_deps(),
1409+
vec![],
1410+
"failed rhs_input_deps"
1411+
);
1412+
assert_eq!(discrete_model.take_out_state_deps(), vec![]);
1413+
assert_eq!(discrete_model.take_out_input_deps(), vec![]);
1414+
assert_eq!(discrete_model.take_mass_state_deps(), vec![(0, 1), (1, 0)]);
1415+
assert_eq!(discrete_model.take_mass_input_deps(), vec![(0, 0), (1, 1)]);
1416+
}
1417+
1418+
#[test]
1419+
fn tensor_state_input_dep_logistic_test() {
1420+
let full_text = "
1421+
in_i { r = 1, k = 1 }
1422+
u_i {
1423+
y = 0.1,
1424+
z = 0,
1425+
}
1426+
dudt_i {
1427+
dydt = 0,
1428+
dzdt = 0,
1429+
}
1430+
M_i {
1431+
dydt,
1432+
0,
1433+
}
1434+
F_i {
1435+
(r * y) * (1 - (y / k)),
1436+
(2 * y) - z,
1437+
}
1438+
out_i {
1439+
3 * y,
1440+
4 * z,
1441+
}
1442+
";
1443+
let model = parse_ds_string(full_text).unwrap();
1444+
let mut discrete_model =
1445+
DiscreteModel::build("tensor_state_input_dep_logistic_test", &model).unwrap();
1446+
assert_eq!(discrete_model.take_state0_input_deps(), vec![]);
1447+
assert_eq!(discrete_model.take_dstate0_input_deps(), vec![]);
1448+
assert_eq!(
1449+
discrete_model.take_rhs_state_deps(),
1450+
vec![(0, 0), (1, 0), (1, 1)]
1451+
);
1452+
assert_eq!(discrete_model.take_rhs_input_deps(), vec![(0, 0), (0, 1)]);
1453+
assert_eq!(discrete_model.take_out_state_deps(), vec![(0, 0), (1, 1)]);
1454+
assert_eq!(discrete_model.take_out_input_deps(), vec![]);
1455+
assert_eq!(discrete_model.take_mass_state_deps(), vec![(0, 0)]);
1456+
assert_eq!(discrete_model.take_mass_input_deps(), vec![]);
1457+
}
1458+
13401459
macro_rules! tensor_state_input_dep_test {
13411460
($($name:ident: $text:literal expect $expected_state_state_deps:expr ; $expected_state_inputs_deps:expr,)*) => {
13421461
$(
@@ -1389,6 +1508,7 @@ mod tests {
13891508
tsi_diag_mat_mul2: "A_ij { (0..2, 0..2): p_i } F_i { A_ij * u_j }" expect vec![(0,0), (1,1)] ; vec![(0,0), (1,1)],
13901509
tsi_concat: "F_i { (0): u_i[0], (1): p_i[0] }" expect vec![(0,0)] ; vec![(1,0)],
13911510
tsi_concat2: "a_ij { (0,0): u_i[0], (1,1): p_i[0] } F_i { a_ij }" expect vec![(0,0)] ; vec![(1,0)],
1511+
tsi_expr: "a_i { y = u_i[0], z = u_i[1] } b_i { r = p_i[0], k = p_i[1] } F_i { (r * y) * (1 - y / k), (2 * y) - z }" expect vec![(0,0), (1,0), (1,1)] ; vec![(0,0), (0,1)],
13921512
}
13931513

13941514
#[test]

diffsl/src/discretise/env.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ pub struct Env {
5353
errs: ValidationErrors,
5454
vars: HashMap<String, EnvVar>,
5555
pub(crate) state0_input_deps: Vec<NonZero>,
56+
pub(crate) dstate0_input_deps: Vec<NonZero>,
5657
}
5758

5859
impl Env {
@@ -74,6 +75,7 @@ impl Env {
7475
vars,
7576
current_span: None,
7677
state0_input_deps: vec![],
78+
dstate0_input_deps: vec![],
7779
}
7880
}
7981

diffsl/src/discretise/layout.rs

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@ use std::{
55
convert::AsRef,
66
fmt,
77
hash::{Hash, Hasher},
8+
iter::zip,
89
mem,
910
ops::Deref,
1011
sync::Arc,
1112
};
1213

13-
use crate::discretise::Env;
14+
use crate::{ast::AstKind, discretise::Env};
1415

1516
use super::{broadcast_shapes, shape::Shape, Index, TensorBlock};
1617

@@ -24,6 +25,7 @@ pub enum LayoutKind {
2425
#[derive(Debug, Clone, Copy)]
2526
pub enum TensorType {
2627
State,
28+
StateDot,
2729
Input,
2830
Other,
2931
}
@@ -163,12 +165,12 @@ impl Layout {
163165
/// Add state or input dependencies to this layout based on the tensor type.
164166
pub fn add_tensor_dependencies(&mut self, tensor_type: TensorType, start: i64, env: &mut Env) {
165167
let indices = match tensor_type {
166-
TensorType::State | TensorType::Input => {
168+
TensorType::State | TensorType::StateDot | TensorType::Input => {
167169
let mut deps = Vec::new();
168-
let n_states = *self.shape().get(0).unwrap_or(&1);
169-
for i in 0..n_states {
170-
let index = Index::from(vec![(i as i64) + start]);
171-
deps.push((index, i));
170+
let n_states = *self.shape().get(0).unwrap_or(&1) as i64;
171+
for i in 0_i64..n_states {
172+
let index = Index::from(vec![i]);
173+
deps.push((index, (i + start) as usize));
172174
}
173175
deps
174176
}
@@ -184,6 +186,15 @@ impl Layout {
184186
// store the state0 input dependencies in the env since we don't want to propagate them further
185187
env.state0_input_deps = mem::take(&mut self.input_deps);
186188
}
189+
TensorType::StateDot => {
190+
assert!(
191+
self.state_deps.is_empty(),
192+
"state dot tensor layout should not already have state dependencies",
193+
);
194+
self.state_deps = indices;
195+
// store the dstate0 input dependencies in the env since we don't want to propagate them further
196+
env.dstate0_input_deps = mem::take(&mut self.input_deps);
197+
}
187198
TensorType::Input => {
188199
assert! {
189200
self.input_deps.is_empty(),
@@ -657,6 +668,10 @@ impl Layout {
657668
.iter()
658669
.map(|x| x.layout().as_ref())
659670
.collect::<Vec<_>>();
671+
let is_zero = elmts
672+
.iter()
673+
.map(|x| matches!(x.expr().kind, AstKind::Number(0.0)))
674+
.collect::<Vec<_>>();
660675
let starts = elmts.iter().map(|x| x.start()).collect::<Vec<_>>();
661676

662677
// if there are no layouts then return an empty layout
@@ -800,7 +815,13 @@ impl Layout {
800815
kind: LayoutKind::Sparse,
801816
n_dense_axes,
802817
};
803-
for (layout, start) in std::iter::zip(layouts.iter(), starts.iter()) {
818+
for ((layout, start), is_zero) in
819+
zip(zip(layouts.iter(), starts.iter()), is_zero.iter())
820+
{
821+
if *is_zero {
822+
continue;
823+
}
824+
804825
// convert to sparse
805826
new_layout.indices.extend(layout.indices().map(|mut x| {
806827
for (i, xi) in x.iter_mut().enumerate() {

diffsl/src/execution/compiler.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1532,9 +1532,9 @@ mod tests {
15321532
sparse_mat_mul12: "A_ij { (0,0):1, (0,1):1, (1,1):1, (2,2):1, (3,3):1, (4,4):1, (5,5):1 } b2_i { (1:6): 1 } r_i { A_ij * b2_j }" expect "r" vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
15331533
sparse_mat_mul8: "A_ij { (0,1): -0.5, (0,0): 1.5, (6,5): 1.5, (6,4): -0.5 } b_i { (0:6): 1 } r_i { A_ij * b_j } A1_ij { (0,1): 1, (0,2): 1, (1,1): 1, (2,2): 1, (3,3): 1, (4,4): 1, (5,5): 1, (6,6): 1 } b2_i { (0:7): 1 } b3_i { (0:7):2 } r2_i { A1_ij * (r_j + b2_j) * b3_j }" expect "r2" vec![4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0],
15341534
sparse_dense_concat: "a_i { (0): 1, (2): 3 } b_i { 4, 5 } r_i { a_i, b_i }" expect "r" vec![1., 3., 4., 5.],
1535-
sparse_mat_mul9: "A_ij { (0,1): -0.5, (0,0): 1.5, (1,1): 0, (2,2): 0, (3,3): 0, (4,4): 0, (5,5): 0, (6,5): 1.5, (6,4): -0.5 } b_i { (0:6): 1 } r_i { A_ij * b_j }" expect "r" vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
1535+
sparse_mat_mul9: "A_ij { (0,1): -0.5, (0,0): 1.5, (1,1): 0, (2,2): 0, (3,3): 0, (4,4): 0, (5,5): 0, (6,5): 1.5, (6,4): -0.5 } b_i { (0:6): 1 } r_i { A_ij * b_j }" expect "r" vec![1.0, 1.0],
15361536
sparse_mat_mul10: "A_ij { (0,0): 1.5, (0,1): -0.5, (6,4): -0.5, (6,5): 1.5 } b_i { (0:6): 1 } r_i { A_ij * b_j }" expect "r" vec![1.0, 1.0],
1537-
sparse_mat_mul11: "A_ij { (0,0): 1.5, (0,1): -0.5, (1,1): 0, (2,2): 0, (3,3): 0, (4,4): 0, (5,5): 0, (6,4): -0.5, (6,5): 1.5 } b_i { (0:6): 1 } r_i { A_ij * b_j }" expect "r" vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
1537+
sparse_mat_mul11: "A_ij { (0,0): 1.5, (0,1): -0.5, (1,1): 0, (2,2): 0, (3,3): 0, (4,4): 0, (5,5): 0, (6,4): -0.5, (6,5): 1.5 } b_i { (0:6): 1 } r_i { A_ij * b_j }" expect "r" vec![1.0, 1.0],
15381538
sparse_nonsquare_mat_vec_mul: "A_ij { (0, 0): 1, (0, 1): 4, (1, 2): 2 } b_j { (0:3): 5 } r_i { A_ij * b_j }" expect "r" vec![25.0, 10.0],
15391539
sparse_nonsquare_mat_vec_mul2: "A_ij { (0, 0): 1, (0, 1): 4, (1, 2): 2 } b_j { (2): 5 } r_i { A_ij * b_j } B_ij { (0..2,0..2): 1 } s_i { B_ij * max(r_j, 1) }" expect "s" vec![1.0, 10.0],
15401540
max_sparse_vec: "A_ij { (0..2,0..2): 1 } b_j { (1): 5 } r_i { A_ij * max(b_j, 1) }" expect "r" vec![1.0, 5.0],
@@ -1543,9 +1543,9 @@ mod tests {
15431543
contract_to_mat_vec: "A_ij { (0, 0): 1, (1, 0): 3, (1, 1): 4 } B_ij { (1, 1): 2 } b_i { B_ij } r_i { A_ij * b_j }" expect "r" vec![8.],
15441544
sparse_mat_vec_mul7: "A_ij { (0, 1): 4, (1, 2): 2, (2, 2): 1 } b_i { (2): 5 } r_i { A_ij * (1 + b_j) }" expect "r" vec![4., 12., 6.],
15451545
sparse_mat_vec_mul6: "A_ij { (0, 1): 4, (1, 2): 2, (2, 2): 1 } b_i { (2): 5 } r_i { A_ij * (1 * b_j) }" expect "r" vec![10., 5.],
1546-
sparse_mat_vec_mul3: "A_ij { (0, 1): 4, (1, 2): 2, (2, 2): 0 } b_i { (2): 5 } c_j { (0:3): 1 } r_i { A_ij * (b_j + c_j) }" expect "r" vec![4., 12., 0.],
1546+
sparse_mat_vec_mul3: "A_ij { (0, 1): 4, (1, 2): 2, (2, 2): 0 } b_i { (2): 5 } c_j { (0:3): 1 } r_i { A_ij * (b_j + c_j) }" expect "r" vec![4., 12.],
15471547
sparse_mat_vec_mul5: "A_ij { (1, 1): 2 } b_j { (1): 3 } r_i { A_ij * (1 * b_j) }" expect "r" vec![6.0],
1548-
sparse_mat_vec_mul4: "A_ij { (0, 1): 4, (1, 2): 2, (2, 2): 0 } b_i { (0): 2, (2): 5 } c_j { (0:3): 1 } r_i { A_ij * (b_j + c_j) }" expect "r" vec![4., 12., 0.],
1548+
sparse_mat_vec_mul4: "A_ij { (0, 1): 4, (1, 2): 2, (2, 2): 0 } b_i { (0): 2, (2): 5 } c_j { (0:3): 1 } r_i { A_ij * (b_j + c_j) }" expect "r" vec![4., 12.],
15491549
sparse_mat_vec_mul2: "A_ij { (0, 1): 4, (1, 0): 2 } b_i { (1): 5 } c_j { (0:1): 1, (1:2): 1 } r_i { A_ij * (b_j + c_j) }" expect "r" vec![24.0, 2.0],
15501550
sparse_mat_vec_mul: "A_ij { (1, 1): 2 } b_j { (1): 3 } r_i { A_ij * b_j }" expect "r" vec![6.0],
15511551
sparse_broadcast_to_sparse: "A_i { (1): 2 } B_ij { (0:2, 0:2): A_i }" expect "B" vec![2.0],

0 commit comments

Comments
 (0)