@@ -231,6 +231,61 @@ impl SessionBuilder {
231231 outputs,
232232 } )
233233 }
234+
235+ /// Load an ONNX graph from memory and commit the session
236+ pub fn with_model_from_memory < B > ( self , model_bytes : B ) -> Result < Session >
237+ where
238+ B : AsRef < [ u8 ] > ,
239+ {
240+ self . with_model_from_memory_monomorphized ( model_bytes. as_ref ( ) )
241+ }
242+
243+ fn with_model_from_memory_monomorphized ( self , model_bytes : & [ u8 ] ) -> Result < Session > {
244+ let mut session_ptr: * mut sys:: OrtSession = std:: ptr:: null_mut ( ) ;
245+
246+ let env_ptr: * const sys:: OrtEnv = self . env . env_ptr ( ) ;
247+
248+ let status = unsafe {
249+ let model_data = model_bytes. as_ptr ( ) as * const std:: ffi:: c_void ;
250+ let model_data_length = model_bytes. len ( ) as u64 ;
251+ g_ort ( ) . CreateSessionFromArray . unwrap ( ) (
252+ env_ptr,
253+ model_data,
254+ model_data_length,
255+ self . session_options_ptr ,
256+ & mut session_ptr,
257+ )
258+ } ;
259+ status_to_result ( status) . map_err ( OrtError :: Session ) ?;
260+ assert_eq ! ( status, std:: ptr:: null_mut( ) ) ;
261+ assert_ne ! ( session_ptr, std:: ptr:: null_mut( ) ) ;
262+
263+ let mut allocator_ptr: * mut sys:: OrtAllocator = std:: ptr:: null_mut ( ) ;
264+ let status = unsafe { g_ort ( ) . GetAllocatorWithDefaultOptions . unwrap ( ) ( & mut allocator_ptr) } ;
265+ status_to_result ( status) . map_err ( OrtError :: Allocator ) ?;
266+ assert_eq ! ( status, std:: ptr:: null_mut( ) ) ;
267+ assert_ne ! ( allocator_ptr, std:: ptr:: null_mut( ) ) ;
268+
269+ let memory_info = MemoryInfo :: new ( AllocatorType :: Arena , MemType :: Default ) ?;
270+
271+ // Extract input and output properties
272+ let num_input_nodes = dangerous:: extract_inputs_count ( session_ptr) ?;
273+ let num_output_nodes = dangerous:: extract_outputs_count ( session_ptr) ?;
274+ let inputs = ( 0 ..num_input_nodes)
275+ . map ( |i| dangerous:: extract_input ( session_ptr, allocator_ptr, i) )
276+ . collect :: < Result < Vec < Input > > > ( ) ?;
277+ let outputs = ( 0 ..num_output_nodes)
278+ . map ( |i| dangerous:: extract_output ( session_ptr, allocator_ptr, i) )
279+ . collect :: < Result < Vec < Output > > > ( ) ?;
280+
281+ Ok ( Session {
282+ session_ptr,
283+ allocator_ptr,
284+ memory_info,
285+ inputs,
286+ outputs,
287+ } )
288+ }
234289}
235290
236291/// Type storing the session information, built from an [`Environment`](environment/struct.Environment.html)
0 commit comments