Skip to content

Commit 6ac4027

Browse files
committed
guard against arbitrary rust recursion depth via coroutine abuse
1 parent 76b8c23 commit 6ac4027

File tree

7 files changed

+117
-34
lines changed

7 files changed

+117
-34
lines changed

src/fuel.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
use std::ops::{Deref, DerefMut};
2+
3+
use thiserror::Error;
4+
5+
pub const RECURSION_LIMIT: u8 = u8::MAX;
6+
17
/// A counter for tracking the amount of time spent in `Thread::step` and in callbacks.
28
///
39
/// The fuel unit is *approximately* one VM instruction, but this is just a rough estimate
@@ -9,6 +15,7 @@
915
pub struct Fuel {
1016
fuel: i32,
1117
interrupted: bool,
18+
recursion_level: u8,
1219
}
1320

1421
impl Fuel {
@@ -20,6 +27,7 @@ impl Fuel {
2027
Self {
2128
fuel,
2229
interrupted: false,
30+
recursion_level: 0,
2331
}
2432
}
2533

@@ -79,4 +87,61 @@ impl Fuel {
7987
pub fn should_continue(&self) -> bool {
8088
self.fuel > 0 && !self.interrupted
8189
}
90+
91+
/// Mark that we are about to run a Rust callback that is potentially controlled by untrusted
92+
/// code.
93+
///
94+
/// Increments the current recursion level if it is below `RECURSION_LIMIT`. If the recursion
95+
/// level would rise above the limit, this returns a recursion error, otherwise returns a
96+
/// guard that restores the previous recursion level on drop. This prevents untrusted code from
97+
/// consuming arbitrary Rust stack depth and execution time by tricking Rust code into recursing
98+
/// endlessly.
99+
///
100+
/// By default, the recursion level is automatically incremented whenever `Thread::step` is
101+
/// about to trigger a callback, so this is should almost never be necessary to call explicitly.
102+
///
103+
/// With the normal stdlib, arbitrary recursion is only possible by (ab)using coroutines. Normal
104+
/// Lua recursion and Rust code "calling" Lua code via a `Sequence` poll does not actually use
105+
/// the real Rust call stack, and cannot lead to using unbounded time or unbounded Rust stack
106+
/// space. Coroutines create their own inner `Thread`s and step them inside a `Sequence`, so
107+
/// they can eventually trigger a recursion limit in pathological cases.
108+
pub fn recurse(&mut self) -> Result<Recurse<'_>, RecursionLimit> {
109+
if self.recursion_level == RECURSION_LIMIT {
110+
return Err(RecursionLimit);
111+
}
112+
113+
self.recursion_level += 1;
114+
Ok(Recurse(self))
115+
}
116+
117+
/// Returns the current Rust callback recursion level.
118+
pub fn recursion_level(&self) -> u8 {
119+
self.recursion_level
120+
}
121+
}
122+
123+
#[derive(Debug, Copy, Clone, Error)]
124+
#[error("callback recursion limit reached")]
125+
pub struct RecursionLimit;
126+
127+
pub struct Recurse<'a>(&'a mut Fuel);
128+
129+
impl<'a> Drop for Recurse<'a> {
130+
fn drop(&mut self) {
131+
self.recursion_level -= 1;
132+
}
133+
}
134+
135+
impl<'a> Deref for Recurse<'a> {
136+
type Target = Fuel;
137+
138+
fn deref(&self) -> &Fuel {
139+
self.0
140+
}
141+
}
142+
143+
impl<'a> DerefMut for Recurse<'a> {
144+
fn deref_mut(&mut self) -> &mut Fuel {
145+
self.0
146+
}
82147
}

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ pub use self::{
3939
stack::Stack,
4040
string::{String, StringError},
4141
table::{InvalidTableKey, Table},
42-
thread::{BadThreadMode, Thread, ThreadError, ThreadMode},
42+
thread::{BadThreadMode, Thread, ThreadMode, VMError},
4343
userdata::{AnyUserData, BadUserDataType},
4444
value::Value,
4545
};

