Skip to content

Commit 821a8f2

Browse files
committed
Fix require
1 parent 1478e57 commit 821a8f2

File tree

3 files changed

+30
-19
lines changed

3 files changed

+30
-19
lines changed

crates/luars/src/lua_vm/mod.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2420,8 +2420,7 @@ impl LuaVM {
24202420
}
24212421
Err(LuaError::Yield) => {
24222422
// Yield is not an error - propagate it
2423-
let values = self.take_yield_values();
2424-
Err(self.do_yield(values))
2423+
Err(LuaError::Yield)
24252424
}
24262425
Err(_) => {
24272426
// Real error: clean up frames and return false with error message

crates/luars/src/stdlib/basic.rs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -641,8 +641,8 @@ fn lua_version(vm: &mut LuaVM) -> LuaResult<MultiValue> {
641641

642642
/// require(modname) - Load a module
643643
fn lua_require(vm: &mut LuaVM) -> LuaResult<MultiValue> {
644-
let modname_str = require_arg(vm, 1, "require")?;
645-
if !modname_str.is_string() {
644+
let modname_value = require_arg(vm, 1, "require")?;
645+
if !modname_value.is_string() {
646646
return Err(vm.error("module name must be a string".to_string()));
647647
}
648648

@@ -660,7 +660,7 @@ fn lua_require(vm: &mut LuaVM) -> LuaResult<MultiValue> {
660660
if let Some(loaded_table) = package_ref.raw_get(&loaded_key) {
661661
if let Some(loaded_id) = loaded_table.as_table_id() {
662662
if let Some(loaded_ref) = vm.object_pool.get_table(loaded_id) {
663-
loaded_ref.raw_get(&modname_str)
663+
loaded_ref.raw_get(&modname_value)
664664
} else {
665665
None
666666
}
@@ -721,7 +721,7 @@ fn lua_require(vm: &mut LuaVM) -> LuaResult<MultiValue> {
721721
// Try each searcher (1-based indexing)
722722
for searcher in searchers_values {
723723
// Call searcher with module name
724-
let (success, results) = vm.protected_call(searcher, vec![modname_str.clone()])?;
724+
let (success, results) = vm.protected_call(searcher, vec![modname_value.clone()])?;
725725

726726
if !success {
727727
let error_msg = results
@@ -745,15 +745,16 @@ fn lua_require(vm: &mut LuaVM) -> LuaResult<MultiValue> {
745745
if first_result.is_function() || first_result.is_cfunction() {
746746
// Call the loader
747747
let loader_args = if results.len() > 1 {
748-
vec![modname_str.clone(), results[1].clone()]
748+
vec![modname_value.clone(), results[1].clone()]
749749
} else {
750-
vec![modname_str.clone()]
750+
vec![modname_value.clone()]
751751
};
752752

753753
let (load_success, load_results) =
754754
vm.protected_call(first_result.clone(), loader_args)?;
755755

756756
if !load_success {
757+
let module_str = vm.value_to_string(&modname_value)?;
757758
let error_msg = load_results
758759
.first()
759760
.and_then(|v| {
@@ -766,7 +767,7 @@ fn lua_require(vm: &mut LuaVM) -> LuaResult<MultiValue> {
766767
.unwrap_or_else(|| "unknown error".to_string());
767768
return Err(vm.error(format!(
768769
"error loading module '{}': {}",
769-
modname_str, error_msg
770+
module_str, error_msg
770771
)));
771772
}
772773

@@ -784,7 +785,7 @@ fn lua_require(vm: &mut LuaVM) -> LuaResult<MultiValue> {
784785
if let Some(loaded_table) = package_ref.raw_get(&loaded_key) {
785786
if let Some(loaded_id) = loaded_table.as_table_id() {
786787
if let Some(loaded_ref) = vm.object_pool.get_table_mut(loaded_id) {
787-
loaded_ref.raw_set(modname_str, module_value.clone());
788+
loaded_ref.raw_set(modname_value, module_value.clone());
788789
}
789790
}
790791
}
@@ -800,14 +801,15 @@ fn lua_require(vm: &mut LuaVM) -> LuaResult<MultiValue> {
800801
}
801802
}
802803
}
804+
let module_str = vm.value_to_string(&modname_value)?;
803805

804806
// All searchers failed
805807
if error_messages.is_empty() {
806-
Err(vm.error(format!("module '{}' not found", modname_str)))
808+
Err(vm.error(format!("module '{}' not found", module_str)))
807809
} else {
808810
Err(vm.error(format!(
809811
"module '{}' not found:{}",
810-
modname_str,
812+
module_str,
811813
error_messages.join("")
812814
)))
813815
}

crates/luars/src/stdlib/package.rs

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,11 @@ fn searcher_lua(vm: &mut LuaVM) -> LuaResult<MultiValue> {
185185
}
186186

187187
// Loader function for Lua files (called by searcher_lua)
188+
// Called as: loader(modname, filepath)
188189
fn lua_file_loader(vm: &mut LuaVM) -> LuaResult<MultiValue> {
189-
let filepath_val = require_arg(vm, 1, "Lua file loader")?;
190+
// First arg is modname, second arg is filepath (passed by searcher)
191+
let _modname_val = require_arg(vm, 1, "Lua file loader")?;
192+
let filepath_val = require_arg(vm, 2, "Lua file loader")?;
190193

191194
let Some(filepath_id) = filepath_val.as_string_id() else {
192195
return Err(vm.error("file path must be a string".to_string()));
@@ -198,21 +201,27 @@ fn lua_file_loader(vm: &mut LuaVM) -> LuaResult<MultiValue> {
198201
s.as_str().to_string()
199202
};
200203

204+
if !std::fs::metadata(&filepath_str).is_ok() {
205+
return Ok(MultiValue::empty())
206+
}
207+
201208
// Read the file
202209
let source = match std::fs::read_to_string(&filepath_str) {
203210
Ok(s) => s,
204-
Err(_) => {
205-
return Ok(MultiValue::single(LuaValue::nil()));
211+
Err(e) => {
212+
return Err(vm.error(format!("cannot open file '{}': {}", filepath_str, e)));
206213
}
207214
};
208215

209216
// Compile it using VM's string pool with chunk name
210217
let chunkname = format!("@{}", filepath_str);
211218
let chunk = vm.compile_with_name(&source, &chunkname)?;
212-
// Create a function from the chunk
213-
let func = vm.create_function(Rc::new(chunk), vec![]);
219+
220+
// Create a function from the chunk with _ENV upvalue
221+
let env_upvalue_id = vm.object_pool.create_upvalue_closed(vm.global_value);
222+
let func = vm.create_function(Rc::new(chunk), vec![env_upvalue_id]);
214223

215-
// Call the function
224+
// Call the function to execute the module
216225
let (success, results) = vm.protected_call(func, vec![])?;
217226

218227
if !success {
@@ -228,7 +237,8 @@ fn lua_file_loader(vm: &mut LuaVM) -> LuaResult<MultiValue> {
228237
)));
229238
}
230239

231-
// Get the result value
240+
// Get the result value - if the module returns a value, use it
241+
// Otherwise return true (standard Lua behavior)
232242
let module_value = if results.is_empty() || results[0].is_nil() {
233243
LuaValue::boolean(true)
234244
} else {

0 commit comments

Comments
 (0)