Skip to content

Commit 9441506

Browse files
committed
Fix Lua String soundness when borrowing &str or &[u8].
Make borrowing using new `BorrowedStr` and `BorrowedBytes` types that holds strong reference to Lua.
1 parent bba644e commit 9441506

File tree

8 files changed

+179
-56
lines changed

8 files changed

+179
-56
lines changed

examples/async_http_server.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,14 @@ impl hyper::service::Service<Request<Incoming>> for Svc {
6868
if let Some(headers) = lua_resp.get::<_, Option<Table>>("headers")? {
6969
for pair in headers.pairs::<String, LuaString>() {
7070
let (h, v) = pair?;
71-
resp = resp.header(&h, v.as_bytes());
71+
resp = resp.header(&h, &*v.as_bytes());
7272
}
7373
}
7474

7575
// Set body
7676
let body = lua_resp
7777
.get::<_, Option<LuaString>>("body")?
78-
.map(|b| Full::new(Bytes::copy_from_slice(b.as_bytes())).boxed())
78+
.map(|b| Full::new(Bytes::copy_from_slice(&b.as_bytes())).boxed())
7979
.unwrap_or_else(|| Empty::<Bytes>::new().boxed());
8080

8181
Ok(resp.body(body).unwrap())

src/conversion.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ impl FromLua for CString {
464464
message: Some("expected string or number".to_string()),
465465
})?;
466466

