Skip to content

Commit 9479a58

Browse files
committed
allow dot notation for RustPython
1 parent d890eec commit 9479a58

File tree

2 files changed

+38
-26
lines changed

2 files changed

+38
-26
lines changed

src/lib.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,13 @@ pub fn init_and_register<R: Runtime>(python_functions: Vec<&'static str>) -> Tau
118118
let functions = py_lib::read_variable(StringRequest {
119119
value: "_tauri_plugin_functions".into(),
120120
})
121-
.unwrap_or_default().replace("'","\"");
121+
.unwrap_or_default()
122+
.replace("'", "\""); // python arrays are serialized usings ' instead of "
122123

123124
// dbg!(&functions);
124125
if let Ok(python_functions) = serde_json::from_str::<Vec<String>>(&functions) {
125126
for function_name in python_functions {
126-
py_lib::register_function_str(function_name.into(), None).unwrap();
127+
py_lib::register_function_str(function_name, None).unwrap();
127128
}
128129
}
129130
Ok(())

src/py_lib.rs

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,19 @@ pub fn register_function(payload: RegisterRequest) -> crate::Result<()> {
5454
register_function_str(payload.python_function_call, payload.number_of_args)
5555
}
5656

57-
pub fn register_function_str(fn_name: String, number_of_args: Option<u8>) -> crate::Result<()> {
57+
pub fn register_function_str(
58+
function_name: String,
59+
number_of_args: Option<u8>,
60+
) -> crate::Result<()> {
5861
if INIT_BLOCKED.load(std::sync::atomic::Ordering::Relaxed) {
5962
return Err("Cannot register after function called".into());
6063
}
6164
rustpython_vm::Interpreter::without_stdlib(Default::default()).enter(|vm| {
62-
GLOBALS
63-
.globals
64-
.get_item(&fn_name, vm)
65-
.expect(&format!("Function {fn_name} not found"));
65+
let var_dot_split: Vec<&str> = function_name.split(".").collect();
66+
let func = GLOBALS.globals.get_item(var_dot_split[0], vm)?;
67+
if var_dot_split.len() > 1 {
68+
func.get_item(var_dot_split[1], vm)?;
69+
}
6670

6771
if let Some(num_args) = number_of_args {
6872
let py_analyze_sig = format!(
@@ -71,7 +75,7 @@ from inspect import signature
7175
if len(signature({}).parameters) != {}:
7276
raise Exception("Function parameters don't match in 'registerFunction'")
7377
"#,
74-
fn_name, num_args
78+
function_name, num_args
7579
);
7680

7781
let code_obj = vm
@@ -81,12 +85,13 @@ if len(signature({}).parameters) != {}:
8185
"<embedded>".to_owned(),
8286
)
8387
.map_err(|err| vm.new_syntax_error(&err, Some(&py_analyze_sig)))?;
84-
vm.run_code_obj(code_obj, GLOBALS.clone()).expect(&format!(
85-
"Number of args doesn't match signature of {fn_name}."
86-
));
88+
vm.run_code_obj(code_obj, GLOBALS.clone())
89+
.unwrap_or_else(|_| {
90+
panic!("Number of args doesn't match signature of {function_name}.")
91+
});
8792
}
88-
// dbg!(format!("Added '{fn_name}'"));
89-
FUNCTION_MAP.lock().unwrap().insert(fn_name);
93+
// dbg!(format!("Added '{function_name}'"));
94+
FUNCTION_MAP.lock().unwrap().insert(function_name);
9095
Ok(())
9196
})
9297
}
@@ -104,23 +109,29 @@ pub fn call_function(payload: RunRequest) -> crate::Result<String> {
104109
.into_iter()
105110
.map(|value| py_serde::deserialize(vm, value).unwrap())
106111
.collect();
107-
let res = GLOBALS
108-
.globals
109-
.get_item(&function_name, vm)?
110-
.call(posargs, vm)?
111-
.str(vm)?
112-
.to_string();
113-
Ok(res)
112+
let var_dot_split: Vec<&str> = function_name.split(".").collect();
113+
let func = GLOBALS.globals.get_item(var_dot_split[0], vm)?;
114+
Ok(if var_dot_split.len() > 1 {
115+
func.get_item(var_dot_split[1], vm)?
116+
} else {
117+
func
118+
}
119+
.call(posargs, vm)?
120+
.str(vm)?
121+
.to_string())
114122
})
115123
}
116124

117125
pub fn read_variable(payload: StringRequest) -> crate::Result<String> {
118126
rustpython_vm::Interpreter::without_stdlib(Default::default()).enter(|vm| {
119-
let res = GLOBALS
120-
.globals
121-
.get_item(&payload.value, vm)?
122-
.str(vm)?
123-
.to_string();
124-
Ok(res)
127+
let var_dot_split: Vec<&str> = payload.value.split(".").collect();
128+
let var = GLOBALS.globals.get_item(var_dot_split[0], vm)?;
129+
Ok(if var_dot_split.len() > 1 {
130+
var.get_item(var_dot_split[1], vm)?
131+
} else {
132+
var
133+
}
134+
.str(vm)?
135+
.to_string())
125136
})
126137
}

0 commit comments

Comments
 (0)