@@ -15,17 +15,16 @@ use std::iter::zip;
1515
1616use ort:: execution_providers:: ExecutionProviderDispatch ;
1717use ort:: session:: Session ;
18- use ort:: util:: Mutex ;
1918use ort:: Error ;
2019use rustler:: Atom ;
2120use rustler:: ResourceArc ;
2221use std:: error:: Error as StdError ;
23- use std:: sync:: Arc ;
22+ use std:: sync:: Mutex ;
2423
2524/// Holds the model state which include onnxruntime session and environment. All
2625/// are threadsafe so this can be called concurrently from the beam.
2726pub struct OrtexModel {
28- pub session : Arc < Mutex < ort:: session:: Session > > ,
27+ pub session : Mutex < ort:: session:: Session > ,
2928}
3029
3130// Since we're only using the session for inference and
@@ -50,7 +49,7 @@ pub fn init(
5049 . commit_from_file ( model_path) ?;
5150
5251 let state = OrtexModel {
53- session : Arc :: new ( Mutex :: new ( session) ) ,
52+ session : session. into ( ) ,
5453 } ;
5554 Ok ( state)
5655}
@@ -64,8 +63,7 @@ pub fn show(
6463 Vec < ( String , String , Option < Vec < i64 > > ) > ,
6564 Vec < ( String , String , Option < Vec < i64 > > ) > ,
6665) {
67- let model: & OrtexModel = & * model;
68- let session = model. session . lock ( ) ;
66+ let session: & mut ort:: session:: Session = & mut model. session . lock ( ) . unwrap ( ) ;
6967
7068 let mut inputs = Vec :: new ( ) ;
7169 for input in session. inputs . iter ( ) {
@@ -92,7 +90,7 @@ pub fn run(
9290 model : ResourceArc < OrtexModel > ,
9391 inputs : Vec < ResourceArc < OrtexTensor > > ,
9492) -> Result < Vec < ( ResourceArc < OrtexTensor > , Vec < usize > , Atom , usize ) > , Box < dyn StdError > > {
95- let mut session = model. session . lock ( ) ;
93+ let session : & mut ort :: session:: Session = & mut model. session . lock ( ) . unwrap ( ) ;
9694
9795 let mut ortified_inputs: Vec < ort:: session:: SessionInputValue > = Vec :: new ( ) ;
9896
@@ -108,12 +106,10 @@ pub fn run(
108106 }
109107 }
110108
111- let output_descriptors = session. outputs . clone ( ) ;
112109 let outputs = session. run ( & ortified_inputs[ ..] ) ?;
113110 let mut collected_outputs = Vec :: new ( ) ;
114111
115- for output_descriptor in output_descriptors {
116- let output_name: & str = & output_descriptor. name ;
112+ for output_name in outputs. keys ( ) {
117113 let val = outputs. get ( output_name) . expect (
118114 & format ! (
119115 "Expected {} to be in the outputs, but didn't find it" ,
@@ -124,10 +120,26 @@ pub fn run(
124120 let ortextensor: OrtexTensor = val. try_into ( ) ?;
125121 let shape = ortextensor. shape ( ) ;
126122 let ( dtype, bits) = ortextensor. dtype ( ) ;
127-
128123 let collected_output = ( ResourceArc :: new ( ortextensor) , shape, dtype, bits) ;
129124 collected_outputs. push ( collected_output) ;
130125 }
131126
127+ // for output_descriptor in &session.outputs {
128+ // let output_name: &str = &output_descriptor.name;
129+ // let val = outputs.get(output_name).expect(
130+ // &format!(
131+ // "Expected {} to be in the outputs, but didn't find it",
132+ // output_name
133+ // )[..],
134+ // );
135+
136+ // let ortextensor: OrtexTensor = val.try_into()?;
137+ // let shape = ortextensor.shape();
138+ // let (dtype, bits) = ortextensor.dtype();
139+
140+ // let collected_output = (ResourceArc::new(ortextensor), shape, dtype, bits);
141+ // collected_outputs.push(collected_output);
142+ // }
143+
132144 Ok ( collected_outputs)
133145}
0 commit comments