Skip to content

Commit a4493c3

Browse files
Another try
1 parent de0dae6 commit a4493c3

File tree

1 file changed

+23
-11
lines changed

1 file changed

+23
-11
lines changed

native/ortex/src/model.rs

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,16 @@ use std::iter::zip;
1515

1616
use ort::execution_providers::ExecutionProviderDispatch;
1717
use ort::session::Session;
18-
use ort::util::Mutex;
1918
use ort::Error;
2019
use rustler::Atom;
2120
use rustler::ResourceArc;
2221
use std::error::Error as StdError;
23-
use std::sync::Arc;
22+
use std::sync::Mutex;
2423

2524
/// Holds the model state which include onnxruntime session and environment. All
2625
/// are threadsafe so this can be called concurrently from the beam.
2726
pub struct OrtexModel {
28-
pub session: Arc<Mutex<ort::session::Session>>,
27+
pub session: Mutex<ort::session::Session>,
2928
}
3029

3130
// Since we're only using the session for inference and
@@ -50,7 +49,7 @@ pub fn init(
5049
.commit_from_file(model_path)?;
5150

5251
let state = OrtexModel {
53-
session: Arc::new(Mutex::new(session)),
52+
session: session.into(),
5453
};
5554
Ok(state)
5655
}
@@ -64,8 +63,7 @@ pub fn show(
6463
Vec<(String, String, Option<Vec<i64>>)>,
6564
Vec<(String, String, Option<Vec<i64>>)>,
6665
) {
67-
let model: &OrtexModel = &*model;
68-
let session = model.session.lock();
66+
let session: &mut ort::session::Session = &mut model.session.lock().unwrap();
6967

7068
let mut inputs = Vec::new();
7169
for input in session.inputs.iter() {
@@ -92,7 +90,7 @@ pub fn run(
9290
model: ResourceArc<OrtexModel>,
9391
inputs: Vec<ResourceArc<OrtexTensor>>,
9492
) -> Result<Vec<(ResourceArc<OrtexTensor>, Vec<usize>, Atom, usize)>, Box<dyn StdError>> {
95-
let mut session = model.session.lock();
93+
let session: &mut ort::session::Session = &mut model.session.lock().unwrap();
9694

9795
let mut ortified_inputs: Vec<ort::session::SessionInputValue> = Vec::new();
9896

@@ -108,12 +106,10 @@ pub fn run(
108106
}
109107
}
110108

111-
let output_descriptors = session.outputs.clone();
112109
let outputs = session.run(&ortified_inputs[..])?;
113110
let mut collected_outputs = Vec::new();
114111

115-
for output_descriptor in output_descriptors {
116-
let output_name: &str = &output_descriptor.name;
112+
for output_name in outputs.keys() {
117113
let val = outputs.get(output_name).expect(
118114
&format!(
119115
"Expected {} to be in the outputs, but didn't find it",
@@ -124,10 +120,26 @@ pub fn run(
124120
let ortextensor: OrtexTensor = val.try_into()?;
125121
let shape = ortextensor.shape();
126122
let (dtype, bits) = ortextensor.dtype();
127-
128123
let collected_output = (ResourceArc::new(ortextensor), shape, dtype, bits);
129124
collected_outputs.push(collected_output);
130125
}
131126

127+
// for output_descriptor in &session.outputs {
128+
// let output_name: &str = &output_descriptor.name;
129+
// let val = outputs.get(output_name).expect(
130+
// &format!(
131+
// "Expected {} to be in the outputs, but didn't find it",
132+
// output_name
133+
// )[..],
134+
// );
135+
136+
// let ortextensor: OrtexTensor = val.try_into()?;
137+
// let shape = ortextensor.shape();
138+
// let (dtype, bits) = ortextensor.dtype();
139+
140+
// let collected_output = (ResourceArc::new(ortextensor), shape, dtype, bits);
141+
// collected_outputs.push(collected_output);
142+
// }
143+
132144
Ok(collected_outputs)
133145
}

0 commit comments

Comments
 (0)