Skip to content

Commit 32e9935

Browse files
authored
Merge pull request #36 from krazijames/load-model-from-memory
Allow loading models from memory
2 parents b6383af + da17267 commit 32e9935

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed

onnxruntime/src/session.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,61 @@ impl SessionBuilder {
231231
outputs,
232232
})
233233
}
234+
235+
/// Load an ONNX graph from memory and commit the session
236+
pub fn with_model_from_memory<B>(self, model_bytes: B) -> Result<Session>
237+
where
238+
B: AsRef<[u8]>,
239+
{
240+
self.with_model_from_memory_monomorphized(model_bytes.as_ref())
241+
}
242+
243+
fn with_model_from_memory_monomorphized(self, model_bytes: &[u8]) -> Result<Session> {
244+
let mut session_ptr: *mut sys::OrtSession = std::ptr::null_mut();
245+
246+
let env_ptr: *const sys::OrtEnv = self.env.env_ptr();
247+
248+
let status = unsafe {
249+
let model_data = model_bytes.as_ptr() as *const std::ffi::c_void;
250+
let model_data_length = model_bytes.len() as u64;
251+
g_ort().CreateSessionFromArray.unwrap()(
252+
env_ptr,
253+
model_data,
254+
model_data_length,
255+
self.session_options_ptr,
256+
&mut session_ptr,
257+
)
258+
};
259+
status_to_result(status).map_err(OrtError::Session)?;
260+
assert_eq!(status, std::ptr::null_mut());
261+
assert_ne!(session_ptr, std::ptr::null_mut());
262+
263+
let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut();
264+
let status = unsafe { g_ort().GetAllocatorWithDefaultOptions.unwrap()(&mut allocator_ptr) };
265+
status_to_result(status).map_err(OrtError::Allocator)?;
266+
assert_eq!(status, std::ptr::null_mut());
267+
assert_ne!(allocator_ptr, std::ptr::null_mut());
268+
269+
let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default)?;
270+
271+
// Extract input and output properties
272+
let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?;
273+
let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?;
274+
let inputs = (0..num_input_nodes)
275+
.map(|i| dangerous::extract_input(session_ptr, allocator_ptr, i))
276+
.collect::<Result<Vec<Input>>>()?;
277+
let outputs = (0..num_output_nodes)
278+
.map(|i| dangerous::extract_output(session_ptr, allocator_ptr, i))
279+
.collect::<Result<Vec<Output>>>()?;
280+
281+
Ok(Session {
282+
session_ptr,
283+
allocator_ptr,
284+
memory_info,
285+
inputs,
286+
outputs,
287+
})
288+
}
234289
}
235290

236291
/// Type storing the session information, built from an [`Environment`](environment/struct.Environment.html)

0 commit comments

Comments
 (0)