Skip to content

Commit d04f08c

Browse files
authored
Internally use widestring crate and CString similar extension for managing null terminated utf16 strings (#9)
1 parent f043852 commit d04f08c

File tree

6 files changed

+91
-79
lines changed

6 files changed

+91
-79
lines changed

Cargo.lock

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

wslplugins-rs/Cargo.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,16 @@ version = "0.58"
99
features = ["Win32_Foundation", "Win32_System", "Win32_Networking_WinSock"]
1010

1111
[dependencies]
12-
wslplugins-sys = { path = "../wslplugins-sys" }
13-
typed-path = ">0.1"
1412
bitflags = { version = ">0.1.0", optional = true }
15-
flagset = { version = ">0.1.0", optional = true }
1613
enumflags2 = { version = ">0.5", optional = true }
14+
flagset = { version = ">0.1.0", optional = true }
1715
log = { version = "*", optional = true }
1816
log-instrument = { version = "*", optional = true }
19-
wslplugins-macro = { path = "../wslplugins-macro", optional = true }
2017
thiserror = "2.0.7"
18+
typed-path = ">0.1"
19+
widestring = { version = "1", features = ["alloc"] }
20+
wslplugins-macro = { path = "../wslplugins-macro", optional = true }
21+
wslplugins-sys = { path = "../wslplugins-sys" }
2122

2223
[dependencies.semver]
2324
version = ">0.1"

wslplugins-rs/src/api/api_v1.rs

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,20 @@ extern crate wslplugins_sys;
33
use super::Error;
44
use super::Result;
55
use crate::api::errors::require_update_error::Result as UpReqResult;
6-
use crate::utils::{cstring_from_str, encode_wide_null_terminated};
6+
use crate::cstring_ext::CstringExt;
77
use crate::wsl_session_information::WSLSessionInformation;
88
#[cfg(feature = "log-instrument")]
99
use log_instrument::instrument;
10-
use std::ffi::{CString, OsStr, OsString};
10+
use std::ffi::{CString, OsStr};
1111
use std::fmt::Debug;
1212
use std::iter::once;
1313
use std::mem::MaybeUninit;
1414
use std::net::TcpStream;
1515
use std::os::windows::io::FromRawSocket;
1616
use std::os::windows::raw::SOCKET;
1717
use std::path::Path;
18-
use std::str::FromStr;
1918
use typed_path::Utf8UnixPath;
19+
use widestring::U16CString;
2020
use windows::Win32::Networking::WinSock::SOCKET as WinSocket;
2121
use windows::{
2222
core::{Result as WinResult, GUID, PCSTR, PCWSTR},
@@ -99,13 +99,10 @@ impl ApiV1 {
9999
read_only: bool,
100100
name: &OsStr,
101101
) -> WinResult<()> {
102-
let encoded_windows_path = encode_wide_null_terminated(windows_path.as_ref().as_os_str());
103-
let encoded_linux_path = encode_wide_null_terminated(
104-
OsString::from_str(linux_path.as_ref().as_str())
105-
.unwrap()
106-
.as_os_str(),
107-
);
108-
let encoded_name = encode_wide_null_terminated(name);
102+
let encoded_windows_path =
103+
U16CString::from_os_str_truncate(windows_path.as_ref().as_os_str());
104+
let encoded_linux_path = U16CString::from_str_truncate(linux_path.as_ref().as_str());
105+
let encoded_name = U16CString::from_os_str_truncate(name);
109106
let result = unsafe {
110107
self.0.MountFolder.unwrap_unchecked()(
111108
session.id(),
@@ -163,7 +160,10 @@ impl ApiV1 {
163160
.copied()
164161
.chain(once(0))
165162
.collect();
166-
let c_args: Vec<CString> = args.iter().map(|&arg| cstring_from_str(arg)).collect();
163+
let c_args: Vec<CString> = args
164+
.iter()
165+
.map(|&arg| CString::from_str_truncate(arg))
166+
.collect();
167167
let mut args_ptrs: Vec<PCSTR> = c_args
168168
.iter()
169169
.map(|arg| PCSTR::from_raw(arg.as_ptr() as *const u8))
@@ -188,8 +188,10 @@ impl ApiV1 {
188188
/// Set the error message to display to the user if the VM or distribution creation fails.
189189
#[cfg_attr(feature = "log-instrument", instrument)]
190190
pub(crate) fn plugin_error(&self, error: &OsStr) -> WinResult<()> {
191-
let error_vec = encode_wide_null_terminated(error);
192-
unsafe { self.0.PluginError.unwrap_unchecked()(PCWSTR::from_raw(error_vec.as_ptr())).ok() }
191+
let error_utf16 = U16CString::from_os_str_truncate(error);
192+
unsafe {
193+
self.0.PluginError.unwrap_unchecked()(PCWSTR::from_raw(error_utf16.as_ptr())).ok()
194+
}
193195
}
194196

195197
/// Execute a program in a user distribution
@@ -242,7 +244,10 @@ impl ApiV1 {
242244
.chain(once(0))
243245
.collect();
244246
let path_ptr = PCSTR::from_raw(c_path.as_ptr());
245-
let c_args: Vec<CString> = args.iter().map(|&arg| cstring_from_str(arg)).collect();
247+
let c_args: Vec<CString> = args
248+
.iter()
249+
.map(|&arg| CString::from_str_truncate(arg))
250+
.collect();
246251
let mut args_ptrs: Vec<PCSTR> = c_args
247252
.iter()
248253
.map(|arg| PCSTR::from_raw(arg.as_ptr() as *const u8))

wslplugins-rs/src/cstring_ext.rs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
use std::ffi::CString;
2+
3+
pub(crate) trait CstringExt {
4+
/// Creates a `CString` from a string slice, truncating at the first null byte if present.
5+
fn from_str_truncate(value: &str) -> Self;
6+
}
7+
8+
impl CstringExt for CString {
9+
fn from_str_truncate(value: &str) -> Self {
10+
let bytes = value.as_bytes();
11+
let truncated_bytes = match bytes.iter().position(|&b| b == 0) {
12+
Some(pos) => &bytes[..pos],
13+
None => bytes,
14+
};
15+
// SAFETY: `truncated_bytes` is guaranteed not to contain null bytes.
16+
unsafe { Self::from_vec_unchecked(truncated_bytes.to_vec()) }
17+
}
18+
}
19+
20+
#[cfg(test)]
21+
mod tests {
22+
use super::*;
23+
use std::ffi::CString;
24+
25+
#[test]
26+
fn test_from_str_truncate_no_null() {
27+
let input = "Hello, world!";
28+
let cstring = CString::from_str_truncate(input);
29+
assert_eq!(cstring.to_str().unwrap(), input);
30+
}
31+
32+
#[test]
33+
fn test_from_str_truncate_with_null() {
34+
let input = "Hello\0world!";
35+
let cstring = CString::from_str_truncate(input);
36+
assert_eq!(cstring.to_str().unwrap(), "Hello");
37+
}
38+
39+
#[test]
40+
fn test_from_str_truncate_empty() {
41+
let input = "";
42+
let cstring = CString::from_str_truncate(input);
43+
assert_eq!(cstring.to_str().unwrap(), "");
44+
}
45+
46+
#[test]
47+
fn test_from_str_truncate_null_only() {
48+
let input = "\0";
49+
let cstring = CString::from_str_truncate(input);
50+
assert_eq!(cstring.to_str().unwrap(), "");
51+
}
52+
53+
#[test]
54+
fn test_from_str_truncate_null_in_middle() {
55+
let input = "Rust\0is awesome!";
56+
let cstring = CString::from_str_truncate(input);
57+
assert_eq!(cstring.to_str().unwrap(), "Rust");
58+
}
59+
}

wslplugins-rs/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ pub mod api;
4141

4242
// Internal modules for managing specific WSL features.
4343
mod core_distribution_information;
44+
pub(crate) mod cstring_ext;
4445
mod distribution_information;
4546
mod offline_distribution_information;
4647
mod utils;

wslplugins-rs/src/utils.rs

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,5 @@
1-
//! # String Encoding Utilities
2-
//!
3-
//! This module provides utility functions to handle string encoding conversions, specifically for:
4-
//! - Encoding `OsStr` as wide, null-terminated UTF-16 strings.
5-
//! - Creating `CString` instances from Rust strings, filtering out null bytes.
6-
7-
use std::ffi::{CString, OsStr};
8-
use std::os::windows::ffi::OsStrExt;
9-
10-
pub fn encode_wide_null_terminated(input: &OsStr) -> Vec<u16> {
11-
input
12-
.encode_wide()
13-
.filter(|&c| c != 0)
14-
.chain(Some(0))
15-
.collect()
16-
}
17-
18-
pub fn cstring_from_str(input: &str) -> CString {
19-
let filtered_input: Vec<u8> = input.bytes().filter(|&c| c != 0).collect();
20-
unsafe { CString::from_vec_unchecked(filtered_input) }
21-
}
22-
231
#[cfg(test)]
242
pub(crate) fn test_transparence<T, U>() {
253
assert_eq!(align_of::<T>(), align_of::<U>());
264
assert_eq!(size_of::<T>(), size_of::<U>());
275
}
28-
29-
#[cfg(test)]
30-
mod tests {
31-
use super::*;
32-
use std::ffi::OsString;
33-
34-
/// Tests `encode_wide_null_terminated` with a string containing no null characters.
35-
#[test]
36-
fn test_encode_wide_null_terminated_no_nulls() {
37-
let input = OsString::from("Hello");
38-
let expected: Vec<u16> = "Hello\0".encode_utf16().collect();
39-
assert_eq!(encode_wide_null_terminated(&input), expected);
40-
}
41-
42-
/// Tests `encode_wide_null_terminated` with a string containing null characters.
43-
#[test]
44-
fn test_encode_wide_null_terminated_with_nulls() {
45-
let input = OsString::from("Hel\0lo");
46-
let expected: Vec<u16> = "Hello\0".encode_utf16().collect();
47-
assert_eq!(encode_wide_null_terminated(&input), expected);
48-
}
49-
50-
/// Tests `cstring_from_str` with a string containing no null characters.
51-
#[test]
52-
fn test_cstring_from_str_no_nulls() {
53-
let input = "Hello";
54-
let cstring = cstring_from_str(input);
55-
assert_eq!(cstring.to_str().unwrap(), input);
56-
}
57-
58-
/// Tests `cstring_from_str` with a string containing null characters.
59-
#[test]
60-
fn test_cstring_from_str_with_nulls() {
61-
let input = "Hel\0lo";
62-
let cstring = cstring_from_str(input);
63-
let expected = "Hello".as_bytes();
64-
assert_eq!(cstring.into_bytes(), expected);
65-
}
66-
}

0 commit comments

Comments
 (0)