Skip to content

Commit 41cba02

Browse files
authored
Merge pull request #97 from Peefy/feat-java-plugin
feat: impl kcl java plugin
2 parents 304558d + b2f0258 commit 41cba02

File tree

9 files changed

+210
-10
lines changed

9 files changed

+210
-10
lines changed

java/Makefile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ fmt:
77
mvn formatter:format
88

99
pkg:
10-
mvn clean package
10+
mvn clean package -Dcargo-build.profile=release
1111

1212
deploy:
13-
mvn clean deploy
13+
mvn clean deploy -Dcargo-build.profile=release
1414

1515
test:
16-
mvn clean test
16+
mvn clean test -Dcargo-build.profile=release

java/src/lib.rs

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,25 @@ extern crate once_cell;
88
extern crate prost;
99

1010
use anyhow::Result;
11-
use jni::objects::{JByteArray, JClass, JObject};
11+
use jni::objects::{GlobalRef, JByteArray, JClass, JObject, JString};
1212
use jni::sys::jbyteArray;
1313
use jni::JNIEnv;
14-
use kclvm_api::call;
14+
use jni::JavaVM;
15+
use kclvm_api::call_with_plugin_agent;
1516
use kclvm_api::gpyrpc::LoadPackageArgs;
1617
use kclvm_api::service::KclvmServiceImpl;
1718
use kclvm_parser::KCLModuleCache;
1819
use kclvm_sema::resolver::scope::KCLScopeCache;
1920
use lazy_static::lazy_static;
2021
use once_cell::sync::OnceCell;
2122
use prost::Message;
23+
use std::ffi::{CStr, CString};
24+
use std::os::raw::c_char;
2225
use std::sync::Mutex;
2326

