Skip to content

Commit 4d71a34

Browse files
committed
Proper embedding tokenization
1 parent d54cef9 commit 4d71a34

File tree

8 files changed

+207
-122
lines changed

8 files changed

+207
-122
lines changed

.github/workflows/release_on_push.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ jobs:
3434
run: |
3535
mkdir -p mabl/src/main/assets
3636
curl -L -f -o mabl/src/main/assets/minilm-l6-v2-qint8-arm64.onnx "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model_qint8_arm64.onnx?download=true"
37+
curl -L -f -o mabl/src/main/assets/minilm-l6-v2-tokenizer.json "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/tokenizer.json?download=true"
3738
3839
- name: Generate version for build
3940
run: |

build.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
if ! [ -f mabl/src/main/assets/minilm-l6-v2-qint8-arm64.onnx ]; then
22
curl -L -o mabl/src/main/assets/minilm-l6-v2-qint8-arm64.onnx https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model_qint8_arm64.onnx?download=true
33
fi
4+
if ! [ -f mabl/src/main/assets/minilm-l6-v2-tokenizer.json ]; then
5+
curl -L -o mabl/src/main/assets/minilm-l6-v2-tokenizer.json https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/tokenizer.json?download=true
6+
fi
47
./gradlew :plugins:demo:installDebug :plugins:aipinsystem:installDebug :plugins:system:installDebug :plugins:openai:installDebug :plugins:googlesearch:installDebug :mabl:installAipinDebug
58
adb shell pm grant com.penumbraos.mabl.pin android.permission.CAMERA
69
adb shell appops set com.penumbraos.mabl.pin MANAGE_EXTERNAL_STORAGE allow

