@@ -62,15 +62,15 @@ use crate::{download::AvailableOnnxModel, error::OrtDownloadError};
6262/// # }
6363/// ```
6464#[ derive( Debug ) ]
65- pub struct SessionBuilder {
66- env : Environment ,
65+ pub struct SessionBuilder < ' a > {
66+ env : & ' a Environment ,
6767 session_options_ptr : * mut sys:: OrtSessionOptions ,
6868
6969 allocator : AllocatorType ,
7070 memory_type : MemType ,
7171}
7272
73- impl Drop for SessionBuilder {
73+ impl < ' a > Drop for SessionBuilder < ' a > {
7474 #[ tracing:: instrument]
7575 fn drop ( & mut self ) {
7676 debug ! ( "Dropping the session options." ) ;
@@ -79,8 +79,8 @@ impl Drop for SessionBuilder {
7979 }
8080}
8181
82- impl SessionBuilder {
83- pub ( crate ) fn new ( env : Environment ) -> Result < SessionBuilder > {
82+ impl < ' a > SessionBuilder < ' a > {
83+ pub ( crate ) fn new ( env : & ' a Environment ) -> Result < SessionBuilder < ' a > > {
8484 let mut session_options_ptr: * mut sys:: OrtSessionOptions = std:: ptr:: null_mut ( ) ;
8585 let status = unsafe { g_ort ( ) . CreateSessionOptions . unwrap ( ) ( & mut session_options_ptr) } ;
8686
@@ -97,7 +97,7 @@ impl SessionBuilder {
9797 }
9898
9999 /// Configure the session to use a number of threads
100- pub fn with_number_threads ( self , num_threads : i16 ) -> Result < SessionBuilder > {
100+ pub fn with_number_threads ( self , num_threads : i16 ) -> Result < SessionBuilder < ' a > > {
101101 // FIXME: Pre-built binaries use OpenMP, set env variable instead
102102
103103 // We use a u16 in the builder to cover the 16-bits positive values of a i32.
@@ -113,7 +113,7 @@ impl SessionBuilder {
113113 pub fn with_optimization_level (
114114 self ,
115115 opt_level : GraphOptimizationLevel ,
116- ) -> Result < SessionBuilder > {
116+ ) -> Result < SessionBuilder < ' a > > {
117117 // Sets graph optimization level
118118 unsafe {
119119 g_ort ( ) . SetSessionGraphOptimizationLevel . unwrap ( ) (
@@ -127,47 +127,44 @@ impl SessionBuilder {
127127 /// Set the session's allocator
128128 ///
129129 /// Defaults to [`AllocatorType::Arena`](../enum.AllocatorType.html#variant.Arena)
130- pub fn with_allocator ( mut self , allocator : AllocatorType ) -> Result < SessionBuilder > {
130+ pub fn with_allocator ( mut self , allocator : AllocatorType ) -> Result < SessionBuilder < ' a > > {
131131 self . allocator = allocator;
132132 Ok ( self )
133133 }
134134
135135 /// Set the session's memory type
136136 ///
137137 /// Defaults to [`MemType::Default`](../enum.MemType.html#variant.Default)
138- pub fn with_memory_type ( mut self , memory_type : MemType ) -> Result < SessionBuilder > {
138+ pub fn with_memory_type ( mut self , memory_type : MemType ) -> Result < SessionBuilder < ' a > > {
139139 self . memory_type = memory_type;
140140 Ok ( self )
141141 }
142142
143143 /// Download an ONNX pre-trained model from the [ONNX Model Zoo](https://github.com/onnx/models) and commit the session
144144 #[ cfg( feature = "model-fetching" ) ]
145- pub fn with_model_downloaded < M > ( self , model : M ) -> Result < Session >
145+ pub fn with_model_downloaded < M > ( self , model : M ) -> Result < Session < ' a > >
146146 where
147147 M : Into < AvailableOnnxModel > ,
148148 {
149149 self . with_model_downloaded_monomorphized ( model. into ( ) )
150150 }
151151
152152 #[ cfg( feature = "model-fetching" ) ]
153- fn with_model_downloaded_monomorphized ( self , model : AvailableOnnxModel ) -> Result < Session > {
153+ fn with_model_downloaded_monomorphized ( self , model : AvailableOnnxModel ) -> Result < Session < ' a > > {
154154 let download_dir = env:: current_dir ( ) . map_err ( OrtDownloadError :: IoError ) ?;
155155 let downloaded_path = model. download_to ( download_dir) ?;
156- self . with_model_from_file_monomorphized ( downloaded_path. as_ref ( ) )
156+ self . with_model_from_file ( downloaded_path)
157157 }
158158
159159 // TODO: Add all functions changing the options.
160160 // See all OrtApi methods taking a `options: *mut OrtSessionOptions`.
161161
162162 /// Load an ONNX graph from a file and commit the session
163- pub fn with_model_from_file < P > ( self , model_filepath : P ) -> Result < Session >
163+ pub fn with_model_from_file < P > ( self , model_filepath_ref : P ) -> Result < Session < ' a > >
164164 where
165- P : AsRef < Path > ,
165+ P : AsRef < Path > + ' a ,
166166 {
167- self . with_model_from_file_monomorphized ( model_filepath. as_ref ( ) )
168- }
169-
170- fn with_model_from_file_monomorphized ( self , model_filepath : & Path ) -> Result < Session > {
167+ let model_filepath = model_filepath_ref. as_ref ( ) ;
171168 let mut session_ptr: * mut sys:: OrtSession = std:: ptr:: null_mut ( ) ;
172169
173170 if !model_filepath. exists ( ) {
@@ -224,6 +221,7 @@ impl SessionBuilder {
224221 . collect :: < Result < Vec < Output > > > ( ) ?;
225222
226223 Ok ( Session {
224+ env : self . env ,
227225 session_ptr,
228226 allocator_ptr,
229227 memory_info,
@@ -233,14 +231,14 @@ impl SessionBuilder {
233231 }
234232
235233 /// Load an ONNX graph from memory and commit the session
236- pub fn with_model_from_memory < B > ( self , model_bytes : B ) -> Result < Session >
234+ pub fn with_model_from_memory < B > ( self , model_bytes : B ) -> Result < Session < ' a > >
237235 where
238236 B : AsRef < [ u8 ] > ,
239237 {
240238 self . with_model_from_memory_monomorphized ( model_bytes. as_ref ( ) )
241239 }
242240
243- fn with_model_from_memory_monomorphized ( self , model_bytes : & [ u8 ] ) -> Result < Session > {
241+ fn with_model_from_memory_monomorphized ( self , model_bytes : & [ u8 ] ) -> Result < Session < ' a > > {
244242 let mut session_ptr: * mut sys:: OrtSession = std:: ptr:: null_mut ( ) ;
245243
246244 let env_ptr: * const sys:: OrtEnv = self . env . env_ptr ( ) ;
@@ -279,6 +277,7 @@ impl SessionBuilder {
279277 . collect :: < Result < Vec < Output > > > ( ) ?;
280278
281279 Ok ( Session {
280+ env : self . env ,
282281 session_ptr,
283282 allocator_ptr,
284283 memory_info,
@@ -290,7 +289,8 @@ impl SessionBuilder {
290289
291290/// Type storing the session information, built from an [`Environment`](environment/struct.Environment.html)
292291#[ derive( Debug ) ]
293- pub struct Session {
292+ pub struct Session < ' a > {
293+ env : & ' a Environment ,
294294 session_ptr : * mut sys:: OrtSession ,
295295 allocator_ptr : * mut sys:: OrtAllocator ,
296296 memory_info : MemoryInfo ,
@@ -348,7 +348,7 @@ impl Output {
348348 }
349349}
350350
351- impl Drop for Session {
351+ impl < ' a > Drop for Session < ' a > {
352352 #[ tracing:: instrument]
353353 fn drop ( & mut self ) {
354354 debug ! ( "Dropping the session." ) ;
@@ -360,7 +360,7 @@ impl Drop for Session {
360360 }
361361}
362362
363- impl Session {
363+ impl < ' a > Session < ' a > {
364364 /// Run the input data through the ONNX graph, performing inference.
365365 ///
366366 /// Note that ONNX models can have multiple inputs; a `Vec<_>` is thus
@@ -562,7 +562,7 @@ impl Session {
562562
563563/// This module contains dangerous functions working on raw pointers.
564564/// Those functions are only to be used from inside the
565- /// `SessionBuilder::with_model_from_file_monomorphized ()` method.
565+ /// `SessionBuilder::with_model_from_file ()` method.
566566mod dangerous {
567567 use super :: * ;
568568
0 commit comments