src/thread/error.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,15 @@ pub struct BadThreadMode {
5050
}
5151

5252
#[derive(Debug, Copy, Clone, Error)]
53-
pub enum ThreadError {
53+
pub enum VMError {
5454
#[error("{}", if *.0 {
5555
"operation expects variable stack"
5656
} else {
5757
"unexpected variable stack during operation"
5858
})]
59-
ExpectedVariable(bool),
59+
ExpectedVariableStack(bool),
6060
#[error(transparent)]
6161
BadType(#[from] TypeError),
62+
#[error("_ENV upvalue is only allowed on top-level closure")]
63+
BadEnvUpValue,
6264
}

src/thread/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ mod thread;
33
mod vm;
44

55
pub use self::{
6-
error::{BadThreadMode, BinaryOperatorError, ThreadError},
6+
error::{BadThreadMode, BinaryOperatorError, VMError},
77
thread::{Thread, ThreadMode},
88
};
99

src/thread/thread.rs

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ use crate::{
1616
meta_ops,
1717
types::{RegisterIndex, VarCount},
1818
AnyCallback, AnySequence, BadThreadMode, CallbackReturn, Closure, Context, Error,
19-
FromMultiValue, Fuel, Function, IntoMultiValue, SequencePoll, Stack, ThreadError, TypeError,
20-
Value,
19+
FromMultiValue, Fuel, Function, IntoMultiValue, SequencePoll, Stack, TypeError, VMError, Value,
2120
};
2221

