Skip to content

Commit 616831e

Browse files
committed
optimize coroutine
1 parent e56f1bb commit 616831e

File tree

1 file changed

+93
-126
lines changed

1 file changed

+93
-126
lines changed

crates/luars/src/lua_vm/mod.rs

Lines changed: 93 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -390,17 +390,24 @@ impl LuaVM {
390390

391391
// ============ Frame Management (Lua 5.4 style) ============
392392
// Uses pre-allocated Vec for O(1) operations
393-
// Key optimization: Vec is pre-filled to MAX_CALL_DEPTH, so direct index access
393+
// Key optimization:
394+
// - Main VM: pre-filled to MAX_CALL_DEPTH for direct index access
395+
// - Coroutines: start small and grow on demand (like Lua's linked list CallInfo)
394396

395397
/// Push a new frame onto the call stack and return stable pointer
396-
/// ULTRA-OPTIMIZED: Direct index write to pre-filled Vec
398+
/// OPTIMIZED: Direct index write when capacity allows, grow on demand otherwise
397399
#[inline(always)]
398400
pub(crate) fn push_frame(&mut self, frame: LuaCallFrame) -> *mut LuaCallFrame {
399401
let idx = self.frame_count;
400402
debug_assert!(idx < MAX_CALL_DEPTH, "call stack overflow");
401403

402-
// Direct write - Vec is pre-filled to MAX_CALL_DEPTH
403-
self.frames[idx] = frame;
404+
// Fast path: direct write if Vec is pre-filled or has space
405+
if idx < self.frames.len() {
406+
self.frames[idx] = frame;
407+
} else {
408+
// Slow path: grow the Vec (for coroutines with on-demand allocation)
409+
self.frames.push(frame);
410+
}
404411
self.frame_count = idx + 1;
405412
&mut self.frames[idx] as *mut LuaCallFrame
406413
}
@@ -537,12 +544,15 @@ impl LuaVM {
537544
}
538545

539546
// ============ Coroutine Support ============
547+
548+
/// Initial call depth for coroutines (grows on demand, like Lua's linked list CallInfo)
549+
const INITIAL_COROUTINE_CALL_DEPTH: usize = 8;
540550

541551
/// Create a new thread (coroutine) - returns ThreadId-based LuaValue
542552
pub fn create_thread_value(&mut self, func: LuaValue) -> LuaValue {
543-
// Pre-allocate frames like the main VM does
544-
let mut frames = Vec::with_capacity(MAX_CALL_DEPTH);
545-
frames.resize_with(MAX_CALL_DEPTH, LuaCallFrame::default);
553+
// Only allocate capacity, don't pre-fill (unlike main VM)
554+
// Coroutines typically have shallow call stacks, so we grow on demand
555+
let frames = Vec::with_capacity(Self::INITIAL_COROUTINE_CALL_DEPTH);
546556

547557
let mut thread = LuaThread {
548558
status: CoroutineStatus::Suspended,
@@ -572,9 +582,9 @@ impl LuaVM {
572582
/// Create a new thread (coroutine) - legacy version returning Rc<RefCell<>>
573583
/// This is still needed for internal VM state tracking (current_thread)
574584
pub fn create_thread(&mut self, func: LuaValue) -> Rc<RefCell<LuaThread>> {
575-
// Pre-allocate frames like the main VM does
576-
let mut frames = Vec::with_capacity(MAX_CALL_DEPTH);
577-
frames.resize_with(MAX_CALL_DEPTH, LuaCallFrame::default);
585+
// Only allocate capacity, don't pre-fill (unlike main VM)
586+
// Coroutines typically have shallow call stacks, so we grow on demand
587+
let frames = Vec::with_capacity(Self::INITIAL_COROUTINE_CALL_DEPTH);
578588

579589
let thread = LuaThread {
580590
status: CoroutineStatus::Suspended,
@@ -602,6 +612,7 @@ impl LuaVM {
602612
}
603613

604614
/// Resume a coroutine using ThreadId-based LuaValue
615+
/// OPTIMIZED: Uses swap instead of take to avoid repeated allocations
605616
pub fn resume_thread(
606617
&mut self,
607618
thread_val: LuaValue,
@@ -636,39 +647,29 @@ impl LuaVM {
636647
_ => {}
637648
}
638649

639-
// Save current VM state
640-
let saved_frames = std::mem::take(&mut self.frames);
641-
let saved_frame_count = self.frame_count;
642-
self.frame_count = 0; // frames is now empty
643-
let saved_stack = std::mem::take(&mut self.register_stack);
644-
let saved_returns = std::mem::take(&mut self.return_values);
645-
let saved_upvalues = std::mem::take(&mut self.open_upvalues);
646-
let saved_frame_id = self.next_frame_id;
647-
let saved_thread = self.current_thread.take();
648-
let saved_thread_id = self.current_thread_id.take();
649-
650-
// Get thread state and check if first resume
651-
let is_first_resume = {
652-
let Some(thread) = self.object_pool.get_thread(thread_id) else {
653-
return Err(self.error("invalid thread".to_string()));
654-
};
655-
thread.frame_count == 0 // Use frame_count instead of frames.is_empty()
656-
};
657-
658-
// Load thread state into VM
650+
// OPTIMIZED: Use swap to exchange state with thread
651+
// This avoids allocation/deallocation overhead of take
652+
let is_first_resume;
659653
{
660654
let Some(thread) = self.object_pool.get_thread_mut(thread_id) else {
661655
return Err(self.error("invalid thread".to_string()));
662656
};
657+
658+
is_first_resume = thread.frame_count == 0;
663659
thread.status = CoroutineStatus::Running;
664-
self.frames = std::mem::take(&mut thread.frames);
665-
self.frame_count = thread.frame_count; // Use thread's frame_count
666-
self.register_stack = std::mem::take(&mut thread.register_stack);
667-
self.return_values = std::mem::take(&mut thread.return_values);
668-
self.open_upvalues = std::mem::take(&mut thread.open_upvalues);
669-
self.next_frame_id = thread.next_frame_id;
660+
661+
// Swap state between VM and thread (O(1) pointer swaps)
662+
std::mem::swap(&mut self.frames, &mut thread.frames);
663+
std::mem::swap(&mut self.register_stack, &mut thread.register_stack);
664+
std::mem::swap(&mut self.return_values, &mut thread.return_values);
665+
std::mem::swap(&mut self.open_upvalues, &mut thread.open_upvalues);
666+
std::mem::swap(&mut self.frame_count, &mut thread.frame_count);
667+
std::mem::swap(&mut self.next_frame_id, &mut thread.next_frame_id);
670668
}
671669

670+
// Save thread tracking info
671+
let saved_thread = self.current_thread.take();
672+
let saved_thread_id = self.current_thread_id.take();
672673
self.current_thread_id = Some(thread_id);
673674
self.current_thread_value = Some(thread_val.clone());
674675

@@ -683,15 +684,13 @@ impl LuaVM {
683684
match self.call_function_internal(func, args) {
684685
Ok(values) => Ok(values),
685686
Err(LuaError::Yield) => {
686-
// Function yielded - this is expected
687687
let values = self.take_yield_values();
688688
Ok(values)
689689
}
690690
Err(e) => Err(e),
691691
}
692692
} else {
693-
// Resumed from yield:
694-
// Use saved CALL instruction info to properly store return values
693+
// Resumed from yield: handle return values
695694
let (call_reg, call_nret) = {
696695
let Some(thread) = self.object_pool.get_thread(thread_id) else {
697696
return Err(self.error("invalid thread".to_string()));
@@ -700,31 +699,30 @@ impl LuaVM {
700699
};
701700

702701
if let (Some(a), Some(num_expected)) = (call_reg, call_nret) {
703-
let frame = &self.frames[self.frame_count - 1];
704-
let base_ptr = frame.base_ptr;
705-
let top = frame.top;
702+
if self.frame_count > 0 {
703+
let frame = &self.frames[self.frame_count - 1];
704+
let base_ptr = frame.base_ptr;
705+
let top = frame.top;
706706

707-
// Store resume args as return values of the yield call
708-
let num_returns = args.len();
709-
let n = if num_expected == usize::MAX {
710-
num_returns
711-
} else {
712-
num_expected.min(num_returns)
713-
};
707+
let num_returns = args.len();
708+
let n = if num_expected == usize::MAX {
709+
num_returns
710+
} else {
711+
num_expected.min(num_returns)
712+
};
714713

715-
for (i, value) in args.iter().take(n).enumerate() {
716-
if base_ptr + a + i < self.register_stack.len() && a + i < top {
717-
self.register_stack[base_ptr + a + i] = value.clone();
714+
for (i, value) in args.iter().take(n).enumerate() {
715+
if base_ptr + a + i < self.register_stack.len() && a + i < top {
716+
self.register_stack[base_ptr + a + i] = *value;
717+
}
718718
}
719-
}
720-
// Fill remaining expected registers with nil
721-
for i in num_returns..num_expected.min(top - a) {
722-
if base_ptr + a + i < self.register_stack.len() {
723-
self.register_stack[base_ptr + a + i] = LuaValue::nil();
719+
for i in num_returns..num_expected.min(top.saturating_sub(a)) {
720+
if base_ptr + a + i < self.register_stack.len() {
721+
self.register_stack[base_ptr + a + i] = LuaValue::nil();
722+
}
724723
}
725724
}
726725

727-
// Clear the saved info
728726
if let Some(thread) = self.object_pool.get_thread_mut(thread_id) {
729727
thread.yield_call_reg = None;
730728
thread.yield_call_nret = None;
@@ -733,89 +731,58 @@ impl LuaVM {
733731

734732
self.return_values = args;
735733

736-
// Continue execution from where it yielded
737734
match self.run() {
738-
Ok(_) => {
739-
// Normal completion - return the stored return values
740-
Ok(self.return_values.clone())
741-
}
742-
Err(LuaError::Yield) => {
743-
// Yield happened - this is expected, get the yield values
744-
Ok(self.take_yield_values())
745-
}
735+
Ok(_) => Ok(self.return_values.clone()),
736+
Err(LuaError::Yield) => Ok(self.take_yield_values()),
746737
Err(e) => Err(e),
747738
}
748739
};
749740

750-
// Check if thread yielded by examining the result
751-
let did_yield = match &result {
752-
Ok(_) if !self.frames_is_empty() => {
753-
// If frames are not empty after execution, it means we yielded
754-
true
755-
}
756-
_ => false,
757-
};
741+
// Check if thread yielded
742+
let did_yield = matches!(&result, Ok(_) if self.frame_count > 0);
758743

759-
// Save thread state back
760-
let final_result = if did_yield {
761-
// Thread yielded - save state and return yield values
762-
let Some(thread) = self.object_pool.get_thread_mut(thread_id) else {
763-
return Err(self.error("invalid thread".to_string()));
764-
};
765-
thread.frames = std::mem::take(&mut self.frames);
766-
thread.frame_count = self.frame_count; // Save frame_count to thread
767-
self.frame_count = 0; // Reset VM frame count
768-
thread.register_stack = std::mem::take(&mut self.register_stack);
769-
thread.return_values = std::mem::take(&mut self.return_values);
770-
thread.open_upvalues = std::mem::take(&mut self.open_upvalues);
771-
thread.next_frame_id = self.next_frame_id;
772-
thread.status = CoroutineStatus::Suspended;
773-
774-
let values = thread.yield_values.clone();
775-
thread.yield_values.clear();
776-
777-
Ok((true, values))
778-
} else {
779-
// Thread completed or error
744+
// Swap state back to thread
745+
let final_result = {
780746
let Some(thread) = self.object_pool.get_thread_mut(thread_id) else {
781747
return Err(self.error("invalid thread".to_string()));
782748
};
783-
thread.frames = std::mem::take(&mut self.frames);
784-
thread.frame_count = self.frame_count; // Save frame_count to thread
785-
self.frame_count = 0; // Reset VM frame count
786-
thread.register_stack = std::mem::take(&mut self.register_stack);
787-
thread.return_values = std::mem::take(&mut self.return_values);
788-
thread.open_upvalues = std::mem::take(&mut self.open_upvalues);
789-
thread.next_frame_id = self.next_frame_id;
790-
791-
match result {
792-
Ok(values) => {
793-
thread.status = CoroutineStatus::Dead;
794-
Ok((true, values))
795-
}
796-
Err(LuaError::Exit) => {
797-
// Normal exit - coroutine finished successfully
798-
thread.status = CoroutineStatus::Dead;
799-
Ok((true, thread.return_values.clone()))
800-
}
801-
Err(_) => {
802-
thread.status = CoroutineStatus::Dead;
803-
let error_msg = self.get_error_message().to_string();
804-
Ok((false, vec![self.create_string(&error_msg)]))
749+
750+
// Swap back (O(1) pointer swaps)
751+
std::mem::swap(&mut self.frames, &mut thread.frames);
752+
std::mem::swap(&mut self.register_stack, &mut thread.register_stack);
753+
std::mem::swap(&mut self.return_values, &mut thread.return_values);
754+
std::mem::swap(&mut self.open_upvalues, &mut thread.open_upvalues);
755+
std::mem::swap(&mut self.frame_count, &mut thread.frame_count);
756+
std::mem::swap(&mut self.next_frame_id, &mut thread.next_frame_id);
757+
758+
if did_yield {
759+
thread.status = CoroutineStatus::Suspended;
760+
let values = thread.yield_values.clone();
761+
thread.yield_values.clear();
762+
Ok((true, values))
763+
} else {
764+
match result {
765+
Ok(values) => {
766+
thread.status = CoroutineStatus::Dead;
767+
Ok((true, values))
768+
}
769+
Err(LuaError::Exit) => {
770+
thread.status = CoroutineStatus::Dead;
771+
Ok((true, thread.return_values.clone()))
772+
}
773+
Err(_) => {
774+
thread.status = CoroutineStatus::Dead;
775+
let error_msg = self.get_error_message().to_string();
776+
Ok((false, vec![self.create_string(&error_msg)]))
777+
}
805778
}
806779
}
807780
};
808781

809-
// Restore VM state
810-
self.frames = saved_frames;
811-
self.frame_count = saved_frame_count; // CRITICAL: restore frame_count
812-
self.register_stack = saved_stack;
813-
self.return_values = saved_returns;
814-
self.open_upvalues = saved_upvalues;
815-
self.next_frame_id = saved_frame_id;
782+
// Restore thread tracking
816783
self.current_thread = saved_thread;
817784
self.current_thread_id = saved_thread_id;
818-
self.current_thread_value = None; // Clear after resume completes
785+
self.current_thread_value = None;
819786

820787
final_result
821788
}

0 commit comments

Comments
 (0)