gradle/libs.versions.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ ktor-client = "3.0.0"
2121
kotlinx-serialization = "1.7.1"
2222
kotlinx-coroutines = "1.8.1"
2323
onnx-runtime = "1.20.0"
24+
sentence-embeddings = "v6"
2425
room = "2.7.2"
2526
jsoup = "1.17.2"
2627
# The first number needs to match the Kotlin version
@@ -64,6 +65,7 @@ ktor-serialization-kotlinx-json = { group = "io.ktor", name = "ktor-serializatio
6465
kotlinx-serialization-json = { group = "org.jetbrains.kotlinx", name = "kotlinx-serialization-json", version.ref = "kotlinx-serialization" }
6566
kotlinx-coroutines-android = { group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-android", version.ref = "kotlinx-coroutines" }
6667
onnx-runtime-android = { group = "com.microsoft.onnxruntime", name = "onnxruntime-android", version.ref = "onnx-runtime" }
68+
sentence-embeddings = { group = "io.gitlab.shubham0204", name = "sentence-embeddings", version.ref = "sentence-embeddings" }
6769
androidx-room-runtime = { group = "androidx.room", name = "room-runtime", version.ref = "room" }
6870
androidx-room-ktx = { group = "androidx.room", name = "room-ktx", version.ref = "room" }
6971
androidx-room-compiler = { group = "androidx.room", name = "room-compiler", version.ref = "room" }

mabl/build.gradle.kts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ dependencies {
8686
implementation(libs.androidx.camera.camera2)
8787

8888
implementation(libs.kotlinx.serialization.json)
89+
8990
implementation(libs.onnx.runtime.android)
91+
implementation(libs.sentence.embeddings)
9092

9193
implementation(libs.androidx.core.ktx)
9294
implementation(libs.androidx.lifecycle.runtime.ktx)

mabl/src/main/assets/.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
*.onnx
1+
*.onnx
2+
*.json

mabl/src/main/java/com/penumbraos/mabl/services/ToolOrchestrator.kt

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import com.penumbraos.mabl.sdk.IToolService
99
import com.penumbraos.mabl.sdk.PluginType
1010
import com.penumbraos.mabl.sdk.ToolCall
1111
import com.penumbraos.mabl.sdk.ToolDefinition
12-
import java.io.ByteArrayOutputStream
1312
import java.util.concurrent.ConcurrentHashMap
1413

1514
private const val TAG = "ToolOrchestrator"
@@ -60,9 +59,7 @@ class ToolOrchestrator(
6059
allConnected.await()
6160

6261
try {
63-
val outputStream = ByteArrayOutputStream()
64-
context.assets.open("minilm-l6-v2-qint8-arm64.onnx").copyTo(outputStream)
65-
toolSimilarityService.initialize(outputStream.toByteArray())
62+
toolSimilarityService.initialize(context)
6663

6764
// Precalculate embeddings for all available tools
6865
buildToolDefinitionsMap()
@@ -73,7 +70,7 @@ class ToolOrchestrator(
7370
"Tool similarity service initialized successfully with ${allTools.size} tool embeddings precalculated"
7471
)
7572
} catch (e: Exception) {
76-
Log.w(TAG, "Failed to initialize similarity service: ${e.message}")
73+
Log.e(TAG, "Failed to initialize similarity service: $e")
7774
}
7875
}
7976

Lines changed: 60 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
package com.penumbraos.mabl.services
22

3-
import ai.onnxruntime.OnnxTensor
4-
import ai.onnxruntime.OrtEnvironment
5-
import ai.onnxruntime.OrtSession
3+
import android.content.Context
64
import android.util.Log
75
import com.penumbraos.mabl.sdk.ToolDefinition
6+
import com.penumbraos.mabl.util.SentenceEmbedding
87
import kotlinx.coroutines.CoroutineScope
98
import kotlinx.coroutines.Dispatchers
109
import kotlinx.coroutines.withContext
11-
import java.nio.LongBuffer
10+
import java.io.ByteArrayOutputStream
1211
import java.util.concurrent.ConcurrentHashMap
1312
import kotlin.math.sqrt
1413

@@ -21,60 +20,75 @@ data class OfflineIntentClassificationResult(
2120
)
2221

2322
class ToolSimilarityService {
24-
private var ortEnvironment: OrtEnvironment? = null
25-
private var ortSession: OrtSession? = null
26-
private val embeddingCache = ConcurrentHashMap<String, FloatArray>()
23+
private val sentenceEmbedding = SentenceEmbedding()
2724
private val toolEmbeddingCache = ConcurrentHashMap<String, FloatArray>()
28-
private var offlineCapableTools: List<ToolDefinition> = emptyList()
2925
private val intentExampleEmbeddingCache = ConcurrentHashMap<String, FloatArray>()
26+
private var offlineCapableTools: List<ToolDefinition> = emptyList()
3027
private val scope = CoroutineScope(Dispatchers.IO)
3128

3229
companion object {
33-
private const val MAX_SEQUENCE_LENGTH = 512
3430
private const val SIMILARITY_THRESHOLD = 0.5f
3531
private const val INTENT_THRESHOLD = 0.55f
32+
// private const val TOOL_CONFIRMATION_MARGIN = 0.05f
3633
}
3734

38-
suspend fun initialize(modelBytes: ByteArray) {
35+
suspend fun initialize(context: Context) {
3936
withContext(scope.coroutineContext) {
40-
ortEnvironment = OrtEnvironment.getEnvironment()
41-
ortSession = ortEnvironment?.createSession(modelBytes)
37+
val modelBytes = ByteArrayOutputStream().use { outputStream ->
38+
context.assets.open("minilm-l6-v2-qint8-arm64.onnx").copyTo(outputStream)
39+
outputStream.toByteArray()
40+
}
41+
42+
val tokenizerBytes = ByteArrayOutputStream().use { outputStream ->
43+
context.assets.open("minilm-l6-v2-tokenizer.json").copyTo(outputStream)
44+
outputStream.toByteArray()
45+
}
46+
47+
try {
48+
sentenceEmbedding.init(
49+
modelBytes = modelBytes,
50+
tokenizerBytes = tokenizerBytes,
51+
useTokenTypeIds = true,
52+
outputTensorName = "last_hidden_state",
53+
normalizeEmbeddings = true
54+
)
55+
} catch (e: Exception) {
56+
Log.w(TAG, "Failed to initialize tokenizer: ${e.message}")
57+
null
58+
}
4259
}
4360
}
4461

4562
suspend fun precalculateToolEmbeddings(tools: List<ToolDefinition>) {
46-
if (ortSession == null) return
47-
4863
withContext(scope.coroutineContext) {
64+
val offlineCandidates = mutableListOf<ToolDefinition>()
65+
toolEmbeddingCache.clear()
66+
intentExampleEmbeddingCache.clear()
67+
4968
tools.forEach { tool ->
5069
val toolText = buildToolText(tool)
51-
val embedding = getEmbedding(toolText)
52-
toolEmbeddingCache[tool.name] = embedding
70+
toolEmbeddingCache[tool.name] = sentenceEmbedding.encode(toolText)
5371

5472
if (!tool.examples.isNullOrEmpty()) {
55-
offlineCapableTools += tool
73+
offlineCandidates += tool
5674
}
5775

5876
tool.examples?.forEachIndexed { index, example ->
5977
if (!example.isNullOrBlank()) {
6078
val key = intentExampleKey(tool.name, index)
61-
intentExampleEmbeddingCache[key] = getEmbedding(example)
79+
intentExampleEmbeddingCache[key] = sentenceEmbedding.encode(example)
6280
}
6381
}
6482
}
83+
offlineCapableTools = offlineCandidates
6584
}
6685
}
6786

6887
suspend fun classifyIntent(
6988
userQuery: String,
7089
): OfflineIntentClassificationResult? {
71-
if (ortSession == null) {
72-
Log.w(TAG, "Intent classification requested before model initialization")
73-
return null
74-
}
75-
7690
return withContext(scope.coroutineContext) {
77-
val queryEmbedding = getEmbedding(userQuery)
91+
val queryEmbedding = sentenceEmbedding.encode(userQuery)
7892
var bestMatch: OfflineIntentClassificationResult? = null
7993

8094
offlineCapableTools.forEach { tool ->
@@ -90,12 +104,28 @@ class ToolSimilarityService {
90104

91105
val key = intentExampleKey(tool.name, index)
92106
val exampleEmbedding = intentExampleEmbeddingCache[key]
93-
if (exampleEmbedding == null) {
107+
?: return@forEachIndexed
108+
109+
val score = cosineSimilarity(queryEmbedding, exampleEmbedding)
110+
if (score < INTENT_THRESHOLD) {
94111
return@forEachIndexed
95112
}
96-
val score = cosineSimilarity(queryEmbedding, exampleEmbedding)
97113

98-
if (score >= INTENT_THRESHOLD && (bestMatch == null || score > bestMatch!!.similarity)) {
114+
// val toolEmbedding = toolEmbeddingCache[tool.name]
115+
// ?: sentenceEmbedding.encode(buildToolText(tool)).also {
116+
// toolEmbeddingCache[tool.name] = it
117+
// }
118+
// val toolScore = cosineSimilarity(queryEmbedding, toolEmbedding)
119+
// Log.e(
120+
// "ToolSimilarityService",
121+
// "Intent classification result: ${tool.name} $score $toolScore"
122+
// )
123+
//
124+
// if (toolScore < INTENT_THRESHOLD + TOOL_CONFIRMATION_MARGIN) {
125+
// return@forEachIndexed
126+
// }
127+
128+
if (bestMatch == null || score > bestMatch!!.similarity) {
99129
val parameters = extractBooleanParameters(tool, userQuery)
100130
bestMatch = OfflineIntentClassificationResult(tool, score, parameters)
101131
}
@@ -111,18 +141,14 @@ class ToolSimilarityService {
111141
userQuery: String,
112142
maxTools: Int
113143
): List<ToolDefinition> {
114-
if (ortSession == null) {
115-
throw IllegalStateException("Tool similarity service not initialized")
116-
}
117-
118144
return withContext(scope.coroutineContext) {
119-
val queryEmbedding = getEmbedding(userQuery)
145+
val queryEmbedding = sentenceEmbedding.encode(userQuery)
120146

121147
val toolScores = tools.map { tool ->
122148
val toolEmbedding = toolEmbeddingCache[tool.name] ?: run {
123149
// Fallback: calculate embedding if not cached
124150
val toolText = buildToolText(tool)
125-
getEmbedding(toolText)
151+
sentenceEmbedding.encode(toolText)
126152
}
127153
val similarity = cosineSimilarity(queryEmbedding, toolEmbedding)
128154

@@ -141,36 +167,6 @@ class ToolSimilarityService {
141167
}
142168
}
143169

144-
private suspend fun getEmbedding(text: String): FloatArray {
145-
val cacheKey = text.hashCode().toString()
146-
embeddingCache[cacheKey]?.let { return it }
147-
148-
return withContext(scope.coroutineContext) {
149-
val tokenIds = tokenizeText(text)
150-
val inputIdsTensor = createInputTensor(tokenIds)
151-
val tokenTypeIdsTensor = createTokenTypeIdsTensor(tokenIds.size)
152-
val attentionMaskTensor = createAttentionMaskTensor(tokenIds.size)
153-
154-
val inputs = mapOf(
155-
"input_ids" to inputIdsTensor,
156-
"token_type_ids" to tokenTypeIdsTensor,
157-
"attention_mask" to attentionMaskTensor
158-
)
159-
val outputs = ortSession?.run(inputs)
160-
161-
val embedding = outputs?.get(0)?.value as Array<*>
162-
val floatEmbedding = (embedding[0] as Array<FloatArray>)[0]
163-
164-
inputIdsTensor.close()
165-
tokenTypeIdsTensor.close()
166-
attentionMaskTensor.close()
167-
outputs.close()
168-
169-
embeddingCache[cacheKey] = floatEmbedding
170-
floatEmbedding
171-
}
172-
}
173-
174170
private fun buildToolText(tool: ToolDefinition): String {
175171
val builder = StringBuilder()
176172
builder.append(tool.name).append(" ")
@@ -238,57 +234,6 @@ class ToolSimilarityService {
238234
}
239235
}
240236

241-
private fun tokenizeText(text: String): IntArray {
242-
val words = text.lowercase().split(Regex("\\W+"))
243-
val tokens = mutableListOf<Int>()
244-
245-
words.forEach { word ->
246-
if (word.isNotEmpty()) {
247-
tokens.add(word.hashCode() % 30000)
248-
}
249-
}
250-
251-
return tokens.take(MAX_SEQUENCE_LENGTH).toIntArray()
252-
}
253-
254-
private fun createInputTensor(tokenIds: IntArray): OnnxTensor {
255-
val shape = longArrayOf(1, tokenIds.size.toLong())
256-
val buffer = LongBuffer.allocate(tokenIds.size)
257-
258-
tokenIds.forEach { id ->
259-
buffer.put(id.toLong())
260-
}
261-
buffer.flip()
262-
263-
return OnnxTensor.createTensor(ortEnvironment, buffer, shape)
264-
}
265-
266-
private fun createTokenTypeIdsTensor(sequenceLength: Int): OnnxTensor {
267-
val shape = longArrayOf(1, sequenceLength.toLong())
268-
val buffer = LongBuffer.allocate(sequenceLength)
269-
270-
// All tokens are type 0 (single sentence)
271-
repeat(sequenceLength) {
272-
buffer.put(0L)
273-
}
274-
buffer.flip()
275-
276-
return OnnxTensor.createTensor(ortEnvironment, buffer, shape)
277-
}
278-
279-
private fun createAttentionMaskTensor(sequenceLength: Int): OnnxTensor {
280-
val shape = longArrayOf(1, sequenceLength.toLong())
281-
val buffer = LongBuffer.allocate(sequenceLength)
282-
283-
// All tokens get attention (no padding in our case)
284-
repeat(sequenceLength) {
285-
buffer.put(1L)
286-
}
287-
buffer.flip()
288-
289-
return OnnxTensor.createTensor(ortEnvironment, buffer, shape)
290-
}
291-
292237
private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float {
293238
if (a.size != b.size) return 0f
294239

@@ -307,7 +252,6 @@ class ToolSimilarityService {
307252
}
308253

309254
fun close() {
310-
ortSession?.close()
311-
ortEnvironment?.close()
255+
sentenceEmbedding.close()
312256
}
313257
}

0 commit comments

Comments
 (0)