467-
match CStr::from_bytes_with_nul(string.as_bytes_with_nul()) {
467+
match CStr::from_bytes_with_nul(&string.as_bytes_with_nul()) {
468468
Ok(s) => Ok(s.into()),
469469
Err(_) => Err(Error::FromLuaConversionError {
470470
from: ty,
@@ -500,7 +500,7 @@ impl FromLua for BString {
500500
fn from_lua(value: Value, lua: &Lua) -> Result<Self> {
501501
let ty = value.type_name();
502502
match value {
503-
Value::String(s) => Ok(s.as_bytes().into()),
503+
Value::String(s) => Ok((*s.as_bytes()).into()),
504504
#[cfg(feature = "luau")]
505505
Value::UserData(ud) if ud.1 == crate::types::SubtypeId::Buffer => unsafe {
506506
let lua = ud.0.lua.lock();
@@ -509,15 +509,15 @@ impl FromLua for BString {
509509
mlua_assert!(!buf.is_null(), "invalid Luau buffer");
510510
Ok(slice::from_raw_parts(buf as *const u8, size).into())
511511
},
512-
_ => Ok(lua
512+
_ => Ok((*lua
513513
.coerce_string(value)?
514514
.ok_or_else(|| Error::FromLuaConversionError {
515515
from: ty,
516516
to: "BString",
517517
message: Some("expected string or number".to_string()),
518518
})?
519-
.as_bytes()
520-
.into()),
519+
.as_bytes())
520+
.into()),
521521
}
522522
}
523523

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ pub use crate::multi::Variadic;
111111
pub use crate::state::{GCMode, Lua, LuaOptions};
112112
// pub use crate::scope::Scope;
113113
pub use crate::stdlib::StdLib;
114-
pub use crate::string::String;
114+
pub use crate::string::{BorrowedBytes, BorrowedStr, String};
115115
pub use crate::table::{Table, TableExt, TablePairs, TableSequence};
116116
pub use crate::thread::{Thread, ThreadStatus};
117117
pub use crate::types::{AppDataRef, AppDataRefMut, Integer, LightUserData, Number, RegistryKey};

src/serde/de.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ impl<'de> serde::Deserializer<'de> for Deserializer {
135135
#[cfg(feature = "luau")]
136136
Value::Vector(_) => self.deserialize_seq(visitor),
137137
Value::String(s) => match s.to_str() {
138-
Ok(s) => visitor.visit_str(s),
139-
Err(_) => visitor.visit_bytes(s.as_bytes()),
138+
Ok(s) => visitor.visit_str(&s),
139+
Err(_) => visitor.visit_bytes(&s.as_bytes()),
140140
},
141141
Value::Table(ref t) if t.raw_len() > 0 || t.is_array() => self.deserialize_seq(visitor),
142142
Value::Table(_) => self.deserialize_map(visitor),

src/string.rs

Lines changed: 154 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
use std::borrow::{Borrow, Cow};
1+
use std::borrow::Borrow;
22
use std::hash::{Hash, Hasher};
3+
use std::ops::Deref;
34
use std::os::raw::c_void;
45
use std::string::String as StdString;
5-
use std::{fmt, slice, str};
6+
use std::{cmp, fmt, slice, str};
67

78
#[cfg(feature = "serialize")]
89
use {
@@ -11,6 +12,7 @@ use {
1112
};
1213

1314
use crate::error::{Error, Result};
15+
use crate::state::LuaGuard;
1416
use crate::types::ValueRef;
1517

1618
/// Handle to an internal Lua string.
@@ -20,7 +22,7 @@ use crate::types::ValueRef;
2022
pub struct String(pub(crate) ValueRef);
2123

2224
impl String {
23-
/// Get a `&str` slice if the Lua string is valid UTF-8.
25+
/// Get a [`BorrowedStr`] if the Lua string is valid UTF-8.
2426
///
2527
/// # Examples
2628
///
@@ -39,15 +41,17 @@ impl String {
3941
/// # }
4042
/// ```
4143
#[inline]
42-
pub fn to_str(&self) -> Result<&str> {
43-
str::from_utf8(self.as_bytes()).map_err(|e| Error::FromLuaConversionError {
44+
pub fn to_str(&self) -> Result<BorrowedStr> {
45+
let BorrowedBytes(bytes, guard) = self.as_bytes();
46+
let s = str::from_utf8(bytes).map_err(|e| Error::FromLuaConversionError {
4447
from: "string",
4548
to: "&str",
4649
message: Some(e.to_string()),
47-
})
50+
})?;
51+
Ok(BorrowedStr(s, guard))
4852
}
4953

50-
/// Converts this string to a [`Cow<str>`].
54+
/// Converts this string to a [`StdString`].
5155
///
5256
/// Any non-Unicode sequences are replaced with [`U+FFFD REPLACEMENT CHARACTER`][U+FFFD].
5357
///
@@ -66,8 +70,8 @@ impl String {
6670
/// # }
6771
/// ```
6872
#[inline]
69-
pub fn to_string_lossy(&self) -> Cow<'_, str> {
70-
StdString::from_utf8_lossy(self.as_bytes())
73+
pub fn to_string_lossy(&self) -> StdString {
74+
StdString::from_utf8_lossy(&self.as_bytes()).into_owned()
7175
}
7276

7377
/// Get the bytes that make up this string.
@@ -88,13 +92,18 @@ impl String {
8892
/// # }
8993
/// ```
9094
#[inline]
91-
pub fn as_bytes(&self) -> &[u8] {
92-
let nulled = self.as_bytes_with_nul();
93-
&nulled[..nulled.len() - 1]
95+
pub fn as_bytes(&self) -> BorrowedBytes {
96+
let (bytes, guard) = unsafe { self.to_slice() };
97+
BorrowedBytes(&bytes[..bytes.len() - 1], guard)
9498
}
9599

96100
/// Get the bytes that make up this string, including the trailing nul byte.
97-
pub fn as_bytes_with_nul(&self) -> &[u8] {
101+
pub fn as_bytes_with_nul(&self) -> BorrowedBytes {
102+
let (bytes, guard) = unsafe { self.to_slice() };
103+
BorrowedBytes(bytes, guard)
104+
}
105+
106+
unsafe fn to_slice(&self) -> (&[u8], LuaGuard) {
98107
let lua = self.0.lua.lock();
99108
let ref_thread = lua.ref_thread();
100109
unsafe {
@@ -108,7 +117,7 @@ impl String {
108117
// string type
109118
let data = ffi::lua_tolstring(ref_thread, self.0.index, &mut size);
110119

111-
slice::from_raw_parts(data as *const u8, size + 1)
120+
(slice::from_raw_parts(data as *const u8, size + 1), lua)
112121
}
113122
}
114123

@@ -127,7 +136,7 @@ impl fmt::Debug for String {
127136
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
128137
let bytes = self.as_bytes();
129138
// Check if the string is valid utf8
130-
if let Ok(s) = str::from_utf8(bytes) {
139+
if let Ok(s) = str::from_utf8(&bytes) {
131140
return s.fmt(f);
132141
}
133142

@@ -152,22 +161,9 @@ impl fmt::Debug for String {
152161
}
153162
}
154163

155-
impl AsRef<[u8]> for String {
156-
fn as_ref(&self) -> &[u8] {
157-
self.as_bytes()
158-
}
159-
}
160-
161-
impl Borrow<[u8]> for String {
162-
fn borrow(&self) -> &[u8] {
163-
self.as_bytes()
164-
}
165-
}
166-
167164
// Lua strings are basically &[u8] slices, so implement PartialEq for anything resembling that.
168165
//
169-
// This makes our `String` comparable with `Vec<u8>`, `[u8]`, `&str`, `String` and `mlua::String`
170-
// itself.
166+
// This makes our `String` comparable with `Vec<u8>`, `[u8]`, `&str` and `String`.
171167
//
172168
// The only downside is that this disallows a comparison with `Cow<str>`, as that only implements
173169
// `AsRef<str>`, which collides with this impl. Requiring `AsRef<str>` would fix that, but limit us
@@ -181,6 +177,18 @@ where
181177
}
182178
}
183179

180+
impl PartialEq<String> for String {
181+
fn eq(&self, other: &String) -> bool {
182+
self.as_bytes() == other.as_bytes()
183+
}
184+
}
185+
186+
impl PartialEq<&String> for String {
187+
fn eq(&self, other: &&String) -> bool {
188+
self.as_bytes() == other.as_bytes()
189+
}
190+
}
191+
184192
impl Eq for String {}
185193

186194
impl Hash for String {
@@ -196,12 +204,127 @@ impl Serialize for String {
196204
S: Serializer,
197205
{
198206
match self.to_str() {
199-
Ok(s) => serializer.serialize_str(s),
200-
Err(_) => serializer.serialize_bytes(self.as_bytes()),
207+
Ok(s) => serializer.serialize_str(&s),
208+
Err(_) => serializer.serialize_bytes(&self.as_bytes()),
201209
}
202210
}
203211
}
204212

213+
/// A borrowed string (`&str`) that holds a strong reference to the Lua state.
214+
pub struct BorrowedStr<'a>(&'a str, #[allow(unused)] LuaGuard);
215+
216+
impl Deref for BorrowedStr<'_> {
217+
type Target = str;
218+
219+
#[inline(always)]
220+
fn deref(&self) -> &str {
221+
self.0
222+
}
223+
}
224+
225+
impl Borrow<str> for BorrowedStr<'_> {
226+
#[inline(always)]
227+
fn borrow(&self) -> &str {
228+
self.0
229+
}
230+
}
231+
232+
impl AsRef<str> for BorrowedStr<'_> {
233+
#[inline(always)]
234+
fn as_ref(&self) -> &str {
235+
self.0
236+
}
237+
}
238+
239+
impl fmt::Display for BorrowedStr<'_> {
240+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
241+
self.0.fmt(f)
242+
}
243+
}
244+
245+
impl fmt::Debug for BorrowedStr<'_> {
246+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
247+
self.0.fmt(f)
248+
}
249+
}
250+
251+
impl<T> PartialEq<T> for BorrowedStr<'_>
252+
where
253+
T: AsRef<str>,
254+
{
255+
fn eq(&self, other: &T) -> bool {
256+
self.0 == other.as_ref()
257+
}
258+
}
259+
260+
impl<T> PartialOrd<T> for BorrowedStr<'_>
261+
where
262+
T: AsRef<str>,
263+
{
264+
fn partial_cmp(&self, other: &T) -> Option<cmp::Ordering> {
265+
self.0.partial_cmp(other.as_ref())
266+
}
267+
}
268+
269+
/// A borrowed byte slice (`&[u8]`) that holds a strong reference to the Lua state.
270+
pub struct BorrowedBytes<'a>(&'a [u8], #[allow(unused)] LuaGuard);
271+
272+
impl Deref for BorrowedBytes<'_> {
273+
type Target = [u8];
274+
275+
#[inline(always)]
276+
fn deref(&self) -> &[u8] {
277+
self.0
278+
}
279+
}
280+
281+
impl Borrow<[u8]> for BorrowedBytes<'_> {
282+
#[inline(always)]
283+
fn borrow(&self) -> &[u8] {
284+
self.0
285+
}
286+
}
287+
288+
impl AsRef<[u8]> for BorrowedBytes<'_> {
289+
#[inline(always)]
290+
fn as_ref(&self) -> &[u8] {
291+
self.0
292+
}
293+
}
294+
295+
impl fmt::Debug for BorrowedBytes<'_> {
296+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
297+
self.0.fmt(f)
298+
}
299+
}
300+
301+
impl<T> PartialEq<T> for BorrowedBytes<'_>
302+
where
303+
T: AsRef<[u8]>,
304+
{
305+
fn eq(&self, other: &T) -> bool {
306+
self.0 == other.as_ref()
307+
}
308+
}
309+
310+
impl<T> PartialOrd<T> for BorrowedBytes<'_>
311+
where
312+
T: AsRef<[u8]>,
313+
{
314+
fn partial_cmp(&self, other: &T) -> Option<cmp::Ordering> {
315+
self.0.partial_cmp(other.as_ref())
316+
}
317+
}
318+
319+
impl<'a> IntoIterator for BorrowedBytes<'a> {
320+
type Item = &'a u8;
321+
type IntoIter = slice::Iter<'a, u8>;
322+
323+
fn into_iter(self) -> Self::IntoIter {
324+
self.0.into_iter()
325+
}
326+
}
327+
205328
#[cfg(test)]
206329
mod assertions {
207330
use super::*;

src/value.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use std::borrow::Cow;
21
use std::cmp::Ordering;
32
use std::collections::{vec_deque, HashSet, VecDeque};
43
use std::ops::{Deref, DerefMut};
@@ -12,7 +11,7 @@ use num_traits::FromPrimitive;
1211
use crate::error::{Error, Result};
1312
use crate::function::Function;
1413
use crate::state::{Lua, RawLua};
15-
use crate::string::String;
14+
use crate::string::{BorrowedStr, String};
1615
use crate::table::Table;
1716
use crate::thread::Thread;
1817
use crate::types::{Integer, LightUserData, Number, SubtypeId};
@@ -330,19 +329,20 @@ impl Value {
330329
}
331330
}
332331

333-
/// Cast the value to [`str`].
332+
/// Cast the value to [`BorrowedStr`].
334333
///
335-
/// If the value is a Lua [`String`], try to convert it to [`str`] or return `None` otherwise.
334+
/// If the value is a Lua [`String`], try to convert it to [`BorrowedStr`] or return `None`
335+
/// otherwise.
336336
#[inline]
337-
pub fn as_str(&self) -> Option<&str> {
337+
pub fn as_str(&self) -> Option<BorrowedStr> {
338338
self.as_string().and_then(|s| s.to_str().ok())
339339
}
340340

341-
/// Cast the value to [`Cow<str>`].
341+
/// Cast the value to [`StdString`].
342342
///
343-
/// If the value is a Lua [`String`], converts it to [`Cow<str>`] or returns `None` otherwise.
343+
/// If the value is a Lua [`String`], converts it to [`StdString`] or returns `None` otherwise.
344344
#[inline]
345-
pub fn as_string_lossy(&self) -> Option<Cow<str>> {
345+
pub fn as_string_lossy(&self) -> Option<StdString> {
346346
self.as_string().map(|s| s.to_string_lossy())
347347
}
348348

@@ -478,7 +478,7 @@ impl Value {
478478
(Value::Integer(_) | Value::Number(_), _) => Ordering::Less,
479479
(_, Value::Integer(_) | Value::Number(_)) => Ordering::Greater,
480480
// String
481-
(Value::String(a), Value::String(b)) => a.as_bytes().cmp(b.as_bytes()),
481+
(Value::String(a), Value::String(b)) => a.as_bytes().cmp(&b.as_bytes()),
482482
(Value::String(_), _) => Ordering::Less,
483483
(_, Value::String(_)) => Ordering::Greater,
484484
// Other variants can be randomly ordered

0 commit comments

Comments
 (0)