Skip to content

Commit 3b804d4

Browse files
authored
Merge pull request #46 from nbigaouette/42-tie-session-lifetime-to-environment
Store a reference to environment in session to tie its lifetime
2 parents 694bb7c + 7198ab1 commit 3b804d4

File tree

2 files changed

+25
-25
lines changed

2 files changed

+25
-25
lines changed

onnxruntime/src/environment.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ impl Environment {
135135
/// Create a new [`SessionBuilder`](../session/struct.SessionBuilder.html)
136136
/// used to create a new ONNX session.
137137
pub fn new_session_builder(&self) -> Result<SessionBuilder> {
138-
SessionBuilder::new(self.clone())
138+
SessionBuilder::new(self)
139139
}
140140
}
141141

onnxruntime/src/session.rs

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,15 @@ use crate::{download::AvailableOnnxModel, error::OrtDownloadError};
6262
/// # }
6363
/// ```
6464
#[derive(Debug)]
65-
pub struct SessionBuilder {
66-
env: Environment,
65+
pub struct SessionBuilder<'a> {
66+
env: &'a Environment,
6767
session_options_ptr: *mut sys::OrtSessionOptions,
6868

6969
allocator: AllocatorType,
7070
memory_type: MemType,
7171
}
7272

73-
impl Drop for SessionBuilder {
73+
impl<'a> Drop for SessionBuilder<'a> {
7474
#[tracing::instrument]
7575
fn drop(&mut self) {
7676
debug!("Dropping the session options.");
@@ -79,8 +79,8 @@ impl Drop for SessionBuilder {
7979
}
8080
}
8181

82-
impl SessionBuilder {
83-
pub(crate) fn new(env: Environment) -> Result<SessionBuilder> {
82+
impl<'a> SessionBuilder<'a> {
83+
pub(crate) fn new(env: &'a Environment) -> Result<SessionBuilder<'a>> {
8484
let mut session_options_ptr: *mut sys::OrtSessionOptions = std::ptr::null_mut();
8585
let status = unsafe { g_ort().CreateSessionOptions.unwrap()(&mut session_options_ptr) };
8686

@@ -97,7 +97,7 @@ impl SessionBuilder {
9797
}
9898

9999
/// Configure the session to use a number of threads
100-
pub fn with_number_threads(self, num_threads: i16) -> Result<SessionBuilder> {
100+
pub fn with_number_threads(self, num_threads: i16) -> Result<SessionBuilder<'a>> {
101101
// FIXME: Pre-built binaries use OpenMP, set env variable instead
102102

103103
// We use a u16 in the builder to cover the 16-bits positive values of a i32.
@@ -113,7 +113,7 @@ impl SessionBuilder {
113113
pub fn with_optimization_level(
114114
self,
115115
opt_level: GraphOptimizationLevel,
116-
) -> Result<SessionBuilder> {
116+
) -> Result<SessionBuilder<'a>> {
117117
// Sets graph optimization level
118118
unsafe {
119119
g_ort().SetSessionGraphOptimizationLevel.unwrap()(
@@ -127,47 +127,44 @@ impl SessionBuilder {
127127
/// Set the session's allocator
128128
///
129129
/// Defaults to [`AllocatorType::Arena`](../enum.AllocatorType.html#variant.Arena)
130-
pub fn with_allocator(mut self, allocator: AllocatorType) -> Result<SessionBuilder> {
130+
pub fn with_allocator(mut self, allocator: AllocatorType) -> Result<SessionBuilder<'a>> {
131131
self.allocator = allocator;
132132
Ok(self)
133133
}
134134

135135
/// Set the session's memory type
136136
///
137137
/// Defaults to [`MemType::Default`](../enum.MemType.html#variant.Default)
138-
pub fn with_memory_type(mut self, memory_type: MemType) -> Result<SessionBuilder> {
138+
pub fn with_memory_type(mut self, memory_type: MemType) -> Result<SessionBuilder<'a>> {
139139
self.memory_type = memory_type;
140140
Ok(self)
141141
}
142142

143143
/// Download an ONNX pre-trained model from the [ONNX Model Zoo](https://github.com/onnx/models) and commit the session
144144
#[cfg(feature = "model-fetching")]
145-
pub fn with_model_downloaded<M>(self, model: M) -> Result<Session>
145+
pub fn with_model_downloaded<M>(self, model: M) -> Result<Session<'a>>
146146
where
147147
M: Into<AvailableOnnxModel>,
148148
{
149149
self.with_model_downloaded_monomorphized(model.into())
150150
}
151151

152152
#[cfg(feature = "model-fetching")]
153-
fn with_model_downloaded_monomorphized(self, model: AvailableOnnxModel) -> Result<Session> {
153+
fn with_model_downloaded_monomorphized(self, model: AvailableOnnxModel) -> Result<Session<'a>> {
154154
let download_dir = env::current_dir().map_err(OrtDownloadError::IoError)?;
155155
let downloaded_path = model.download_to(download_dir)?;
156-
self.with_model_from_file_monomorphized(downloaded_path.as_ref())
156+
self.with_model_from_file(downloaded_path)
157157
}
158158

159159
// TODO: Add all functions changing the options.
160160
// See all OrtApi methods taking a `options: *mut OrtSessionOptions`.
161161

162162
/// Load an ONNX graph from a file and commit the session
163-
pub fn with_model_from_file<P>(self, model_filepath: P) -> Result<Session>
163+
pub fn with_model_from_file<P>(self, model_filepath_ref: P) -> Result<Session<'a>>
164164
where
165-
P: AsRef<Path>,
165+
P: AsRef<Path> + 'a,
166166
{
167-
self.with_model_from_file_monomorphized(model_filepath.as_ref())
168-
}
169-
170-
fn with_model_from_file_monomorphized(self, model_filepath: &Path) -> Result<Session> {
167+
let model_filepath = model_filepath_ref.as_ref();
171168
let mut session_ptr: *mut sys::OrtSession = std::ptr::null_mut();
172169

173170
if !model_filepath.exists() {
@@ -224,6 +221,7 @@ impl SessionBuilder {
224221
.collect::<Result<Vec<Output>>>()?;
225222

226223
Ok(Session {
224+
env: self.env,
227225
session_ptr,
228226
allocator_ptr,
229227
memory_info,
@@ -233,14 +231,14 @@ impl SessionBuilder {
233231
}
234232

235233
/// Load an ONNX graph from memory and commit the session
236-
pub fn with_model_from_memory<B>(self, model_bytes: B) -> Result<Session>
234+
pub fn with_model_from_memory<B>(self, model_bytes: B) -> Result<Session<'a>>
237235
where
238236
B: AsRef<[u8]>,
239237
{
240238
self.with_model_from_memory_monomorphized(model_bytes.as_ref())
241239
}
242240

243-
fn with_model_from_memory_monomorphized(self, model_bytes: &[u8]) -> Result<Session> {
241+
fn with_model_from_memory_monomorphized(self, model_bytes: &[u8]) -> Result<Session<'a>> {
244242
let mut session_ptr: *mut sys::OrtSession = std::ptr::null_mut();
245243

246244
let env_ptr: *const sys::OrtEnv = self.env.env_ptr();
@@ -279,6 +277,7 @@ impl SessionBuilder {
279277
.collect::<Result<Vec<Output>>>()?;
280278

281279
Ok(Session {
280+
env: self.env,
282281
session_ptr,
283282
allocator_ptr,
284283
memory_info,
@@ -290,7 +289,8 @@ impl SessionBuilder {
290289

291290
/// Type storing the session information, built from an [`Environment`](environment/struct.Environment.html)
292291
#[derive(Debug)]
293-
pub struct Session {
292+
pub struct Session<'a> {
293+
env: &'a Environment,
294294
session_ptr: *mut sys::OrtSession,
295295
allocator_ptr: *mut sys::OrtAllocator,
296296
memory_info: MemoryInfo,
@@ -348,7 +348,7 @@ impl Output {
348348
}
349349
}
350350

351-
impl Drop for Session {
351+
impl<'a> Drop for Session<'a> {
352352
#[tracing::instrument]
353353
fn drop(&mut self) {
354354
debug!("Dropping the session.");
@@ -360,7 +360,7 @@ impl Drop for Session {
360360
}
361361
}
362362

363-
impl Session {
363+
impl<'a> Session<'a> {
364364
/// Run the input data through the ONNX graph, performing inference.
365365
///
366366
/// Note that ONNX models can have multiple inputs; a `Vec<_>` is thus
@@ -562,7 +562,7 @@ impl Session {
562562

563563
/// This module contains dangerous functions working on raw pointers.
564564
/// Those functions are only to be used from inside the
565-
/// `SessionBuilder::with_model_from_file_monomorphized()` method.
565+
/// `SessionBuilder::with_model_from_file()` method.
566566
mod dangerous {
567567
use super::*;
568568

0 commit comments

Comments
 (0)