Skip to content

Commit 89bde95

Browse files
Merge branch 'main' into fix-solve-root
2 parents 9aa44d2 + 88c5654 commit 89bde95

File tree

3 files changed

+69
-64
lines changed

3 files changed

+69
-64
lines changed

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ plotly = { version = "0.13.5" }
2525
argmin = { version = "0.11.0" }
2626
argmin-math = { version = "0.5.1" }
2727
argmin-observer-slog = { version = "0.2.0" }
28-
ort = "=2.0.0-rc.9"
29-
ort-sys = { version = "=2.0.0-rc.9", default-features = false }
28+
ort = "=2.0.0-rc.10"
29+
ort-sys = { version = "=2.0.0-rc.10", default-features = false }
3030

3131
[profile.profiling]
3232
inherits = "release"

examples/neural-ode-weather-prediction/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ ort = { workspace = true, optional = true }
2020
ort-sys = { workspace = true, optional = true }
2121
ndarray = "0.16.1"
2222
csv = "1.3.1"
23-
rand = "0.9.0"
23+
rand = "0.9.0"

examples/neural-ode-weather-prediction/src/main.rs

Lines changed: 66 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use ndarray::Array1;
1515
use ort::error::Result;
1616
use ort::inputs;
1717
use ort::session::{builder::GraphOptimizationLevel, Session};
18+
use ort::value::{TensorRef, ValueType};
1819
use plotly::common::{DashType, Line, Mode};
1920
use plotly::layout::{Axis, GridPattern, LayoutGrid};
2021
use plotly::{Layout, Plot, Scatter};
@@ -31,10 +32,10 @@ const BASE_DATA_DIR: &str = "examples/neural-ode-weather-prediction/src/data/";
3132
const BASE_OUTPUT_DIR: &str = "examples/neural-ode-weather-prediction/";
3233

