Skip to content

Commit 4b4ec66

Browse files
committed
update
1 parent 1379e3d commit 4b4ec66

File tree

2 files changed

+21
-34
lines changed

2 files changed

+21
-34
lines changed

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/java/LiveSessionFutures.kt

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,13 @@ public abstract class LiveSessionFutures internal constructor() {
4040
/**
4141
* Starts an audio conversation with the Gemini server, which can only be stopped using
4242
* stopAudioConversation.
43+
*
44+
* @param functionCallsHandler A callback function that is invoked whenever the server receives a
45+
* function call.
4346
*/
44-
public abstract fun startAudioConversation(): ListenableFuture<Unit>
47+
public abstract fun startAudioConversation(
48+
functionCallsHandler: ((List<FunctionCallPart>) -> List<FunctionResponsePart>)?
49+
): ListenableFuture<Unit>
4550

4651
/** Stops the audio conversation with the Gemini Server. */
4752
public abstract fun stopAudioConversation(): ListenableFuture<Unit>
@@ -96,22 +101,11 @@ public abstract class LiveSessionFutures internal constructor() {
96101
outputModalities: List<ContentModality>
97102
): Publisher<LiveContentResponse>
98103

99-
/**
100-
* Receives all function call responses from the server for the audio conversation feature..
101-
*
102-
* @return A [Publisher] which will emit list of [FunctionCallPart] as they are returned by the
103-
* model.
104-
*/
105-
public abstract fun receiveAudioConversationFunctionCalls(): Publisher<List<FunctionCallPart>>
106-
107104
private class FuturesImpl(private val session: LiveSession) : LiveSessionFutures() {
108105

109106
override fun receive(outputModalities: List<ContentModality>): Publisher<LiveContentResponse> =
110107
session.receive(outputModalities).asPublisher()
111108

112-
override fun receiveAudioConversationFunctionCalls(): Publisher<List<FunctionCallPart>> =
113-
session.receiveAudioConversationFunctionCalls().asPublisher()
114-
115109
override fun close(): ListenableFuture<Unit> =
116110
SuspendToFutureAdapter.launchFuture { session.close() }
117111

@@ -126,8 +120,9 @@ public abstract class LiveSessionFutures internal constructor() {
126120
override fun sendMediaStream(mediaChunks: List<MediaData>) =
127121
SuspendToFutureAdapter.launchFuture { session.sendMediaStream(mediaChunks) }
128122

129-
override fun startAudioConversation() =
130-
SuspendToFutureAdapter.launchFuture { session.startAudioConversation() }
123+
override fun startAudioConversation(
124+
functionCallsHandler: ((List<FunctionCallPart>) -> List<FunctionResponsePart>)?
125+
) = SuspendToFutureAdapter.launchFuture { session.startAudioConversation(functionCallsHandler) }
131126

132127
override fun stopAudioConversation() =
133128
SuspendToFutureAdapter.launchFuture { session.stopAudioConversation() }

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/LiveSession.kt

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ import kotlin.coroutines.CoroutineContext
2929
import kotlinx.coroutines.CoroutineScope
3030
import kotlinx.coroutines.cancel
3131
import kotlinx.coroutines.channels.Channel
32-
import kotlinx.coroutines.delay
3332
import kotlinx.coroutines.flow.Flow
3433
import kotlinx.coroutines.flow.flow
3534
import kotlinx.coroutines.flow.receiveAsFlow
@@ -53,7 +52,6 @@ internal constructor(
5352
private val playBackQueue = ConcurrentLinkedQueue<ByteArray>()
5453
private var startedReceiving = false
5554
private var receiveChannel: Channel<Frame> = Channel()
56-
private var functionCallChannel: Channel<List<FunctionCallPart>> = Channel()
5755

5856
private companion object {
5957
val TAG = LiveSession::class.java.simpleName
@@ -127,16 +125,6 @@ internal constructor(
127125
}
128126
}
129127

130-
/**
131-
* Receives all function call responses from the server for the audio conversation feature. This
132-
* can be called only after calling [startAudioConversation] function.
133-
*
134-
* @return A [Flow] which will emit list of [FunctionCallPart] as they are returned by the model.
135-
*/
136-
public fun receiveAudioConversationFunctionCalls(): Flow<List<FunctionCallPart>> {
137-
return functionCallChannel.receiveAsFlow()
138-
}
139-
140128
private fun fillRecordedAudioQueue() {
141129
CoroutineScope(backgroundDispatcher).launch {
142130
audioHelper!!.startRecording().collect {
@@ -163,7 +151,9 @@ internal constructor(
163151
}
164152
}
165153

166-
private fun fillServerResponseAudioQueue() {
154+
private fun fillServerResponseAudioQueue(
155+
functionCallsHandler: ((List<FunctionCallPart>) -> List<FunctionResponsePart>)? = null
156+
) {
167157
CoroutineScope(backgroundDispatcher).launch {
168158
receive(listOf(ContentModality.AUDIO)).collect {
169159
if (!isRecording) {
@@ -173,8 +163,8 @@ internal constructor(
173163
LiveContentResponse.Status.INTERRUPTED ->
174164
while (!playBackQueue.isEmpty()) playBackQueue.poll()
175165
LiveContentResponse.Status.NORMAL ->
176-
if (!it.functionCalls.isNullOrEmpty()) {
177-
functionCallChannel.send(it.functionCalls)
166+
if (!it.functionCalls.isNullOrEmpty() && functionCallsHandler != null) {
167+
sendFunctionResponse(functionCallsHandler(it.functionCalls))
178168
} else {
179169
val audioData = it.data?.parts?.get(0)?.asInlineDataPartOrNull()?.inlineData
180170
if (audioData != null) {
@@ -198,22 +188,24 @@ internal constructor(
198188
/**
199189
* Starts an audio conversation with the Gemini server, which can only be stopped using
200190
* [stopAudioConversation].
191+
*
192+
* @param functionCallsHandler A callback function that is invoked whenever the server receives a
193+
* function call.
201194
*/
202-
public suspend fun startAudioConversation() {
195+
public suspend fun startAudioConversation(
196+
functionCallsHandler: ((List<FunctionCallPart>) -> List<FunctionResponsePart>)? = null
197+
) {
203198
if (isRecording) {
204199
Log.w(TAG, "startAudioConversation called after the recording has already started.")
205200
return
206201
}
207-
functionCallChannel = Channel()
208202
isRecording = true
209203
audioHelper = AudioHelper()
210204
audioHelper!!.setupAudioTrack()
211205
fillRecordedAudioQueue()
212206
CoroutineScope(backgroundDispatcher).launch { sendAudioDataToServer() }
213-
fillServerResponseAudioQueue()
207+
fillServerResponseAudioQueue(functionCallsHandler)
214208
playServerResponseAudio()
215-
// This delay is necessary to ensure that all threads have started.
216-
delay(1000)
217209
}
218210

219211
/**

0 commit comments

Comments
 (0)