Skip to content

Commit d14a0bd

Browse files
committed
Convert the Path to OsString than to Vec<u8/u16> for model loading
1 parent 6237242 commit d14a0bd

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

onnxruntime/src/session.rs

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
33
use std::{ffi::CString, fmt::Debug, path::Path};
44

5+
#[cfg(not(target_family = "windows"))]
6+
use std::os::unix::ffi::OsStrExt;
7+
#[cfg(target_family = "windows")]
8+
use std::os::windows::ffi::OsStrExt;
9+
510
#[cfg(feature = "model-fetching")]
611
use std::env;
712

@@ -170,14 +175,21 @@ impl SessionBuilder {
170175
filename: model_filepath.to_path_buf(),
171176
});
172177
}
173-
let model_path: CString =
174-
CString::new(
175-
model_filepath
176-
.to_str()
177-
.ok_or_else(|| OrtError::NonUtf8Path {
178-
path: model_filepath.to_path_buf(),
179-
})?,
180-
)?;
178+
179+
// Build an OsString than a vector of bytes to pass to C
180+
let model_path = std::ffi::OsString::from(model_filepath);
181+
#[cfg(target_family = "windows")]
182+
let model_path: Vec<u16> = model_path
183+
.encode_wide()
184+
.chain(std::iter::once(0)) // Make sure we have a null terminated string
185+
.collect();
186+
#[cfg(not(target_family = "windows"))]
187+
let model_path: Vec<std::os::raw::c_char> = model_path
188+
.as_bytes()
189+
.iter()
190+
.chain(std::iter::once(&b'\0')) // Make sure we have a null terminated string
191+
.map(|b| *b as std::os::raw::c_char)
192+
.collect();
181193

182194
let env_ptr: *const sys::OrtEnv = self.env.env_ptr();
183195

0 commit comments

Comments
 (0)