@@ -15,6 +15,7 @@ use ndarray::Array1;
1515use ort:: error:: Result ;
1616use ort:: inputs;
1717use ort:: session:: { builder:: GraphOptimizationLevel , Session } ;
18+ use ort:: value:: { TensorRef , ValueType } ;
1819use plotly:: common:: { DashType , Line , Mode } ;
1920use plotly:: layout:: { Axis , GridPattern , LayoutGrid } ;
2021use plotly:: { Layout , Plot , Scatter } ;
@@ -31,10 +32,10 @@ const BASE_DATA_DIR: &str = "examples/neural-ode-weather-prediction/src/data/";
3132const BASE_OUTPUT_DIR : & str = "examples/neural-ode-weather-prediction/" ;
3233
3334struct NeuralOde {
34- rhs : Session ,
35- rhs_jac_mul : Session ,
36- rhs_jac_transpose_mul : Session ,
37- rhs_sens_transpose_mul : Session ,
35+ rhs : RefCell < Session > ,
36+ rhs_jac_mul : RefCell < Session > ,
37+ rhs_jac_transpose_mul : RefCell < Session > ,
38+ rhs_sens_transpose_mul : RefCell < Session > ,
3839 input_y : RefCell < Array1 < f32 > > ,
3940 input_v : RefCell < Array1 < f32 > > ,
4041 input_p : Array1 < f32 > ,
@@ -58,8 +59,17 @@ impl NeuralOde {
5859 let mut nparams = 0 ;
5960 for input in rhs. inputs . iter ( ) {
6061 if input. name == "p" {
61- nparams = input. input_type . tensor_dimensions ( ) . unwrap ( ) [ 0 ] as usize ;
62- break ;
62+ if let ValueType :: Tensor { shape, .. } = & input. input_type {
63+ let dim = shape
64+ . get ( 0 )
65+ . copied ( )
66+ . expect ( "p input should have at least one dimension" ) ;
67+ if dim < 0 {
68+ panic ! ( "p input has dynamic dimension; cannot infer parameter count" ) ;
69+ }
70+ nparams = dim as usize ;
71+ break ;
72+ }
6373 }
6474 }
6575 let mut rng = rand:: rng ( ) ;
@@ -69,10 +79,10 @@ impl NeuralOde {
6979
7080 Ok ( Self {
7181 y0,
72- rhs,
73- rhs_jac_mul,
74- rhs_jac_transpose_mul,
75- rhs_sens_transpose_mul,
82+ rhs : RefCell :: new ( rhs ) ,
83+ rhs_jac_mul : RefCell :: new ( rhs_jac_mul ) ,
84+ rhs_jac_transpose_mul : RefCell :: new ( rhs_jac_transpose_mul ) ,
85+ rhs_sens_transpose_mul : RefCell :: new ( rhs_sens_transpose_mul ) ,
7686 input_p : params,
7787 input_v : RefCell :: new ( y0_ndarray. clone ( ) ) ,
7888 input_y : RefCell :: new ( y0_ndarray) ,
@@ -200,21 +210,19 @@ impl NonLinearOp for Rhs<'_> {
200210 . iter_mut ( )
201211 . zip ( x. inner ( ) . iter ( ) )
202212 . for_each ( |( y, x) | * y = * x as f32 ) ;
203- let outputs = self
204- . 0
205- . rhs
206- . run (
207- inputs ! [
208- "p" => self . 0 . input_p. view( ) ,
209- "y" => y_input. view( ) ,
210- ]
211- . unwrap ( ) ,
212- )
213+ let p_slice = self . 0 . input_p . as_slice ( ) . unwrap ( ) ;
214+ let y_slice = y_input. as_slice ( ) . unwrap ( ) ;
215+ let mut rhs = self . 0 . rhs . borrow_mut ( ) ;
216+ let outputs = rhs
217+ . run ( inputs ! [
218+ "p" => TensorRef :: from_array_view( ( [ p_slice. len( ) ] , p_slice) ) . unwrap( ) ,
219+ "y" => TensorRef :: from_array_view( ( [ y_slice. len( ) ] , y_slice) ) . unwrap( ) ,
220+ ] )
213221 . unwrap ( ) ;
214- let y_data = outputs[ "Identity_1:0" ] . try_extract_tensor :: < f32 > ( ) . unwrap ( ) ;
222+ let ( _shape , y_data) = outputs[ "Identity_1:0" ] . try_extract_tensor :: < f32 > ( ) . unwrap ( ) ;
215223 y. inner_mut ( )
216224 . iter_mut ( )
217- . zip ( y_data. as_slice ( ) . unwrap ( ) )
225+ . zip ( y_data. iter ( ) )
218226 . for_each ( |( y, x) | * y = * x as f64 ) ;
219227 }
220228}
@@ -231,22 +239,21 @@ impl NonLinearOpJacobian for Rhs<'_> {
231239 . iter_mut ( )
232240 . zip ( v. inner ( ) . iter ( ) )
233241 . for_each ( |( v, x) | * v = * x as f32 ) ;
234- let outputs = self
235- . 0
236- . rhs_jac_mul
237- . run (
238- inputs ! [
239- "y" => y_input. view( ) ,
240- "v" => v_input. view( ) ,
241- "p" => self . 0 . input_p. view( ) ,
242- ]
243- . unwrap ( ) ,
244- )
242+ let p_slice = self . 0 . input_p . as_slice ( ) . unwrap ( ) ;
243+ let y_slice = y_input. as_slice ( ) . unwrap ( ) ;
244+ let v_slice = v_input. as_slice ( ) . unwrap ( ) ;
245+ let mut rhs_jac_mul = self . 0 . rhs_jac_mul . borrow_mut ( ) ;
246+ let outputs = rhs_jac_mul
247+ . run ( inputs ! [
248+ "y" => TensorRef :: from_array_view( ( [ y_slice. len( ) ] , y_slice) ) . unwrap( ) ,
249+ "v" => TensorRef :: from_array_view( ( [ v_slice. len( ) ] , v_slice) ) . unwrap( ) ,
250+ "p" => TensorRef :: from_array_view( ( [ p_slice. len( ) ] , p_slice) ) . unwrap( ) ,
251+ ] )
245252 . unwrap ( ) ;
246- let y_data = outputs[ "Identity_1:0" ] . try_extract_tensor :: < f32 > ( ) . unwrap ( ) ;
253+ let ( _shape , y_data) = outputs[ "Identity_1:0" ] . try_extract_tensor :: < f32 > ( ) . unwrap ( ) ;
247254 y. inner_mut ( )
248255 . iter_mut ( )
249- . zip ( y_data. as_slice ( ) . unwrap ( ) )
256+ . zip ( y_data. iter ( ) )
250257 . for_each ( |( y, x) | * y = * x as f64 ) ;
251258 }
252259}
@@ -263,22 +270,21 @@ impl NonLinearOpAdjoint for Rhs<'_> {
263270 . iter_mut ( )
264271 . zip ( v. inner ( ) . iter ( ) )
265272 . for_each ( |( v, x) | * v = * x as f32 ) ;
266- let outputs = self
267- . 0
268- . rhs_jac_transpose_mul
269- . run (
270- inputs ! [
271- "y" => y_input. view( ) ,
272- "v" => v_input. view( ) ,
273- "p" => self . 0 . input_p. view( ) ,
274- ]
275- . unwrap ( ) ,
276- )
273+ let p_slice = self . 0 . input_p . as_slice ( ) . unwrap ( ) ;
274+ let y_slice = y_input. as_slice ( ) . unwrap ( ) ;
275+ let v_slice = v_input. as_slice ( ) . unwrap ( ) ;
276+ let mut rhs_jac_transpose_mul = self . 0 . rhs_jac_transpose_mul . borrow_mut ( ) ;
277+ let outputs = rhs_jac_transpose_mul
278+ . run ( inputs ! [
279+ "y" => TensorRef :: from_array_view( ( [ y_slice. len( ) ] , y_slice) ) . unwrap( ) ,
280+ "v" => TensorRef :: from_array_view( ( [ v_slice. len( ) ] , v_slice) ) . unwrap( ) ,
281+ "p" => TensorRef :: from_array_view( ( [ p_slice. len( ) ] , p_slice) ) . unwrap( ) ,
282+ ] )
277283 . unwrap ( ) ;
278- let y_data = outputs[ "Identity_1:0" ] . try_extract_tensor :: < f32 > ( ) . unwrap ( ) ;
284+ let ( _shape , y_data) = outputs[ "Identity_1:0" ] . try_extract_tensor :: < f32 > ( ) . unwrap ( ) ;
279285 y. inner_mut ( )
280286 . iter_mut ( )
281- . zip ( y_data. as_slice ( ) . unwrap ( ) )
287+ . zip ( y_data. iter ( ) )
282288 . for_each ( |( y, x) | * y = * x as f64 ) ;
283289 }
284290}
@@ -295,22 +301,21 @@ impl NonLinearOpSensAdjoint for Rhs<'_> {
295301 . iter_mut ( )
296302 . zip ( v. inner ( ) . iter ( ) )
297303 . for_each ( |( v, x) | * v = * x as f32 ) ;
298- let outputs = self
299- . 0
300- . rhs_sens_transpose_mul
301- . run (
302- inputs ! [
303- "y" => y_input. view( ) ,
304- "v" => v_input. view( ) ,
305- "p" => self . 0 . input_p. view( ) ,
306- ]
307- . unwrap ( ) ,
308- )
304+ let p_slice = self . 0 . input_p . as_slice ( ) . unwrap ( ) ;
305+ let y_slice = y_input. as_slice ( ) . unwrap ( ) ;
306+ let v_slice = v_input. as_slice ( ) . unwrap ( ) ;
307+ let mut rhs_sens_transpose_mul = self . 0 . rhs_sens_transpose_mul . borrow_mut ( ) ;
308+ let outputs = rhs_sens_transpose_mul
309+ . run ( inputs ! [
310+ "y" => TensorRef :: from_array_view( ( [ y_slice. len( ) ] , y_slice) ) . unwrap( ) ,
311+ "v" => TensorRef :: from_array_view( ( [ v_slice. len( ) ] , v_slice) ) . unwrap( ) ,
312+ "p" => TensorRef :: from_array_view( ( [ p_slice. len( ) ] , p_slice) ) . unwrap( ) ,
313+ ] )
309314 . unwrap ( ) ;
310- let y_data = outputs[ "Identity_1:0" ] . try_extract_tensor :: < f32 > ( ) . unwrap ( ) ;
315+ let ( _shape , y_data) = outputs[ "Identity_1:0" ] . try_extract_tensor :: < f32 > ( ) . unwrap ( ) ;
311316 y. inner_mut ( )
312317 . iter_mut ( )
313- . zip ( y_data. as_slice ( ) . unwrap ( ) )
318+ . zip ( y_data. iter ( ) )
314319 . for_each ( |( y, x) | * y = * x as f64 ) ;
315320 }
316321}
0 commit comments