2322
use super::run_vm;
@@ -195,13 +194,22 @@ impl<'gc> Thread<'gc> {
195194
while state.mode() == ThreadMode::Normal {
196195
match state.frames.pop().expect("no frame to step") {
197196
Frame::Callback(callback) => {
198-
fuel.consume_fuel(FUEL_PER_CALLBACK);
197+
let mut rfuel = match fuel.recurse() {
198+
Ok(r) => r,
199+
Err(err) => {
200+
state.unwind(&ctx, err.into());
201+
continue;
202+
}
203+
};
204+
205+
rfuel.consume_fuel(FUEL_PER_CALLBACK);
199206
state.frames.push(Frame::Calling);
200207

201208
assert!(state.error.is_none());
202209
let mut stack = mem::replace(&mut state.external_stack, Stack::new(&ctx));
203210
drop(state);
204-
let seq = callback.call(ctx, fuel, &mut stack);
211+
let seq = callback.call(ctx, &mut rfuel, &mut stack);
212+
drop(rfuel);
205213
state = self.0.borrow_mut(&ctx);
206214
state.external_stack = stack;
207215

@@ -216,18 +224,27 @@ impl<'gc> Thread<'gc> {
216224
}
217225
}
218226
Frame::Sequence(mut sequence) => {
219-
fuel.consume_fuel(FUEL_PER_SEQ_STEP);
227+
let mut rfuel = match fuel.recurse() {
228+
Ok(r) => r,
229+
Err(err) => {
230+
state.unwind(&ctx, err.into());
231+
continue;
232+
}
233+
};
234+
235+
rfuel.consume_fuel(FUEL_PER_SEQ_STEP);
220236
state.frames.push(Frame::Calling);
221237

222238
let mut stack = mem::replace(&mut state.external_stack, Stack::new(&ctx));
223239
let error = state.error.take();
224240
drop(state);
225241
let fin = if let Some(error) = error {
226242
assert!(stack.is_empty());
227-
sequence.error(ctx, fuel, error, &mut stack)
243+
sequence.error(ctx, &mut rfuel, error, &mut stack)
228244
} else {
229-
sequence.poll(ctx, fuel, &mut stack)
245+
sequence.poll(ctx, &mut rfuel, &mut stack)
230246
};
247+
drop(rfuel);
231248
state = self.0.borrow_mut(&ctx);
232249
state.external_stack = stack;
233250

@@ -414,11 +431,7 @@ impl<'gc, 'a> LuaFrame<'gc, 'a> {
414431
}
415432

416433
// Place the current frame's varargs at the given register, expecting the given count
417-
pub(crate) fn varargs(
418-
&mut self,
419-
dest: RegisterIndex,
420-
count: VarCount,
421-
) -> Result<(), ThreadError> {
434+
pub(crate) fn varargs(&mut self, dest: RegisterIndex, count: VarCount) -> Result<(), VMError> {
422435
match self.state.frames.last_mut() {
423436
Some(Frame::Lua {
424437
bottom,
@@ -427,7 +440,7 @@ impl<'gc, 'a> LuaFrame<'gc, 'a> {
427440
..
428441
}) => {
429442
if *is_variable {
430-
return Err(ThreadError::ExpectedVariable(false));
443+
return Err(VMError::ExpectedVariableStack(false));
431444
}
432445

433446
let varargs_start = *bottom + 1;
@@ -459,7 +472,7 @@ impl<'gc, 'a> LuaFrame<'gc, 'a> {
459472
mc: &Mutation<'gc>,
460473
table_base: RegisterIndex,
461474
count: VarCount,
462-
) -> Result<(), ThreadError> {
475+
) -> Result<(), VMError> {
463476
let Some(&mut Frame::Lua {
464477
base,
465478
ref mut is_variable,
@@ -471,7 +484,7 @@ impl<'gc, 'a> LuaFrame<'gc, 'a> {
471484
};
472485

473486
if count.is_variable() != *is_variable {
474-
return Err(ThreadError::ExpectedVariable(count.is_variable()));
487+
return Err(VMError::ExpectedVariableStack(count.is_variable()));
475488
}
476489

477490
let table_ind = base + table_base.0 as usize;
@@ -527,7 +540,7 @@ impl<'gc, 'a> LuaFrame<'gc, 'a> {
527540
func: RegisterIndex,
528541
args: VarCount,
529542
returns: VarCount,
530-
) -> Result<(), ThreadError> {
543+
) -> Result<(), VMError> {
531544
match self.state.frames.last_mut() {
532545
Some(Frame::Lua {
533546
expected_return,
@@ -536,7 +549,7 @@ impl<'gc, 'a> LuaFrame<'gc, 'a> {
536549
..
537550
}) => {
538551
if *is_variable != args.is_variable() {
539-
return Err(ThreadError::ExpectedVariable(args.is_variable()));
552+
return Err(VMError::ExpectedVariableStack(args.is_variable()));
540553
}
541554

542555
*expected_return = Some(LuaReturn::Normal(returns));
@@ -598,7 +611,7 @@ impl<'gc, 'a> LuaFrame<'gc, 'a> {
598611
func: RegisterIndex,
599612
arg_count: u8,
600613
returns: VarCount,
601-
) -> Result<(), ThreadError> {
614+
) -> Result<(), VMError> {
602615
match self.state.frames.last_mut() {
603616
Some(Frame::Lua {
604617
expected_return,
@@ -607,7 +620,7 @@ impl<'gc, 'a> LuaFrame<'gc, 'a> {
607620
..
608621
}) => {
609622
if *is_variable {
610-
return Err(ThreadError::ExpectedVariable(false));
623+
return Err(VMError::ExpectedVariableStack(false));
611624
}
612625

613626
consume_call_fuel(self.fuel, arg_count as usize);
@@ -672,7 +685,7 @@ impl<'gc, 'a> LuaFrame<'gc, 'a> {
672685
func: Function<'gc>,
673686
args: &[Value<'gc>],
674687
ret_index: Option<RegisterIndex>,
675-
) -> Result<(), ThreadError> {
688+
) -> Result<(), VMError> {
676689
match self.state.frames.last_mut() {
677690
Some(Frame::Lua {
678691
expected_return,
@@ -682,7 +695,7 @@ impl<'gc, 'a> LuaFrame<'gc, 'a> {
682695
..
683696
}) => {
684697
if *is_variable {
685-
return Err(ThreadError::ExpectedVariable(false));
698+
return Err(VMError::ExpectedVariableStack(false));
686699
}
687700

688701
consume_call_fuel(self.fuel, args.len());
@@ -738,7 +751,7 @@ impl<'gc, 'a> LuaFrame<'gc, 'a> {
738751
ctx: Context<'gc>,
739752
func: RegisterIndex,
740753
args: VarCount,
741-
) -> Result<(), ThreadError> {
754+
) -> Result<(), VMError> {
742755
match self.state.frames.last() {
743756
Some(&Frame::Lua {
744757
bottom,
@@ -747,7 +760,7 @@ impl<'gc, 'a> LuaFrame<'gc, 'a> {
747760
..
748761
}) => {
749762
if is_variable != args.is_variable() {
750-
return Err(ThreadError::ExpectedVariable(args.is_variable()));
763+
return Err(VMError::ExpectedVariableStack(args.is_variable()));
751764
}
752765

753766
self.state.close_upvalues(&ctx, bottom);
@@ -814,7 +827,7 @@ impl<'gc, 'a> LuaFrame<'gc, 'a> {
814827
mc: &Mutation<'gc>,
815828
start: RegisterIndex,
816829
count: VarCount,
817-
) -> Result<(), ThreadError> {
830+
) -> Result<(), VMError> {
818831
match self.state.frames.pop() {
819832
Some(Frame::Lua {
820833
bottom,
@@ -823,7 +836,7 @@ impl<'gc, 'a> LuaFrame<'gc, 'a> {
823836
..
824837
}) => {
825838
if is_variable != count.is_variable() {
826-
return Err(ThreadError::ExpectedVariable(count.is_variable()));
839+
return Err(VMError::ExpectedVariableStack(count.is_variable()));
827840
}
828841
self.state.close_upvalues(mc, bottom);
829842

src/thread/vm.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use crate::{
1111
Closure, Constant, Context, Function, RuntimeError, String, Table, Value,
1212
};
1313

14-
use super::{BinaryOperatorError, LuaFrame};
14+
use super::{BinaryOperatorError, LuaFrame, VMError};
1515

1616
// Runs the VM for the given number of instructions or until the current LuaFrame may have been
1717
// changed.
@@ -221,7 +221,7 @@ pub(crate) fn run_vm<'gc>(
221221
for &desc in proto.upvalues.iter() {
222222
match desc {
223223
UpValueDescriptor::Environment => {
224-
panic!("_ENV upvalue is only allowed on top-level closure");
224+
return Err(VMError::BadEnvUpValue.into());
225225
}
226226
UpValueDescriptor::ParentLocal(reg) => {
227227
upvalues.push(registers.open_upvalue(&ctx, reg));

tests/callback.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ fn callback() -> Result<(), StaticError> {
99
let mut lua = Lua::core();
1010

1111
lua.try_run(|ctx| {
12-
let callback = AnyCallback::from_fn(&ctx, |_, _, stack| {
12+
let callback = AnyCallback::from_fn(&ctx, |_, fuel, stack| {
13+
assert_eq!(fuel.recursion_level(), 1);
1314
stack.push_back(Value::Integer(42));
1415
Ok(CallbackReturn::Return)
1516
});
@@ -22,7 +23,9 @@ fn callback() -> Result<(), StaticError> {
2223
ctx,
2324
&br#"
2425
local a, b, c = callback(1, 2)
25-
return a == 1 and b == 2 and c == 42
26+
assert(a == 1 and b == 2 and c == 42)
27+
local d, e, f = callback(3, 4)
28+
assert(d == 3 and e == 4 and f == 42)
2629
"#[..],
2730
)?;
2831

@@ -31,7 +34,7 @@ fn callback() -> Result<(), StaticError> {
3134
Ok(ctx.state.registry.stash(&ctx, thread))
3235
})?;
3336

34-
assert!(lua.run_thread::<bool>(&thread)?);
37+
lua.run_thread::<()>(&thread)?;
3538
Ok(())
3639
}
3740

0 commit comments

Comments
 (0)