3334
struct NeuralOde {
34-
rhs: Session,
35-
rhs_jac_mul: Session,
36-
rhs_jac_transpose_mul: Session,
37-
rhs_sens_transpose_mul: Session,
35+
rhs: RefCell<Session>,
36+
rhs_jac_mul: RefCell<Session>,
37+
rhs_jac_transpose_mul: RefCell<Session>,
38+
rhs_sens_transpose_mul: RefCell<Session>,
3839
input_y: RefCell<Array1<f32>>,
3940
input_v: RefCell<Array1<f32>>,
4041
input_p: Array1<f32>,
@@ -58,8 +59,17 @@ impl NeuralOde {
5859
let mut nparams = 0;
5960
for input in rhs.inputs.iter() {
6061
if input.name == "p" {
61-
nparams = input.input_type.tensor_dimensions().unwrap()[0] as usize;
62-
break;
62+
if let ValueType::Tensor { shape, .. } = &input.input_type {
63+
let dim = shape
64+
.get(0)
65+
.copied()
66+
.expect("p input should have at least one dimension");
67+
if dim < 0 {
68+
panic!("p input has dynamic dimension; cannot infer parameter count");
69+
}
70+
nparams = dim as usize;
71+
break;
72+
}
6373
}
6474
}
6575
let mut rng = rand::rng();
@@ -69,10 +79,10 @@ impl NeuralOde {
6979

7080
Ok(Self {
7181
y0,
72-
rhs,
73-
rhs_jac_mul,
74-
rhs_jac_transpose_mul,
75-
rhs_sens_transpose_mul,
82+
rhs: RefCell::new(rhs),
83+
rhs_jac_mul: RefCell::new(rhs_jac_mul),
84+
rhs_jac_transpose_mul: RefCell::new(rhs_jac_transpose_mul),
85+
rhs_sens_transpose_mul: RefCell::new(rhs_sens_transpose_mul),
7686
input_p: params,
7787
input_v: RefCell::new(y0_ndarray.clone()),
7888
input_y: RefCell::new(y0_ndarray),
@@ -200,21 +210,19 @@ impl NonLinearOp for Rhs<'_> {
200210
.iter_mut()
201211
.zip(x.inner().iter())
202212
.for_each(|(y, x)| *y = *x as f32);
203-
let outputs = self
204-
.0
205-
.rhs
206-
.run(
207-
inputs![
208-
"p" => self.0.input_p.view(),
209-
"y" => y_input.view(),
210-
]
211-
.unwrap(),
212-
)
213+
let p_slice = self.0.input_p.as_slice().unwrap();
214+
let y_slice = y_input.as_slice().unwrap();
215+
let mut rhs = self.0.rhs.borrow_mut();
216+
let outputs = rhs
217+
.run(inputs![
218+
"p" => TensorRef::from_array_view(([p_slice.len()], p_slice)).unwrap(),
219+
"y" => TensorRef::from_array_view(([y_slice.len()], y_slice)).unwrap(),
220+
])
213221
.unwrap();
214-
let y_data = outputs["Identity_1:0"].try_extract_tensor::<f32>().unwrap();
222+
let (_shape, y_data) = outputs["Identity_1:0"].try_extract_tensor::<f32>().unwrap();
215223
y.inner_mut()
216224
.iter_mut()
217-
.zip(y_data.as_slice().unwrap())
225+
.zip(y_data.iter())
218226
.for_each(|(y, x)| *y = *x as f64);
219227
}
220228
}
@@ -231,22 +239,21 @@ impl NonLinearOpJacobian for Rhs<'_> {
231239
.iter_mut()
232240
.zip(v.inner().iter())
233241
.for_each(|(v, x)| *v = *x as f32);
234-
let outputs = self
235-
.0
236-
.rhs_jac_mul
237-
.run(
238-
inputs![
239-
"y" => y_input.view(),
240-
"v" => v_input.view(),
241-
"p" => self.0.input_p.view(),
242-
]
243-
.unwrap(),
244-
)
242+
let p_slice = self.0.input_p.as_slice().unwrap();
243+
let y_slice = y_input.as_slice().unwrap();
244+
let v_slice = v_input.as_slice().unwrap();
245+
let mut rhs_jac_mul = self.0.rhs_jac_mul.borrow_mut();
246+
let outputs = rhs_jac_mul
247+
.run(inputs![
248+
"y" => TensorRef::from_array_view(([y_slice.len()], y_slice)).unwrap(),
249+
"v" => TensorRef::from_array_view(([v_slice.len()], v_slice)).unwrap(),
250+
"p" => TensorRef::from_array_view(([p_slice.len()], p_slice)).unwrap(),
251+
])
245252
.unwrap();
246-
let y_data = outputs["Identity_1:0"].try_extract_tensor::<f32>().unwrap();
253+
let (_shape, y_data) = outputs["Identity_1:0"].try_extract_tensor::<f32>().unwrap();
247254
y.inner_mut()
248255
.iter_mut()
249-
.zip(y_data.as_slice().unwrap())
256+
.zip(y_data.iter())
250257
.for_each(|(y, x)| *y = *x as f64);
251258
}
252259
}
@@ -263,22 +270,21 @@ impl NonLinearOpAdjoint for Rhs<'_> {
263270
.iter_mut()
264271
.zip(v.inner().iter())
265272
.for_each(|(v, x)| *v = *x as f32);
266-
let outputs = self
267-
.0
268-
.rhs_jac_transpose_mul
269-
.run(
270-
inputs![
271-
"y" => y_input.view(),
272-
"v" => v_input.view(),
273-
"p" => self.0.input_p.view(),
274-
]
275-
.unwrap(),
276-
)
273+
let p_slice = self.0.input_p.as_slice().unwrap();
274+
let y_slice = y_input.as_slice().unwrap();
275+
let v_slice = v_input.as_slice().unwrap();
276+
let mut rhs_jac_transpose_mul = self.0.rhs_jac_transpose_mul.borrow_mut();
277+
let outputs = rhs_jac_transpose_mul
278+
.run(inputs![
279+
"y" => TensorRef::from_array_view(([y_slice.len()], y_slice)).unwrap(),
280+
"v" => TensorRef::from_array_view(([v_slice.len()], v_slice)).unwrap(),
281+
"p" => TensorRef::from_array_view(([p_slice.len()], p_slice)).unwrap(),
282+
])
277283
.unwrap();
278-
let y_data = outputs["Identity_1:0"].try_extract_tensor::<f32>().unwrap();
284+
let (_shape, y_data) = outputs["Identity_1:0"].try_extract_tensor::<f32>().unwrap();
279285
y.inner_mut()
280286
.iter_mut()
281-
.zip(y_data.as_slice().unwrap())
287+
.zip(y_data.iter())
282288
.for_each(|(y, x)| *y = *x as f64);
283289
}
284290
}
@@ -295,22 +301,21 @@ impl NonLinearOpSensAdjoint for Rhs<'_> {
295301
.iter_mut()
296302
.zip(v.inner().iter())
297303
.for_each(|(v, x)| *v = *x as f32);
298-
let outputs = self
299-
.0
300-
.rhs_sens_transpose_mul
301-
.run(
302-
inputs![
303-
"y" => y_input.view(),
304-
"v" => v_input.view(),
305-
"p" => self.0.input_p.view(),
306-
]
307-
.unwrap(),
308-
)
304+
let p_slice = self.0.input_p.as_slice().unwrap();
305+
let y_slice = y_input.as_slice().unwrap();
306+
let v_slice = v_input.as_slice().unwrap();
307+
let mut rhs_sens_transpose_mul = self.0.rhs_sens_transpose_mul.borrow_mut();
308+
let outputs = rhs_sens_transpose_mul
309+
.run(inputs![
310+
"y" => TensorRef::from_array_view(([y_slice.len()], y_slice)).unwrap(),
311+
"v" => TensorRef::from_array_view(([v_slice.len()], v_slice)).unwrap(),
312+
"p" => TensorRef::from_array_view(([p_slice.len()], p_slice)).unwrap(),
313+
])
309314
.unwrap();
310-
let y_data = outputs["Identity_1:0"].try_extract_tensor::<f32>().unwrap();
315+
let (_shape, y_data) = outputs["Identity_1:0"].try_extract_tensor::<f32>().unwrap();
311316
y.inner_mut()
312317
.iter_mut()
313-
.zip(y_data.as_slice().unwrap())
318+
.zip(y_data.iter())
314319
.for_each(|(y, x)| *y = *x as f64);
315320
}
316321
}

0 commit comments

Comments
 (0)