Skip to content

Commit 8442dd6

Browse files
committed
finish bidi
1 parent 7ef4910 commit 8442dd6

File tree

6 files changed

+1254
-1197
lines changed

6 files changed

+1254
-1197
lines changed

firebase-vertexai/gradle.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
version=17.0.0
15+
version=17.1.0
1616
latestReleasedVersion=16.2.0

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/AudioHelper.kt

Lines changed: 28 additions & 1150 deletions
Large diffs are not rendered by default.

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Base64Strings.kt

Lines changed: 1127 additions & 0 deletions
Large diffs are not rendered by default.

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/LiveSession.kt

Lines changed: 95 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
package com.google.firebase.vertexai.type
22

33
import android.media.AudioFormat
4-
import android.media.AudioManager
54
import android.media.AudioTrack
6-
import android.util.Base64
75
import io.ktor.client.plugins.websocket.ClientWebSocketSession
86
import io.ktor.websocket.Frame
97
import io.ktor.websocket.close
108
import io.ktor.websocket.readBytes
9+
import kotlinx.coroutines.CoroutineScope
10+
import kotlinx.coroutines.Dispatchers
11+
import kotlinx.coroutines.cancel
1112
import kotlinx.coroutines.flow.Flow
1213
import kotlinx.coroutines.flow.flow
14+
import kotlinx.coroutines.launch
1315
import kotlinx.serialization.SerialName
1416
import kotlinx.serialization.Serializable
1517
import kotlinx.serialization.encodeToString
1618
import kotlinx.serialization.json.Json
19+
import java.util.concurrent.ConcurrentLinkedQueue
1720

