diff --git a/packages/firebase_ai/firebase_ai/example/lib/pages/bidi_page.dart b/packages/firebase_ai/firebase_ai/example/lib/pages/bidi_page.dart index 3d86a4c4b04c..6004ba09d6cd 100644 --- a/packages/firebase_ai/firebase_ai/example/lib/pages/bidi_page.dart +++ b/packages/firebase_ai/firebase_ai/example/lib/pages/bidi_page.dart @@ -57,6 +57,8 @@ class _BidiPageState extends State { StreamController _stopController = StreamController(); final AudioOutput _audioOutput = AudioOutput(); final AudioInput _audioInput = AudioInput(); + int? _inputTranscriptionMessageIndex; + int? _outputTranscriptionMessageIndex; @override void initState() { @@ -67,6 +69,8 @@ class _BidiPageState extends State { responseModalities: [ ResponseModalities.audio, ], + inputAudioTranscription: AudioTranscriptionConfig(), + outputAudioTranscription: AudioTranscriptionConfig(), ); // ignore: deprecated_member_use @@ -131,14 +135,17 @@ class _BidiPageState extends State { child: ListView.builder( controller: _scrollController, itemBuilder: (context, idx) { + final message = _messages[idx]; return MessageWidget( - text: _messages[idx].text, - image: Image.memory( - _messages[idx].imageBytes!, - cacheWidth: 400, - cacheHeight: 400, - ), - isFromUser: _messages[idx].fromUser ?? false, + text: message.text, + image: message.imageBytes != null + ? Image.memory( + message.imageBytes!, + cacheWidth: 400, + cacheHeight: 400, + ) + : null, + isFromUser: message.fromUser ?? false, ); }, itemCount: _messages.length, @@ -354,6 +361,43 @@ class _BidiPageState extends State { if (message.modelTurn != null) { await _handleLiveServerContent(message); } + + if (message.inputTranscription?.text != null) { + final transcription = message.inputTranscription!; + if (_inputTranscriptionMessageIndex != null) { + // TODO(cynthia): find a better way to update the message + _messages[_inputTranscriptionMessageIndex!].text = + '${_messages[_inputTranscriptionMessageIndex!].text}${transcription.text!}'; + } else { + _messages.add(MessageData( + text: 'Input transcription: ${transcription.text!}', + fromUser: true)); + _inputTranscriptionMessageIndex = _messages.length - 1; + } + if (transcription.finished ?? false) { + _inputTranscriptionMessageIndex = null; + } + setState(_scrollDown); + } + if (message.outputTranscription?.text != null) { + final transcription = message.outputTranscription!; + if (_outputTranscriptionMessageIndex != null) { + _messages[_outputTranscriptionMessageIndex!].text = + '${_messages[_outputTranscriptionMessageIndex!].text}${transcription.text!}'; + } else { + _messages.add( + MessageData( + text: 'Output transcription: ${transcription.text!}', + fromUser: false, + ), + ); + _outputTranscriptionMessageIndex = _messages.length - 1; + } + if (transcription.finished ?? false) { + _outputTranscriptionMessageIndex = null; + } + setState(_scrollDown); + } if (message.interrupted != null && message.interrupted!) { developer.log('Interrupted: $response'); } diff --git a/packages/firebase_ai/firebase_ai/example/lib/widgets/message_widget.dart b/packages/firebase_ai/firebase_ai/example/lib/widgets/message_widget.dart index 7ea588557ce3..03a0105419a1 100644 --- a/packages/firebase_ai/firebase_ai/example/lib/widgets/message_widget.dart +++ b/packages/firebase_ai/firebase_ai/example/lib/widgets/message_widget.dart @@ -23,7 +23,7 @@ class MessageData { this.isThought = false, }); final Uint8List? imageBytes; - final String? text; + String? text; final bool? fromUser; final bool isThought; } diff --git a/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart b/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart index 0b6c4735caee..645807488b6d 100644 --- a/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart +++ b/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart @@ -91,6 +91,7 @@ export 'src/live_api.dart' show LiveGenerationConfig, SpeechConfig, + AudioTranscriptionConfig, LiveServerMessage, LiveServerContent, LiveServerToolCall, diff --git a/packages/firebase_ai/firebase_ai/lib/src/live_api.dart b/packages/firebase_ai/firebase_ai/lib/src/live_api.dart index c77d2037a993..ed4d423ea209 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/live_api.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/live_api.dart @@ -71,11 +71,19 @@ class SpeechConfig { }; } +/// The audio transcription configuration. +class AudioTranscriptionConfig { + // ignore: public_member_api_docs + Map toJson() => {}; +} + /// Configures live generation settings. final class LiveGenerationConfig extends BaseGenerationConfig { // ignore: public_member_api_docs LiveGenerationConfig({ this.speechConfig, + this.inputAudioTranscription, + this.outputAudioTranscription, super.responseModalities, super.maxOutputTokens, super.temperature, @@ -88,6 +96,13 @@ final class LiveGenerationConfig extends BaseGenerationConfig { /// The speech configuration. final SpeechConfig? speechConfig; + /// The transcription of the input aligns with the input audio language. + final AudioTranscriptionConfig? inputAudioTranscription; + + /// The transcription of the output aligns with the language code specified for + /// the output audio. + final AudioTranscriptionConfig? outputAudioTranscription; + @override Map toJson() => { ...super.toJson(), @@ -109,6 +124,18 @@ sealed class LiveServerMessage {} /// with the live server has finished successfully. class LiveServerSetupComplete implements LiveServerMessage {} +/// Audio transcription message. +class Transcription { + // ignore: public_member_api_docs + const Transcription({this.text, this.finished}); + + /// Transcription text. + final String? text; + + /// Whether this is the end of the transcription. + final bool? finished; +} + /// Content generated by the model in a live stream. class LiveServerContent implements LiveServerMessage { /// Creates a [LiveServerContent] instance. @@ -116,7 +143,14 @@ class LiveServerContent implements LiveServerMessage { /// [modelTurn] (optional): The content generated by the model. /// [turnComplete] (optional): Indicates if the turn is complete. /// [interrupted] (optional): Indicates if the generation was interrupted. - LiveServerContent({this.modelTurn, this.turnComplete, this.interrupted}); + /// [inputTranscription] (optional): The input transcription. + /// [outputTranscription] (optional): The output transcription. + LiveServerContent( + {this.modelTurn, + this.turnComplete, + this.interrupted, + this.inputTranscription, + this.outputTranscription}); // TODO(cynthia): Add accessor for media content /// The content generated by the model. @@ -129,6 +163,18 @@ class LiveServerContent implements LiveServerMessage { /// Whether generation was interrupted. If true, indicates that a /// client message has interrupted current model final bool? interrupted; + + /// The input transcription. + /// + /// The transcription is independent to the model turn which means it doesn't + /// imply any ordering between transcription and model turn. + final Transcription? inputTranscription; + + /// The output transcription. + /// + /// The transcription is independent to the model turn which means it doesn't + /// imply any ordering between transcription and model turn. + final Transcription? outputTranscription; } /// A tool call in a live stream. @@ -306,7 +352,32 @@ LiveServerMessage _parseServerMessage(Object jsonObject) { if (serverContentJson.containsKey('turnComplete')) { turnComplete = serverContentJson['turnComplete'] as bool; } - return LiveServerContent(modelTurn: modelTurn, turnComplete: turnComplete); + final interrupted = serverContentJson['interrupted'] as bool?; + Transcription? inputTranscription; + if (serverContentJson.containsKey('inputTranscription')) { + final transcriptionJson = + serverContentJson['inputTranscription'] as Map; + inputTranscription = Transcription( + text: transcriptionJson['text'] as String?, + finished: transcriptionJson['finished'] as bool?, + ); + } + Transcription? outputTranscription; + if (serverContentJson.containsKey('outputTranscription')) { + final transcriptionJson = + serverContentJson['outputTranscription'] as Map; + outputTranscription = Transcription( + text: transcriptionJson['text'] as String?, + finished: transcriptionJson['finished'] as bool?, + ); + } + return LiveServerContent( + modelTurn: modelTurn, + turnComplete: turnComplete, + interrupted: interrupted, + inputTranscription: inputTranscription, + outputTranscription: outputTranscription, + ); } else if (json.containsKey('toolCall')) { final toolContentJson = json['toolCall'] as Map; List functionCalls = []; diff --git a/packages/firebase_ai/firebase_ai/lib/src/live_model.dart b/packages/firebase_ai/firebase_ai/lib/src/live_model.dart index df7fbc8b3add..5a71e309dcd0 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/live_model.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/live_model.dart @@ -106,6 +106,14 @@ final class LiveGenerativeModel extends BaseModel { if (_systemInstruction != null) 'system_instruction': _systemInstruction.toJson(), if (_tools != null) 'tools': _tools.map((t) => t.toJson()).toList(), + if (_liveGenerationConfig != null && + _liveGenerationConfig.inputAudioTranscription != null) + 'input_audio_transcription': + _liveGenerationConfig.inputAudioTranscription!.toJson(), + if (_liveGenerationConfig != null && + _liveGenerationConfig.outputAudioTranscription != null) + 'output_audio_transcription': + _liveGenerationConfig.outputAudioTranscription!.toJson(), } };