diff --git a/sherpa-onnx/java-api/src/main/java/com/k2fsa/sherpa/onnx/OnlineRecognizer.java b/sherpa-onnx/java-api/src/main/java/com/k2fsa/sherpa/onnx/OnlineRecognizer.java index 22a3af0d39..cb7cd5e407 100644 --- a/sherpa-onnx/java-api/src/main/java/com/k2fsa/sherpa/onnx/OnlineRecognizer.java +++ b/sherpa-onnx/java-api/src/main/java/com/k2fsa/sherpa/onnx/OnlineRecognizer.java @@ -66,7 +66,8 @@ public OnlineRecognizerResult getResult(OnlineStream s) { String text = (String) arr[0]; String[] tokens = (String[]) arr[1]; float[] timestamps = (float[]) arr[2]; - return new OnlineRecognizerResult(text, tokens, timestamps); + float[] ysProbs = (float[]) arr[3]; + return new OnlineRecognizerResult(text, tokens, timestamps, ysProbs); } diff --git a/sherpa-onnx/java-api/src/main/java/com/k2fsa/sherpa/onnx/OnlineRecognizerResult.java b/sherpa-onnx/java-api/src/main/java/com/k2fsa/sherpa/onnx/OnlineRecognizerResult.java index 65e15a95bb..980d299fb9 100644 --- a/sherpa-onnx/java-api/src/main/java/com/k2fsa/sherpa/onnx/OnlineRecognizerResult.java +++ b/sherpa-onnx/java-api/src/main/java/com/k2fsa/sherpa/onnx/OnlineRecognizerResult.java @@ -6,11 +6,13 @@ public class OnlineRecognizerResult { private final String text; private final String[] tokens; private final float[] timestamps; + private final float[] ysProbs; - public OnlineRecognizerResult(String text, String[] tokens, float[] timestamps) { + public OnlineRecognizerResult(String text, String[] tokens, float[] timestamps, float[] ysProbs) { this.text = text; this.tokens = tokens; this.timestamps = timestamps; + this.ysProbs = ysProbs; } public String getText() { @@ -24,4 +26,9 @@ public String[] getTokens() { public float[] getTimestamps() { return timestamps; } + + public float[] getYsProbs() { + return ysProbs; + } + } diff --git a/sherpa-onnx/jni/online-recognizer.cc b/sherpa-onnx/jni/online-recognizer.cc index 2780b9c57c..b518ffe8e2 100644 --- a/sherpa-onnx/jni/online-recognizer.cc +++ b/sherpa-onnx/jni/online-recognizer.cc @@ -384,8 +384,9 @@ Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_getResult(JNIEnv *env, // [0]: text, jstring // [1]: tokens, array of jstring // [2]: timestamps, array of float + // [3]: ys_probs, array of float jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( - 3, env->FindClass("java/lang/Object"), nullptr); + 4, env->FindClass("java/lang/Object"), nullptr); jstring text = env->NewStringUTF(result.text.c_str()); env->SetObjectArrayElement(obj_arr, 0, text); @@ -408,5 +409,10 @@ Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_getResult(JNIEnv *env, env->SetObjectArrayElement(obj_arr, 2, timestamps_arr); + jfloatArray ys_probs_arr = env->NewFloatArray(result.ys_probs.size()); + env->SetFloatArrayRegion(ys_probs_arr, 0, result.ys_probs.size(), + result.ys_probs.data()); + env->SetObjectArrayElement(obj_arr, 3, ys_probs_arr); + return obj_arr; } diff --git a/sherpa-onnx/kotlin-api/OnlineRecognizer.kt b/sherpa-onnx/kotlin-api/OnlineRecognizer.kt index f12b11c014..b31f551386 100644 --- a/sherpa-onnx/kotlin-api/OnlineRecognizer.kt +++ b/sherpa-onnx/kotlin-api/OnlineRecognizer.kt @@ -83,6 +83,7 @@ data class OnlineRecognizerResult( val text: String, val tokens: Array, val timestamps: FloatArray, + val ysProbs: FloatArray, // TODO(fangjun): Add more fields ) @@ -124,8 +125,9 @@ class OnlineRecognizer( val text = objArray[0] as String val tokens = objArray[1] as Array val timestamps = objArray[2] as FloatArray + val ysProbs = objArray[3] as FloatArray - return OnlineRecognizerResult(text = text, tokens = tokens, timestamps = timestamps) + return OnlineRecognizerResult(text = text, tokens = tokens, timestamps = timestamps, ysProbs = ysProbs) } private external fun delete(ptr: Long)