2427
lazy_static! {
28+
static ref JVM: Mutex<Option<JavaVM>> = Mutex::new(None);
29+
static ref CALLBACK_OBJ: Mutex<Option<GlobalRef>> = Mutex::new(None);
2530
static ref MODULE_CACHE: Mutex<OnceCell<KCLModuleCache>> = Mutex::new(OnceCell::new());
2631
static ref SCOPE_CACHE: Mutex<OnceCell<KCLScopeCache>> = Mutex::new(OnceCell::new());
2732
}
@@ -33,12 +38,20 @@ pub extern "system" fn Java_com_kcl_api_API_callNative(
3338
name: JByteArray,
3439
args: JByteArray,
3540
) -> jbyteArray {
36-
intern_call_native(&mut env, name, args).unwrap_or_else(|e| {
41+
intern_call_native_with_plugin(&mut env, name, args).unwrap_or_else(|e| {
3742
let _ = throw(&mut env, e);
3843
JObject::default().into_raw()
3944
})
4045
}
4146

47+
#[no_mangle]
48+
pub extern "system" fn Java_com_kcl_api_API_registerPluginContext(env: JNIEnv, obj: JObject) {
49+
let jvm = env.get_java_vm().unwrap();
50+
*JVM.lock().unwrap() = Some(jvm);
51+
let global_ref = env.new_global_ref(obj).unwrap();
52+
*CALLBACK_OBJ.lock().unwrap() = Some(global_ref);
53+
}
54+
4255
#[no_mangle]
4356
pub extern "system" fn Java_com_kcl_api_API_loadPackageWithCache(
4457
mut env: JNIEnv,
@@ -51,10 +64,14 @@ pub extern "system" fn Java_com_kcl_api_API_loadPackageWithCache(
5164
})
5265
}
5366

54-
fn intern_call_native(env: &mut JNIEnv, name: JByteArray, args: JByteArray) -> Result<jbyteArray> {
67+
fn intern_call_native_with_plugin(
68+
env: &mut JNIEnv,
69+
name: JByteArray,
70+
args: JByteArray,
71+
) -> Result<jbyteArray> {
5572
let name = env.convert_byte_array(name)?;
5673
let args = env.convert_byte_array(args)?;
57-
let result = call(&name, &args)?;
74+
let result = call_with_plugin_agent(&name, &args, plugin_agent as u64)?;
5875
let j_byte_array = env.byte_array_from_slice(&result)?;
5976
Ok(j_byte_array.into_raw())
6077
}
@@ -78,6 +95,47 @@ fn intern_load_package_with_cache(env: &mut JNIEnv, args: JByteArray) -> Result<
7895
Ok(j_byte_array.into_raw())
7996
}
8097

98+
#[no_mangle]
99+
extern "C" fn plugin_agent(
100+
method: *const c_char,
101+
args: *const c_char,
102+
kwargs: *const c_char,
103+
) -> *const c_char {
104+
let jvm = JVM.lock().unwrap();
105+
let jvm = jvm.as_ref().unwrap();
106+
let mut env = jvm.attach_current_thread().unwrap();
107+
108+
let callback_obj = CALLBACK_OBJ.lock().unwrap();
109+
let callback_obj = callback_obj.as_ref().unwrap();
110+
111+
let method = unsafe {
112+
env.new_string(CStr::from_ptr(method).to_string_lossy().into_owned())
113+
.expect("Failed to create Java string")
114+
};
115+
let args = unsafe {
116+
env.new_string(CStr::from_ptr(args).to_string_lossy().into_owned())
117+
.expect("Failed to create Java string")
118+
};
119+
let kwargs = unsafe {
120+
env.new_string(CStr::from_ptr(kwargs).to_string_lossy().into_owned())
121+
.expect("Failed to create Java string")
122+
};
123+
let params = &[(&method).into(), (&args).into(), (&kwargs).into()];
124+
let result = env
125+
.call_method(
126+
callback_obj,
127+
"callMethod",
128+
"(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;",
129+
params,
130+
)
131+
.unwrap();
132+
let result: JString = result.l().unwrap().into();
133+
let result: String = env.get_string(&result).unwrap().into();
134+
CString::new(result)
135+
.expect("Failed to create CString")
136+
.into_raw()
137+
}
138+
81139
fn throw(env: &mut JNIEnv, error: anyhow::Error) -> jni::errors::Result<()> {
82140
env.throw(("java/lang/Exception", error.to_string()))
83141
}

java/src/main/java/com/kcl/api/API.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
import java.io.IOException;
55
import java.io.InputStream;
66
import java.io.UncheckedIOException;
7+
import java.util.Map;
78
import java.nio.file.Files;
89
import java.nio.file.StandardCopyOption;
910

1011
import com.kcl.api.Spec.*;
12+
import com.kcl.plugin.MethodFunction;
13+
import com.kcl.plugin.PluginContext;
1114

1215
public class API implements Service {
1316
static String LIB_NAME = "kcl_lib_jni";
@@ -30,7 +33,7 @@ private static void doLoadLibrary() throws IOException {
3033
System.loadLibrary(LIB_NAME);
3134
return;
3235
} catch (UnsatisfiedLinkError ignore) {
33-
// ignore - try to find native libraries from classpath
36+
// ignore - try to find native libraries from class path
3437
}
3538
doLoadBundledLibrary();
3639
}
@@ -59,7 +62,21 @@ private static String bundledLibraryPath() {
5962

6063
private native byte[] loadPackageWithCache(byte[] args);
6164

65+
private native void registerPluginContext(PluginContext ctx);
66+
67+
private static PluginContext pluginContext = new PluginContext();
68+
private static byte[] buffer = new byte[1024];
69+
70+
public static void registerPlugin(String name, Map<String, MethodFunction> methodMap) {
71+
pluginContext.registerPlugin(name, methodMap);
72+
}
73+
74+
private String callMethod(String method, String argsJson, String kwArgsJson) {
75+
return pluginContext.callMethod(method, argsJson, kwArgsJson);
76+
}
77+
6278
public API() {
79+
registerPluginContext(pluginContext);
6380
}
6481

6582
/**
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package com.kcl.plugin;
2+
3+
import java.util.Map;
4+
5+
@FunctionalInterface
6+
public interface MethodFunction {
7+
Object invoke(Object[] args, Map<String, Object> kwArgs);
8+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package com.kcl.plugin;
2+
3+
import java.util.Map;
4+
5+
public class Plugin {
6+
public String name;
7+
public Map<String, MethodFunction> methodMap;
8+
9+
public Map<String, MethodFunction> getMethodMap() {
10+
return methodMap;
11+
}
12+
13+
public Plugin(String name, Map<String, MethodFunction> methodMap) {
14+
this.name = name;
15+
this.methodMap = methodMap;
16+
}
17+
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package com.kcl.plugin;
2+
3+
import com.fasterxml.jackson.databind.ObjectMapper;
4+
import com.fasterxml.jackson.core.JsonProcessingException;
5+
import java.util.HashMap;
6+
import java.util.Map;
7+
8+
public class PluginContext {
9+
private Map<String, Plugin> pluginMap = new HashMap<>();
10+
private ObjectMapper objectMapper = new ObjectMapper();
11+
12+
public String callMethod(String name, String argsJson, String kwArgsJson) {
13+
return callJavaMethod(name, argsJson, kwArgsJson);
14+
}
15+
16+
private String callJavaMethod(String name, String argsJson, String kwArgsJson) {
17+
try {
18+
return callJavaMethodUnsafe(name, argsJson, kwArgsJson);
19+
} catch (Exception e) {
20+
Map<String, Object> panicInfo = new HashMap<>();
21+
panicInfo.put("__kcl_PanicInfo__", true);
22+
panicInfo.put("message", e.getMessage());
23+
return convertToJson(panicInfo);
24+
}
25+
}
26+
27+
private String callJavaMethodUnsafe(String name, String argsJson, String kwArgsJson) {
28+
int dotIdx = name.lastIndexOf(".");
29+
if (dotIdx < 0) {
30+
return "";
31+
}
32+
String modulePath = name.substring(0, dotIdx);
33+
String methodName = name.substring(dotIdx + 1);
34+
String pluginName = modulePath.substring(modulePath.lastIndexOf(".") + 1);
35+
36+
MethodFunction methodFunc = this.pluginMap.getOrDefault(pluginName, new Plugin("", new HashMap<>()))
37+
.getMethodMap().get(methodName);
38+
39+
Object[] args = convertFromJson(argsJson, Object[].class);
40+
Map<String, Object> kwArgs = convertFromJson(kwArgsJson, HashMap.class);
41+
42+
Object result = null;
43+
if (methodFunc != null) {
44+
result = methodFunc.invoke(args, kwArgs);
45+
}
46+
47+
return convertToJson(result);
48+
}
49+
50+
public void registerPlugin(String name, Map<String, MethodFunction> methodMap) {
51+
this.pluginMap.put(name, new Plugin(name, methodMap));
52+
}
53+
54+
private String convertToJson(Object object) {
55+
try {
56+
return objectMapper.writeValueAsString(object);
57+
} catch (JsonProcessingException e) {
58+
e.printStackTrace();
59+
return "";
60+
}
61+
}
62+
63+
private <T> T convertFromJson(String json, Class<T> type) {
64+
try {
65+
return objectMapper.readValue(json, type);
66+
} catch (JsonProcessingException e) {
67+
e.printStackTrace();
68+
return null;
69+
}
70+
}
71+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package com.kcl;
2+
3+
import com.kcl.api.API;
4+
import com.kcl.api.Spec.ExecProgram_Args;
5+
import com.kcl.api.Spec.ExecProgram_Result;
6+
7+
import java.util.Collections;
8+
9+
import org.junit.Assert;
10+
import org.junit.Test;
11+
12+
public class PluginTest {
13+
@Test
14+
public void testExecProgramWithPlugin() throws Exception {
15+
// API instance
16+
API.registerPlugin("my_plugin", Collections.singletonMap("add", (args, kwArgs) -> {
17+
return (int) args[0] + (int) args[1];
18+
}));
19+
API api = new API();
20+
21+
ExecProgram_Result result = api
22+
.execProgram(ExecProgram_Args.newBuilder().addKFilenameList("./src/test_data/plugin/plugin.k").build());
23+
24+
Assert.assertEquals(result.getYamlResult(), "result: 2");
25+
}
26+
}

java/src/test_data/plugin/plugin.k

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import kcl_plugin.my_plugin
2+
3+
result = my_plugin.add(1, 1)

python/kcl_lib/plugin/plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def _call_py_method(self, name: str, args_json: str, kwargs_json: str) -> str:
2121
try:
2222
return self._call_py_method_unsafe(name, args_json, kwargs_json)
2323
except Exception as e:
24-
return json.dumps({"__kcl_PanicInfo__": f"{e}"})
24+
return json.dumps({"__kcl_PanicInfo__": True, "message": f"{e}"})
2525

2626
def _call_py_method_unsafe(
2727
self, name: str, args_json: str, kwargs_json: str

0 commit comments

Comments
 (0)