11package com.penumbraos.mabl.services
22
3- import ai.onnxruntime.OnnxTensor
4- import ai.onnxruntime.OrtEnvironment
5- import ai.onnxruntime.OrtSession
3+ import android.content.Context
64import android.util.Log
75import com.penumbraos.mabl.sdk.ToolDefinition
6+ import com.penumbraos.mabl.util.SentenceEmbedding
87import kotlinx.coroutines.CoroutineScope
98import kotlinx.coroutines.Dispatchers
109import kotlinx.coroutines.withContext
11- import java.nio.LongBuffer
10+ import java.io.ByteArrayOutputStream
1211import java.util.concurrent.ConcurrentHashMap
1312import kotlin.math.sqrt
1413
@@ -21,60 +20,75 @@ data class OfflineIntentClassificationResult(
2120)
2221
2322class ToolSimilarityService {
24- private var ortEnvironment: OrtEnvironment ? = null
25- private var ortSession: OrtSession ? = null
26- private val embeddingCache = ConcurrentHashMap <String , FloatArray >()
23+ private val sentenceEmbedding = SentenceEmbedding ()
2724 private val toolEmbeddingCache = ConcurrentHashMap <String , FloatArray >()
28- private var offlineCapableTools: List <ToolDefinition > = emptyList()
2925 private val intentExampleEmbeddingCache = ConcurrentHashMap <String , FloatArray >()
26+ private var offlineCapableTools: List <ToolDefinition > = emptyList()
3027 private val scope = CoroutineScope (Dispatchers .IO )
3128
3229 companion object {
33- private const val MAX_SEQUENCE_LENGTH = 512
3430 private const val SIMILARITY_THRESHOLD = 0.5f
3531 private const val INTENT_THRESHOLD = 0.55f
32+ // private const val TOOL_CONFIRMATION_MARGIN = 0.05f
3633 }
3734
38- suspend fun initialize (modelBytes : ByteArray ) {
35+ suspend fun initialize (context : Context ) {
3936 withContext(scope.coroutineContext) {
40- ortEnvironment = OrtEnvironment .getEnvironment()
41- ortSession = ortEnvironment?.createSession(modelBytes)
37+ val modelBytes = ByteArrayOutputStream ().use { outputStream ->
38+ context.assets.open(" minilm-l6-v2-qint8-arm64.onnx" ).copyTo(outputStream)
39+ outputStream.toByteArray()
40+ }
41+
42+ val tokenizerBytes = ByteArrayOutputStream ().use { outputStream ->
43+ context.assets.open(" minilm-l6-v2-tokenizer.json" ).copyTo(outputStream)
44+ outputStream.toByteArray()
45+ }
46+
47+ try {
48+ sentenceEmbedding.init (
49+ modelBytes = modelBytes,
50+ tokenizerBytes = tokenizerBytes,
51+ useTokenTypeIds = true ,
52+ outputTensorName = " last_hidden_state" ,
53+ normalizeEmbeddings = true
54+ )
55+ } catch (e: Exception ) {
56+ Log .w(TAG , " Failed to initialize tokenizer: ${e.message} " )
57+ null
58+ }
4259 }
4360 }
4461
4562 suspend fun precalculateToolEmbeddings (tools : List <ToolDefinition >) {
46- if (ortSession == null ) return
47-
4863 withContext(scope.coroutineContext) {
64+ val offlineCandidates = mutableListOf<ToolDefinition >()
65+ toolEmbeddingCache.clear()
66+ intentExampleEmbeddingCache.clear()
67+
4968 tools.forEach { tool ->
5069 val toolText = buildToolText(tool)
51- val embedding = getEmbedding(toolText)
52- toolEmbeddingCache[tool.name] = embedding
70+ toolEmbeddingCache[tool.name] = sentenceEmbedding.encode(toolText)
5371
5472 if (! tool.examples.isNullOrEmpty()) {
55- offlineCapableTools + = tool
73+ offlineCandidates + = tool
5674 }
5775
5876 tool.examples?.forEachIndexed { index, example ->
5977 if (! example.isNullOrBlank()) {
6078 val key = intentExampleKey(tool.name, index)
61- intentExampleEmbeddingCache[key] = getEmbedding (example)
79+ intentExampleEmbeddingCache[key] = sentenceEmbedding.encode (example)
6280 }
6381 }
6482 }
83+ offlineCapableTools = offlineCandidates
6584 }
6685 }
6786
6887 suspend fun classifyIntent (
6988 userQuery : String ,
7089 ): OfflineIntentClassificationResult ? {
71- if (ortSession == null ) {
72- Log .w(TAG , " Intent classification requested before model initialization" )
73- return null
74- }
75-
7690 return withContext(scope.coroutineContext) {
77- val queryEmbedding = getEmbedding (userQuery)
91+ val queryEmbedding = sentenceEmbedding.encode (userQuery)
7892 var bestMatch: OfflineIntentClassificationResult ? = null
7993
8094 offlineCapableTools.forEach { tool ->
@@ -90,12 +104,28 @@ class ToolSimilarityService {
90104
91105 val key = intentExampleKey(tool.name, index)
92106 val exampleEmbedding = intentExampleEmbeddingCache[key]
93- if (exampleEmbedding == null ) {
107+ ? : return @forEachIndexed
108+
109+ val score = cosineSimilarity(queryEmbedding, exampleEmbedding)
110+ if (score < INTENT_THRESHOLD ) {
94111 return @forEachIndexed
95112 }
96- val score = cosineSimilarity(queryEmbedding, exampleEmbedding)
97113
98- if (score >= INTENT_THRESHOLD && (bestMatch == null || score > bestMatch!! .similarity)) {
114+ // val toolEmbedding = toolEmbeddingCache[tool.name]
115+ // ?: sentenceEmbedding.encode(buildToolText(tool)).also {
116+ // toolEmbeddingCache[tool.name] = it
117+ // }
118+ // val toolScore = cosineSimilarity(queryEmbedding, toolEmbedding)
119+ // Log.e(
120+ // "ToolSimilarityService",
121+ // "Intent classification result: ${tool.name} $score $toolScore"
122+ // )
123+ //
124+ // if (toolScore < INTENT_THRESHOLD + TOOL_CONFIRMATION_MARGIN) {
125+ // return@forEachIndexed
126+ // }
127+
128+ if (bestMatch == null || score > bestMatch!! .similarity) {
99129 val parameters = extractBooleanParameters(tool, userQuery)
100130 bestMatch = OfflineIntentClassificationResult (tool, score, parameters)
101131 }
@@ -111,18 +141,14 @@ class ToolSimilarityService {
111141 userQuery : String ,
112142 maxTools : Int
113143 ): List <ToolDefinition > {
114- if (ortSession == null ) {
115- throw IllegalStateException (" Tool similarity service not initialized" )
116- }
117-
118144 return withContext(scope.coroutineContext) {
119- val queryEmbedding = getEmbedding (userQuery)
145+ val queryEmbedding = sentenceEmbedding.encode (userQuery)
120146
121147 val toolScores = tools.map { tool ->
122148 val toolEmbedding = toolEmbeddingCache[tool.name] ? : run {
123149 // Fallback: calculate embedding if not cached
124150 val toolText = buildToolText(tool)
125- getEmbedding (toolText)
151+ sentenceEmbedding.encode (toolText)
126152 }
127153 val similarity = cosineSimilarity(queryEmbedding, toolEmbedding)
128154
@@ -141,36 +167,6 @@ class ToolSimilarityService {
141167 }
142168 }
143169
144- private suspend fun getEmbedding (text : String ): FloatArray {
145- val cacheKey = text.hashCode().toString()
146- embeddingCache[cacheKey]?.let { return it }
147-
148- return withContext(scope.coroutineContext) {
149- val tokenIds = tokenizeText(text)
150- val inputIdsTensor = createInputTensor(tokenIds)
151- val tokenTypeIdsTensor = createTokenTypeIdsTensor(tokenIds.size)
152- val attentionMaskTensor = createAttentionMaskTensor(tokenIds.size)
153-
154- val inputs = mapOf (
155- " input_ids" to inputIdsTensor,
156- " token_type_ids" to tokenTypeIdsTensor,
157- " attention_mask" to attentionMaskTensor
158- )
159- val outputs = ortSession?.run (inputs)
160-
161- val embedding = outputs?.get(0 )?.value as Array <* >
162- val floatEmbedding = (embedding[0 ] as Array <FloatArray >)[0 ]
163-
164- inputIdsTensor.close()
165- tokenTypeIdsTensor.close()
166- attentionMaskTensor.close()
167- outputs.close()
168-
169- embeddingCache[cacheKey] = floatEmbedding
170- floatEmbedding
171- }
172- }
173-
174170 private fun buildToolText (tool : ToolDefinition ): String {
175171 val builder = StringBuilder ()
176172 builder.append(tool.name).append(" " )
@@ -238,57 +234,6 @@ class ToolSimilarityService {
238234 }
239235 }
240236
241- private fun tokenizeText (text : String ): IntArray {
242- val words = text.lowercase().split(Regex (" \\ W+" ))
243- val tokens = mutableListOf<Int >()
244-
245- words.forEach { word ->
246- if (word.isNotEmpty()) {
247- tokens.add(word.hashCode() % 30000 )
248- }
249- }
250-
251- return tokens.take(MAX_SEQUENCE_LENGTH ).toIntArray()
252- }
253-
254- private fun createInputTensor (tokenIds : IntArray ): OnnxTensor {
255- val shape = longArrayOf(1 , tokenIds.size.toLong())
256- val buffer = LongBuffer .allocate(tokenIds.size)
257-
258- tokenIds.forEach { id ->
259- buffer.put(id.toLong())
260- }
261- buffer.flip()
262-
263- return OnnxTensor .createTensor(ortEnvironment, buffer, shape)
264- }
265-
266- private fun createTokenTypeIdsTensor (sequenceLength : Int ): OnnxTensor {
267- val shape = longArrayOf(1 , sequenceLength.toLong())
268- val buffer = LongBuffer .allocate(sequenceLength)
269-
270- // All tokens are type 0 (single sentence)
271- repeat(sequenceLength) {
272- buffer.put(0L )
273- }
274- buffer.flip()
275-
276- return OnnxTensor .createTensor(ortEnvironment, buffer, shape)
277- }
278-
279- private fun createAttentionMaskTensor (sequenceLength : Int ): OnnxTensor {
280- val shape = longArrayOf(1 , sequenceLength.toLong())
281- val buffer = LongBuffer .allocate(sequenceLength)
282-
283- // All tokens get attention (no padding in our case)
284- repeat(sequenceLength) {
285- buffer.put(1L )
286- }
287- buffer.flip()
288-
289- return OnnxTensor .createTensor(ortEnvironment, buffer, shape)
290- }
291-
292237 private fun cosineSimilarity (a : FloatArray , b : FloatArray ): Float {
293238 if (a.size != b.size) return 0f
294239
@@ -307,7 +252,6 @@ class ToolSimilarityService {
307252 }
308253
309254 fun close () {
310- ortSession?.close()
311- ortEnvironment?.close()
255+ sentenceEmbedding.close()
312256 }
313257}
0 commit comments