1821
public class LiveSession
1922
internal constructor(
@@ -22,12 +25,16 @@ internal constructor(
2225
private var audioHelper: AudioHelper? = null
2326
) {
2427

28+
private val audioQueue = ConcurrentLinkedQueue<ByteArray>()
29+
private val playBackQueue = ConcurrentLinkedQueue<ByteArray>()
30+
2531
@Serializable
2632
internal data class ClientContent(
2733
@SerialName("turns") val turns: List<Content.Internal>,
2834
@SerialName("turn_complete") val turnComplete: Boolean
2935
)
3036

37+
3138
@Serializable
3239
internal data class ClientContentSetup(
3340
@SerialName("client_content") val clientContent: ClientContent
@@ -49,70 +56,114 @@ internal constructor(
4956
if(isRecording) { return }
5057
isRecording = true
5158
audioHelper = AudioHelper()
52-
val minBufferSize = AudioTrack.getMinBufferSize(24000, AudioFormat.CHANNEL_OUT_MONO, AudioFormat.ENCODING_PCM_16BIT)
53-
var bytesRead = 0
54-
val chunkSize = minBufferSize
55-
var recordedData = ByteArray(2*chunkSize)
5659
audioHelper!!.setupAudioTrack()
5760

58-
audioHelper!!.startRecording().collect {
59-
x ->
60-
run {
61-
bytesRead += x.size
62-
recordedData += x
63-
if(bytesRead>=0) {
64-
println("BytesRead:")
65-
println(Base64.encodeToString(recordedData, Base64.NO_WRAP))
66-
sendMediaStream(listOf(MediaData("audio/pcm", x)), listOf(ContentModality.AUDIO)).collect {
67-
y ->
68-
run {
69-
val audioData = y.parts[0].asInlineDataPartOrNull()!!.inlineData
70-
audioHelper!!.playAudio(audioData)
71-
}
72-
}
73-
recordedData = byteArrayOf()
74-
bytesRead = 0
61+
CoroutineScope(Dispatchers.Default).launch {
62+
audioHelper!!.startRecording().collect {
63+
if(!isRecording) {
64+
cancel()
65+
}
66+
audioQueue.add(it)
67+
}
68+
}
69+
val minBufferSize = AudioTrack.getMinBufferSize(24000, AudioFormat.CHANNEL_OUT_MONO, AudioFormat.ENCODING_PCM_16BIT)
70+
var bytesRead = 0
71+
var recordedData = ByteArray(minBufferSize * 2)
72+
CoroutineScope(Dispatchers.Default).launch {
73+
while(true) {
74+
if(!isRecording) {
75+
break
76+
}
77+
val byteArr = audioQueue.poll()
78+
if(byteArr!=null) {
79+
bytesRead += byteArr.size
80+
recordedData += byteArr
81+
if (bytesRead >= minBufferSize) {
82+
sendMediaStream(
83+
listOf(MediaData("audio/pcm", recordedData)),
84+
listOf(ContentModality.AUDIO)
85+
)
86+
bytesRead = 0
87+
recordedData = byteArrayOf()
88+
}
89+
} else {
90+
continue
91+
}
92+
}
93+
}
94+
CoroutineScope(Dispatchers.Default).launch {
95+
receiveMediaStream().collect {
96+
if(!isRecording) {
97+
cancel()
98+
}
99+
if(it.interrupted) {
100+
while(!playBackQueue.isEmpty()) playBackQueue.poll()
101+
} else {
102+
playBackQueue.add(it.data!!.parts[0].asInlineDataPartOrNull()!!.inlineData)
103+
}
104+
}
105+
}
106+
CoroutineScope(Dispatchers.Default).launch {
107+
while(true) {
108+
if(!isRecording) {
109+
break
110+
}
111+
val x = playBackQueue.poll()
112+
if(x!=null) {
113+
audioHelper!!.playAudio(x)
75114
}
76115
}
77-
78116
}
79-
80117
}
81118

82119
public fun stopAudioConversation() {
83120
isRecording = false
84121
if(audioHelper!=null) {
122+
while(!playBackQueue.isEmpty()) playBackQueue.poll()
123+
while(!audioQueue.isEmpty()) audioQueue.poll()
85124
audioHelper!!.release()
125+
audioHelper = null
86126
}
87-
88127
}
89-
public fun sendMediaStream(
90-
mediaChunks: List<MediaData>,
91-
outputModalities: List<ContentModality>
92-
): Flow<Content> {
128+
129+
public suspend fun receiveMediaStream(): Flow<StreamOutput> {
93130
return flow {
94-
val jsonString = Json.encodeToString(MediaStreamingSetup(MediaChunks(mediaChunks.map { it.toInternal() })))
95-
println("JsonString: $jsonString")
96-
session?.send(Frame.Text(jsonString))
97131
while (true) {
132+
val message = session!!.incoming.receive()
133+
val receivedBytes =
134+
(message as Frame.Binary).readBytes()
135+
val receivedJson = receivedBytes.toString(Charsets.UTF_8)
98136
try {
99-
val message = session?.incoming?.receive() ?: continue
100-
val receivedBytes = (message as Frame.Binary).readBytes()
101-
val receivedJson = receivedBytes.toString(Charsets.UTF_8)
102-
println("Receivedjson: $receivedJson")
103-
if (receivedJson.contains("turnComplete")) {
104-
break
137+
if (receivedJson.contains("interrupted")) {
138+
emit(StreamOutput(true, null))
139+
continue
105140
}
106-
val serverContent = Json.decodeFromString<ServerContentSetup>(receivedJson)
107-
val audioData = serverContent.serverContent.modelTurn.toPublic()
108-
emit(audioData)
109-
} catch (_: Exception) {}
141+
val serverContent =
142+
Json.decodeFromString<ServerContentSetup>(
143+
receivedJson
144+
)
145+
val audioData =
146+
serverContent.serverContent.modelTurn.toPublic()
147+
emit(StreamOutput(false, audioData))
148+
149+
150+
} catch (e: Exception) {
151+
println("Exception: $e.message")
152+
}
110153
}
111154
}
112155
}
113156

114-
public fun send(text: String, outputModalities: List<ContentModality>): Flow<Content> {
115157

158+
public suspend fun sendMediaStream(
159+
mediaChunks: List<MediaData>,
160+
outputModalities: List<ContentModality>
161+
) {
162+
val jsonString = Json.encodeToString(MediaStreamingSetup(MediaChunks(mediaChunks.map { it.toInternal() })))
163+
session?.send(Frame.Text(jsonString))
164+
}
165+
166+
public fun send(text: String, outputModalities: List<ContentModality>): Flow<Content> {
116167
return flow {
117168
val jsonString =
118169
Json.encodeToString(

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,7 @@ internal fun InternalPart.toPublic(): Part {
187187
return when (this) {
188188
is TextPart.Internal -> TextPart(text)
189189
is InlineDataPart.Internal -> {
190-
println(inlineData.data)
191190
val data = android.util.Base64.decode(inlineData.data, android.util.Base64.DEFAULT)
192-
println(data)
193191
if (inlineData.mimeType.contains("image")) {
194192
ImagePart(decodeBitmapFromImage(data))
195193
} else {
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package com.google.firebase.vertexai.type
2+
3+
public class StreamOutput(public val interrupted: Boolean,public val data: Content?)

0 commit comments

Comments
 (0)