Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 51 additions & 7 deletions packages/firebase_ai/firebase_ai/example/lib/pages/bidi_page.dart
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class _BidiPageState extends State<BidiPage> {
StreamController<bool> _stopController = StreamController<bool>();
final AudioOutput _audioOutput = AudioOutput();
final AudioInput _audioInput = AudioInput();
int? _inputTranscriptionMessageIndex;
int? _outputTranscriptionMessageIndex;

@override
void initState() {
Expand All @@ -67,6 +69,8 @@ class _BidiPageState extends State<BidiPage> {
responseModalities: [
ResponseModalities.audio,
],
inputAudioTranscription: AudioTranscriptionConfig(),
outputAudioTranscription: AudioTranscriptionConfig(),
);

// ignore: deprecated_member_use
Expand Down Expand Up @@ -131,14 +135,17 @@ class _BidiPageState extends State<BidiPage> {
child: ListView.builder(
controller: _scrollController,
itemBuilder: (context, idx) {
final message = _messages[idx];
return MessageWidget(
text: _messages[idx].text,
image: Image.memory(
_messages[idx].imageBytes!,
cacheWidth: 400,
cacheHeight: 400,
),
isFromUser: _messages[idx].fromUser ?? false,
text: message.text,
image: message.imageBytes != null
? Image.memory(
message.imageBytes!,
cacheWidth: 400,
cacheHeight: 400,
)
: null,
isFromUser: message.fromUser ?? false,
);
},
itemCount: _messages.length,
Expand Down Expand Up @@ -354,6 +361,43 @@ class _BidiPageState extends State<BidiPage> {
if (message.modelTurn != null) {
await _handleLiveServerContent(message);
}

if (message.inputTranscription?.text != null) {
final transcription = message.inputTranscription!;
if (_inputTranscriptionMessageIndex != null) {
// TODO(cynthia): find a better way to update the message
_messages[_inputTranscriptionMessageIndex!].text =
'${_messages[_inputTranscriptionMessageIndex!].text}${transcription.text!}';
} else {
_messages.add(MessageData(
text: 'Input transcription: ${transcription.text!}',
fromUser: true));
_inputTranscriptionMessageIndex = _messages.length - 1;
}
if (transcription.finished ?? false) {
_inputTranscriptionMessageIndex = null;
}
setState(_scrollDown);
}
if (message.outputTranscription?.text != null) {
final transcription = message.outputTranscription!;
if (_outputTranscriptionMessageIndex != null) {
_messages[_outputTranscriptionMessageIndex!].text =
'${_messages[_outputTranscriptionMessageIndex!].text}${transcription.text!}';
} else {
_messages.add(
MessageData(
text: 'Output transcription: ${transcription.text!}',
fromUser: false,
),
);
_outputTranscriptionMessageIndex = _messages.length - 1;
}
if (transcription.finished ?? false) {
_outputTranscriptionMessageIndex = null;
}
setState(_scrollDown);
}
if (message.interrupted != null && message.interrupted!) {
developer.log('Interrupted: $response');
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class MessageData {
this.isThought = false,
});
final Uint8List? imageBytes;
final String? text;
String? text;
final bool? fromUser;
final bool isThought;
}
Expand Down
1 change: 1 addition & 0 deletions packages/firebase_ai/firebase_ai/lib/firebase_ai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ export 'src/live_api.dart'
show
LiveGenerationConfig,
SpeechConfig,
AudioTranscriptionConfig,
LiveServerMessage,
LiveServerContent,
LiveServerToolCall,
Expand Down
75 changes: 73 additions & 2 deletions packages/firebase_ai/firebase_ai/lib/src/live_api.dart
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,19 @@ class SpeechConfig {
};
}

/// The audio transcription configuration.
class AudioTranscriptionConfig {
// ignore: public_member_api_docs
Map<String, Object?> toJson() => {};
}

/// Configures live generation settings.
final class LiveGenerationConfig extends BaseGenerationConfig {
// ignore: public_member_api_docs
LiveGenerationConfig({
this.speechConfig,
this.inputAudioTranscription,
this.outputAudioTranscription,
super.responseModalities,
super.maxOutputTokens,
super.temperature,
Expand All @@ -88,6 +96,13 @@ final class LiveGenerationConfig extends BaseGenerationConfig {
/// The speech configuration.
final SpeechConfig? speechConfig;

/// The transcription of the input aligns with the input audio language.
final AudioTranscriptionConfig? inputAudioTranscription;

/// The transcription of the output aligns with the language code specified for
/// the output audio.
final AudioTranscriptionConfig? outputAudioTranscription;

@override
Map<String, Object?> toJson() => {
...super.toJson(),
Expand All @@ -109,14 +124,33 @@ sealed class LiveServerMessage {}
/// with the live server has finished successfully.
class LiveServerSetupComplete implements LiveServerMessage {}

/// Audio transcription message.
class Transcription {
// ignore: public_member_api_docs
const Transcription({this.text, this.finished});

/// Transcription text.
final String? text;

/// Whether this is the end of the transcription.
final bool? finished;
}

/// Content generated by the model in a live stream.
class LiveServerContent implements LiveServerMessage {
/// Creates a [LiveServerContent] instance.
///
/// [modelTurn] (optional): The content generated by the model.
/// [turnComplete] (optional): Indicates if the turn is complete.
/// [interrupted] (optional): Indicates if the generation was interrupted.
LiveServerContent({this.modelTurn, this.turnComplete, this.interrupted});
/// [inputTranscription] (optional): The input transcription.
/// [outputTranscription] (optional): The output transcription.
LiveServerContent(
{this.modelTurn,
this.turnComplete,
this.interrupted,
this.inputTranscription,
this.outputTranscription});

// TODO(cynthia): Add accessor for media content
/// The content generated by the model.
Expand All @@ -129,6 +163,18 @@ class LiveServerContent implements LiveServerMessage {
/// Whether generation was interrupted. If true, indicates that a
/// client message has interrupted current model
final bool? interrupted;

/// The input transcription.
///
/// The transcription is independent to the model turn which means it doesn't
/// imply any ordering between transcription and model turn.
final Transcription? inputTranscription;

/// The output transcription.
///
/// The transcription is independent to the model turn which means it doesn't
/// imply any ordering between transcription and model turn.
final Transcription? outputTranscription;
}

/// A tool call in a live stream.
Expand Down Expand Up @@ -306,7 +352,32 @@ LiveServerMessage _parseServerMessage(Object jsonObject) {
if (serverContentJson.containsKey('turnComplete')) {
turnComplete = serverContentJson['turnComplete'] as bool;
}
return LiveServerContent(modelTurn: modelTurn, turnComplete: turnComplete);
final interrupted = serverContentJson['interrupted'] as bool?;
Transcription? inputTranscription;
if (serverContentJson.containsKey('inputTranscription')) {
final transcriptionJson =
serverContentJson['inputTranscription'] as Map<String, dynamic>;
inputTranscription = Transcription(
text: transcriptionJson['text'] as String?,
finished: transcriptionJson['finished'] as bool?,
);
}
Transcription? outputTranscription;
if (serverContentJson.containsKey('outputTranscription')) {
final transcriptionJson =
serverContentJson['outputTranscription'] as Map<String, dynamic>;
outputTranscription = Transcription(
text: transcriptionJson['text'] as String?,
finished: transcriptionJson['finished'] as bool?,
);
}
return LiveServerContent(
modelTurn: modelTurn,
turnComplete: turnComplete,
interrupted: interrupted,
inputTranscription: inputTranscription,
outputTranscription: outputTranscription,
);
} else if (json.containsKey('toolCall')) {
final toolContentJson = json['toolCall'] as Map<String, dynamic>;
List<FunctionCall> functionCalls = [];
Expand Down
8 changes: 8 additions & 0 deletions packages/firebase_ai/firebase_ai/lib/src/live_model.dart
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,14 @@ final class LiveGenerativeModel extends BaseModel {
if (_systemInstruction != null)
'system_instruction': _systemInstruction.toJson(),
if (_tools != null) 'tools': _tools.map((t) => t.toJson()).toList(),
if (_liveGenerationConfig != null &&
_liveGenerationConfig.inputAudioTranscription != null)
'input_audio_transcription':
_liveGenerationConfig.inputAudioTranscription!.toJson(),
if (_liveGenerationConfig != null &&
_liveGenerationConfig.outputAudioTranscription != null)
'output_audio_transcription':
_liveGenerationConfig.outputAudioTranscription!.toJson(),
}
};

Expand Down
Loading