Skip to content

Commit b63b807

Browse files
committed
working
1 parent 096bc4c commit b63b807

File tree

4 files changed

+78
-76
lines changed

4 files changed

+78
-76
lines changed

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ internal constructor(
143143

144144
fun getBidiEndpoint(): String {
145145
val vertexAiUrl =
146-
"wss://daily-firebaseml.sandbox.googleapis.com/ws/google.firebase.machinelearning.v2beta.LlmBidiService/BidiGenerateContent"
146+
"wss://firebasevertexai.googleapis.com/ws/google.firebase.vertexai.v1beta.LlmBidiService/BidiGenerateContent/locations/us-central1"
147+
147148
return "$vertexAiUrl?key=$key"
148149
}
149150

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/LiveContentResponse.kt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
package com.google.firebase.vertexai.type
22

3+
public enum class Status {
4+
NORMAL,
5+
INTERRUPTED,
6+
TURNCOMPLETE;
7+
}
38

49

5-
public class LiveContentResponse internal constructor(public val data: Content?, public val interrupted: Boolean?, public val functionCalls: List<FunctionCallPart>?) {
10+
public class LiveContentResponse internal constructor(public val data: Content?, public val status: Status = Status.NORMAL, public val functionCalls: List<FunctionCallPart>?) {
611
/**
712
* Convenience field representing all the text parts in the response as a single string, if they
813
* exists.

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/LiveSession.kt

Lines changed: 68 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@ import java.util.concurrent.ConcurrentLinkedQueue
1010
import kotlinx.coroutines.CoroutineScope
1111
import kotlinx.coroutines.Dispatchers
1212
import kotlinx.coroutines.cancel
13+
import kotlinx.coroutines.delay
1314
import kotlinx.coroutines.flow.Flow
1415
import kotlinx.coroutines.flow.flow
1516
import kotlinx.coroutines.launch
17+
import kotlinx.coroutines.runBlocking
1618
import kotlinx.serialization.SerialName
1719
import kotlinx.serialization.Serializable
1820
import kotlinx.serialization.encodeToString
@@ -27,6 +29,8 @@ internal constructor(
2729

2830
private val audioQueue = ConcurrentLinkedQueue<ByteArray>()
2931
private val playBackQueue = ConcurrentLinkedQueue<ByteArray>()
32+
private var stopReceiving = false
33+
private var startedReceiving = false
3034

3135
@Serializable
3236
internal data class ClientContent(
@@ -59,27 +63,30 @@ internal constructor(
5963
if (isRecording) {
6064
return
6165
}
66+
println("Started Receiving")
6267
isRecording = true
6368
audioHelper = AudioHelper()
6469
audioHelper!!.setupAudioTrack()
65-
66-
CoroutineScope(Dispatchers.Default).launch {
70+
val scope = CoroutineScope(Dispatchers.Default)
71+
CoroutineScope(Dispatchers.IO).launch {
6772
audioHelper!!.startRecording().collect {
6873
if (!isRecording) {
6974
cancel()
7075
}
76+
println(it)
7177
audioQueue.add(it)
7278
}
7379
}
74-
val minBufferSize =
75-
AudioTrack.getMinBufferSize(
76-
24000,
77-
AudioFormat.CHANNEL_OUT_MONO,
78-
AudioFormat.ENCODING_PCM_16BIT
79-
)
80-
var bytesRead = 0
81-
var recordedData = ByteArray(minBufferSize * 2)
82-
CoroutineScope(Dispatchers.Default).launch {
80+
81+
scope.launch {
82+
val minBufferSize =
83+
AudioTrack.getMinBufferSize(
84+
24000,
85+
AudioFormat.CHANNEL_OUT_MONO,
86+
AudioFormat.ENCODING_PCM_16BIT
87+
)
88+
var bytesRead = 0
89+
var recordedData = ByteArray(minBufferSize * 2)
8390
while (true) {
8491
if (!isRecording) {
8592
break
@@ -98,19 +105,19 @@ internal constructor(
98105
}
99106
}
100107
}
101-
CoroutineScope(Dispatchers.Default).launch {
108+
scope.launch {
102109
receive(listOf(ContentModality.AUDIO)).collect {
103110
if (!isRecording) {
104111
cancel()
105112
}
106-
if (it.interrupted == true) {
113+
if (it.status == Status.INTERRUPTED) {
107114
while (!playBackQueue.isEmpty()) playBackQueue.poll()
108-
} else {
115+
} else if(it.status == Status.NORMAL) {
109116
playBackQueue.add(it.data!!.parts[0].asInlineDataPartOrNull()!!.inlineData)
110117
}
111118
}
112119
}
113-
CoroutineScope(Dispatchers.Default).launch {
120+
CoroutineScope(Dispatchers.IO).launch {
114121
while (true) {
115122
if (!isRecording) {
116123
break
@@ -124,6 +131,7 @@ internal constructor(
124131
}
125132

126133
public fun stopAudioConversation() {
134+
stopReceiving()
127135
isRecording = false
128136
if (audioHelper != null) {
129137
while (!playBackQueue.isEmpty()) playBackQueue.poll()
@@ -133,37 +141,60 @@ internal constructor(
133141
}
134142
}
135143

144+
public fun stopReceiving() {
145+
if(!startedReceiving) {
146+
stopReceiving = false
147+
return
148+
}
149+
stopReceiving = true
150+
startedReceiving = false
151+
}
152+
153+
public class SessionAlreadyReceivingException: Exception("This session is already receiving. Please call stopReceiving() before calling this again.")
154+
136155
public suspend fun receive(
137156
outputModalities: List<ContentModality>
138157
): Flow<LiveContentResponse> {
158+
if(startedReceiving) {
159+
throw SessionAlreadyReceivingException()
160+
}
161+
139162
return flow {
163+
startedReceiving = true
140164
while (true) {
165+
println(stopReceiving)
166+
if(stopReceiving) {
167+
stopReceiving = false
168+
break
169+
}
141170
val message = session!!.incoming.receive()
142171
val receivedBytes = (message as Frame.Binary).readBytes()
143172
val receivedJson = receivedBytes.toString(Charsets.UTF_8)
173+
if (receivedJson.contains("interrupted")) {
174+
emit(LiveContentResponse(null, Status.INTERRUPTED, null))
175+
continue
176+
}
177+
if(receivedJson.contains("turnComplete")) {
178+
emit(LiveContentResponse(null, Status.TURNCOMPLETE, null))
179+
continue
180+
}
144181
try {
145182
val functionContent = Json.decodeFromString<ToolCallSetup>(receivedJson)
146-
// val y = functionContent.toolCall.functionCalls.map { it.toPublic() as FunctionCallPart }
147-
// emit(LiveContentResponse(null,false, y))
148-
//emit(LiveContentResponse(null, functionContent.toolCall.functionCalls.map { it.toPublic() as FunctionCallPart })))
149-
break
183+
emit(LiveContentResponse(null, Status.NORMAL, functionContent.toolCall.functionCalls.map { FunctionCallPart(it.name, it.args!!) }))
184+
continue
150185
} catch (_: Exception){ }
151-
152186
try {
153-
if (receivedJson.contains("interrupted")) {
154-
emit(LiveContentResponse(null, true, null))
155-
continue
156-
}
187+
157188
val serverContent = Json.decodeFromString<ServerContentSetup>(receivedJson)
158189
val data = serverContent.serverContent.modelTurn.toPublic()
159190
if (outputModalities.contains(ContentModality.AUDIO)) {
160191
if (data.parts[0].asInlineDataPartOrNull()?.mimeType?.equals("audio/pcm") == true) {
161-
emit(LiveContentResponse(data, false, null))
192+
emit(LiveContentResponse(data, Status.NORMAL, null))
162193
}
163194
}
164195
if (outputModalities.contains(ContentModality.TEXT)) {
165196
if (data.parts[0] is TextPart) {
166-
emit(LiveContentResponse(data, false, null))
197+
emit(LiveContentResponse(data, Status.NORMAL, null))
167198
}
168199
}
169200
} catch (e: Exception) {
@@ -178,6 +209,7 @@ internal constructor(
178209
) {
179210
val jsonString =
180211
Json.encodeToString(MediaStreamingSetup(MediaChunks(mediaChunks.map { it.toInternal() })))
212+
println(jsonString)
181213
session?.send(Frame.Text(jsonString))
182214
}
183215
/*
@@ -187,55 +219,18 @@ internal constructor(
187219
188220
*/
189221

190-
public fun send(content: Content, outputModalities: List<ContentModality>): Flow<LiveContentResponse> {
191-
return flow {
192-
val jsonString =
193-
Json.encodeToString(
194-
ClientContentSetup(
195-
ClientContent(listOf(content.toInternal()), true)
196-
)
222+
public suspend fun send(content: Content){
223+
val jsonString =
224+
Json.encodeToString(
225+
ClientContentSetup(
226+
ClientContent(listOf(content.toInternal()), true)
197227
)
198-
session?.send(Frame.Text(jsonString))
199-
while (true) {
200-
try {
201-
val message = session?.incoming?.receive() ?: continue
202-
val receivedBytes = (message as Frame.Binary).readBytes()
203-
val receivedJson = receivedBytes.toString(Charsets.UTF_8)
204-
println(receivedBytes)
205-
try {
206-
val functionContent = Json.decodeFromString<ToolCallSetup>(receivedJson)
207-
emit(LiveContentResponse(null, false, functionContent.toolCall.functionCalls.map { FunctionCallPart(it.name, it.args!!) }))
208-
// val y = functionContent.toolCall.functionCalls.map { it.toPublic() as FunctionCallPart }
209-
// emit(LiveContentResponse(null, false, y))
210-
//emit(LiveContentResponse(null, functionContent.toolCall.functionCalls.map { it.toPublic() as FunctionCallPart })))
211-
break
212-
} catch (e: Exception){
213-
println(e.message)
214-
}
215-
if (receivedJson.contains("turnComplete")) {
216-
break
217-
}
218-
val serverContent = Json.decodeFromString<ServerContentSetup>(receivedJson)
219-
val data = serverContent.serverContent.modelTurn.toPublic()
220-
221-
if (outputModalities.contains(ContentModality.AUDIO)) {
222-
if (data.parts[0].asInlineDataPartOrNull()?.mimeType?.equals("audio/pcm") == true) {
223-
emit(LiveContentResponse(data, false, listOf()))
224-
}
225-
}
226-
if (outputModalities.contains(ContentModality.TEXT)) {
227-
if (data.parts[0] is TextPart) {
228-
emit(LiveContentResponse(data, false, null))
229-
}
230-
}
231-
} catch (e: Exception) {
232-
println(e.message)
233-
}
234-
}
235-
}
228+
)
229+
println(jsonString)
230+
session?.send(Frame.Text(jsonString))
236231
}
237-
public fun send(text: String, outputModalities: List<ContentModality>): Flow<LiveContentResponse> {
238-
return send(Content.Builder().text(text).build(), outputModalities)
232+
public suspend fun send(text: String){
233+
send(Content.Builder().text(text).build())
239234

240235
}
241236

firebase-vertexai/src/test/java/com/google/firebase/vertexai/LiveModelTesting.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ class LiveModelTesting {
5757
)
5858
runBlocking{
5959
val session = generativeModel.connect()
60-
session!!.send("Tell me a story", listOf(ContentModality.TEXT)).collect {
60+
session!!.send("Tell me a story")
61+
session!!.receive(listOf(ContentModality.TEXT)).collect {
6162
println(it.text)
6263
}
6364

0 commit comments

Comments
 (0)