diff --git a/examples/travel_app/lib/main.dart b/examples/travel_app/lib/main.dart index 9901c0f9d..fe096eb23 100644 --- a/examples/travel_app/lib/main.dart +++ b/examples/travel_app/lib/main.dart @@ -2,18 +2,22 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -import 'package:dart_schema_builder/dart_schema_builder.dart'; +import 'package:firebase_ai/firebase_ai.dart'; import 'package:firebase_app_check/firebase_app_check.dart'; import 'package:firebase_core/firebase_core.dart'; import 'package:flutter/material.dart'; -import 'package:flutter_genui/flutter_genui.dart'; +import 'package:flutter_genui/flutter_genui.dart' hide ChatMessage, TextPart; import 'package:logging/logging.dart'; import 'firebase_options.dart'; import 'src/asset_images.dart'; import 'src/catalog.dart'; +import 'src/gemini_client.dart'; +import 'src/turn.dart'; import 'src/widgets/conversation.dart'; +final _logger = Logger('TravelApp'); + void main() async { WidgetsFlutterBinding.ensureInitialized(); await Firebase.initializeApp(options: DefaultFirebaseOptions.currentPlatform); @@ -24,9 +28,22 @@ void main() async { ); _imagesJson = await assetImageCatalogJson(); configureGenUiLogging(level: Level.ALL); + _configureLogging(); runApp(const TravelApp()); } +void _configureLogging() { + hierarchicalLoggingEnabled = true; + Logger.root.level = Level.ALL; + Logger.root.onRecord.listen((record) { + // ignore: avoid_print + print( + '[${record.level.name}] ${record.time}: ' + '${record.loggerName}: ${record.message}', + ); + }); +} + /// The root widget for the travel application. /// /// This widget sets up the [MaterialApp], which configures the overall theme, @@ -34,16 +51,7 @@ void main() async { /// user interface. class TravelApp extends StatelessWidget { /// Creates a new [TravelApp]. - /// - /// The optional [aiClient] can be used to inject a specific AI client, - /// which is useful for testing with a mock implementation. - const TravelApp({this.aiClient, super.key}); - - /// The AI client to use for the application. - /// - /// If null, a default [FirebaseAiClient] will be created by the - /// [TravelPlannerPage]. - final AiClient? aiClient; + const TravelApp({super.key}); @override Widget build(BuildContext context) { @@ -53,7 +61,7 @@ class TravelApp extends StatelessWidget { theme: ThemeData( colorScheme: ColorScheme.fromSeed(seedColor: Colors.blue), ), - home: TravelPlannerPage(aiClient: aiClient), + home: const TravelPlannerPage(), ); } } @@ -70,17 +78,7 @@ class TravelApp extends StatelessWidget { /// generated UI, and a menu to switch between different AI models. class TravelPlannerPage extends StatefulWidget { /// Creates a new [TravelPlannerPage]. - /// - /// An optional [aiClient] can be provided, which is useful for testing - /// or using a custom AI client implementation. If not provided, a default - /// [FirebaseAiClient] is created. - const TravelPlannerPage({this.aiClient, super.key}); - - /// The AI client to use for the application. - /// - /// If null, a default instance of [FirebaseAiClient] will be created within - /// the page's state. - final AiClient? aiClient; + const TravelPlannerPage({super.key}); @override State createState() => _TravelPlannerPageState(); @@ -88,9 +86,9 @@ class TravelPlannerPage extends StatefulWidget { class _TravelPlannerPageState extends State { late final GenUiManager _genUiManager; - late final AiClient _aiClient; + late final GeminiClient _geminiClient; late final UiEventManager _eventManager; - final List _conversation = []; + final List _conversation = []; final _textController = TextEditingController(); final _scrollController = ScrollController(); bool _isThinking = false; @@ -109,31 +107,29 @@ class _TravelPlannerPageState extends State { ), ); _eventManager = UiEventManager(callback: _onUiEvents); - _aiClient = - widget.aiClient ?? - FirebaseAiClient( - tools: _genUiManager.getTools(), - systemInstruction: prompt, - ); + _geminiClient = GeminiClient( + tools: _genUiManager.getTools(), + systemInstruction: prompt, + ); _genUiManager.surfaceUpdates.listen((update) { setState(() { switch (update) { case SurfaceAdded(:final surfaceId, :final definition): _conversation.add( - AiUiMessage(definition: definition, surfaceId: surfaceId), + GenUiTurn(definition: definition, surfaceId: surfaceId), ); _scrollToBottom(); case SurfaceRemoved(:final surfaceId): _conversation.removeWhere( - (m) => m is AiUiMessage && m.surfaceId == surfaceId, + (m) => m is GenUiTurn && m.surfaceId == surfaceId, ); case SurfaceUpdated(:final surfaceId, :final definition): final index = _conversation.lastIndexWhere( - (m) => m is AiUiMessage && m.surfaceId == surfaceId, + (m) => m is GenUiTurn && m.surfaceId == surfaceId, ); if (index != -1) { - _conversation[index] = AiUiMessage( + _conversation[index] = GenUiTurn( definition: definition, surfaceId: surfaceId, ); @@ -169,32 +165,18 @@ class _TravelPlannerPageState extends State { _isThinking = true; }); try { - final result = await _aiClient.generateContent( - _conversation, - S.object( - properties: { - 'result': S.boolean( - description: 'Successfully generated a response UI.', - ), - 'message': S.string( - description: - 'A message about what went wrong, or a message responding to ' - 'the request. Take into account any UI that has been ' - "generated, so there's no need to duplicate requests or " - 'information already present in the UI.', - ), - }, - required: ['result'], - ), - ); - if (result == null) { - return; - } - final value = - (result as Map).cast()['message'] as String? ?? ''; + final contentHistory = _conversation + .map((turn) => turn.toContent()) + .whereType() + .toList(); + final result = await _geminiClient.generate(contentHistory); + final value = result.candidates.first.content.parts + .whereType() + .map((part) => part.text) + .join(''); if (value.isNotEmpty) { setState(() { - _conversation.add(AiTextMessage.text(value)); + _conversation.add(AiTextTurn(value)); }); _scrollToBottom(); } @@ -227,7 +209,7 @@ class _TravelPlannerPageState extends State { } setState(() { - _conversation.add(UserUiInteractionMessage.text(message.toString())); + _conversation.add(UserUiInteractionTurn(message.toString())); }); _scrollToBottom(); _triggerInference(); @@ -240,7 +222,7 @@ class _TravelPlannerPageState extends State { void _sendPrompt(String text) { if (_isThinking || text.trim().isEmpty) return; setState(() { - _conversation.add(UserMessage.text(text)); + _conversation.add(UserTurn(text)); }); _scrollToBottom(); _textController.clear(); @@ -386,6 +368,7 @@ to the user. 3. Create an initial itinerary, which will be iterated over in subsequent steps. This involves planning out each day of the trip, including the specific locations and draft activities. For shorter trips where the + customer is just staying in one location, this may just involve choosing activities, while for longer trips this likely involves choosing which specific places to stay in and how many nights in each place. @@ -447,10 +430,10 @@ because it avoids confusing the conversation with many versions of the same itinerary etc. When processing a user message or event, you should add or update one surface -and then call provideFinalOutput to return control to the user. Never continue -to add or update surfaces until you receive another user event. If the last -entry in the context is a functionResponse, just call provideFinalOutput -immediately - don't try to update the UI. +and then output an explanatory message to return control to the user. Never +continue to add or update surfaces until you receive another user event. +If the last entry in the context is a functionResponse from addOrUpdateSurface, +*do not* call addOrUpdateSurface again - just return. # UI style diff --git a/examples/travel_app/lib/src/catalog/text_input_chip.dart b/examples/travel_app/lib/src/catalog/text_input_chip.dart index a4deba394..8e5d4aa26 100644 --- a/examples/travel_app/lib/src/catalog/text_input_chip.dart +++ b/examples/travel_app/lib/src/catalog/text_input_chip.dart @@ -8,8 +8,8 @@ import 'package:flutter_genui/flutter_genui.dart'; final _schema = S.object( description: - 'An input chip used to ask the user to enter free text, e.g. to ' - 'select a destination. This should only be used inside an InputGroup.', + 'An input chip where the user enters free text, e.g. to ' + 'select a destination. This must only be used inside an InputGroup.', properties: { 'label': S.string(description: 'The label for the text input chip.'), 'initialValue': S.string( diff --git a/examples/travel_app/lib/src/gemini_client.dart b/examples/travel_app/lib/src/gemini_client.dart new file mode 100644 index 000000000..7e93465c6 --- /dev/null +++ b/examples/travel_app/lib/src/gemini_client.dart @@ -0,0 +1,109 @@ +// Copyright 2025 The Flutter Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import 'dart:convert'; + +import 'package:firebase_ai/firebase_ai.dart' as fai; +import 'package:flutter_genui/flutter_genui.dart'; +import 'package:flutter_genui/src/ai_client/gemini_schema_adapter.dart'; +import 'package:logging/logging.dart'; + +class GeminiClient { + GeminiClient({required this.tools, required String systemInstruction}) { + final functionDeclarations = []; + final adapter = GeminiSchemaAdapter(); + for (final tool in tools) { + fai.Schema? adaptedParameters; + if (tool.parameters != null) { + final result = adapter.adapt(tool.parameters!); + if (result.errors.isNotEmpty) { + _logger.warning( + 'Errors adapting parameters for tool ${tool.name}: ' + '${result.errors.join('\n')}', + ); + } + adaptedParameters = result.schema; + } + final parameters = adaptedParameters?.properties; + functionDeclarations.add( + fai.FunctionDeclaration( + tool.name, + tool.description, + parameters: parameters ?? const {}, + ), + ); + } + + _logger.info( + 'Registered tools: ${functionDeclarations.map((d) => d.toJson()).join(', ')}', + ); + + _model = fai.FirebaseAI.googleAI().generativeModel( + model: 'gemini-2.5-flash', + systemInstruction: fai.Content.system(systemInstruction), + tools: [fai.Tool.functionDeclarations(functionDeclarations)], + ); + } + + late final fai.GenerativeModel _model; + final List tools; + final _logger = Logger('GeminiClient'); + + Future generate( + Iterable history, + ) async { + final mutableHistory = List.of(history); + var toolUsageCycle = 0; + const maxToolUsageCycles = 10; + + while (toolUsageCycle < maxToolUsageCycles) { + toolUsageCycle++; + + final concatenatedContents = mutableHistory + .map((c) => const JsonEncoder.withIndent(' ').convert(c.toJson())) + .join('\n'); + + _logger.info( + '****** Performing Inference ******\n$concatenatedContents\n' + 'With functions:\n' + ' ${tools.map((t) => t.name).join(', ')}', + ); + + final inferenceStartTime = DateTime.now(); + final response = await _model.generateContent(mutableHistory); + final elapsed = DateTime.now().difference(inferenceStartTime); + + final candidate = response.candidates.first; + final content = candidate.content; + mutableHistory.add(content); + + _logger.info( + '****** Completed Inference ******\n' + 'Latency = ${elapsed.inMilliseconds}ms\n' + 'Output tokens = ${response.usageMetadata?.candidatesTokenCount ?? 0}\n' + 'Prompt tokens = ${response.usageMetadata?.promptTokenCount ?? 0}\n' + '${const JsonEncoder.withIndent(' ').convert(content.toJson())}', + ); + + final functionCalls = content.parts + .whereType() + .toList(); + + if (functionCalls.isEmpty) { + return response; + } + + final functionResponses = []; + for (final call in functionCalls) { + final tool = tools.firstWhere((t) => t.name == call.name); + final result = await tool.invoke(call.args); + functionResponses.add(fai.FunctionResponse(call.name, result)); + } + + mutableHistory.add(fai.Content.functionResponses(functionResponses)); + } + + throw Exception('Max tool usage cycles reached'); + } +} diff --git a/examples/travel_app/lib/src/turn.dart b/examples/travel_app/lib/src/turn.dart new file mode 100644 index 000000000..f354c78ca --- /dev/null +++ b/examples/travel_app/lib/src/turn.dart @@ -0,0 +1,58 @@ +// Copyright 2025 The Flutter Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import 'package:firebase_ai/firebase_ai.dart' as firebase_ai; +import 'package:flutter_genui/flutter_genui.dart'; + +sealed class Turn { + const Turn(); + + firebase_ai.Content? toContent(); +} + +class UserTurn extends Turn { + final String text; + + const UserTurn(this.text); + + @override + firebase_ai.Content toContent() { + return firebase_ai.Content('user', [firebase_ai.TextPart(text)]); + } +} + +class UserUiInteractionTurn extends Turn { + final String text; + + const UserUiInteractionTurn(this.text); + + @override + firebase_ai.Content toContent() { + return firebase_ai.Content('user', [firebase_ai.TextPart(text)]); + } +} + +class AiTextTurn extends Turn { + final String text; + + const AiTextTurn(this.text); + + @override + firebase_ai.Content toContent() { + return firebase_ai.Content.model([firebase_ai.TextPart(text)]); + } +} + +class GenUiTurn extends Turn { + final String surfaceId; + final UiDefinition definition; + + GenUiTurn({required this.surfaceId, required this.definition}); + + @override + firebase_ai.Content? toContent() { + final text = definition.asContextDescriptionText(); + return firebase_ai.Content.model([firebase_ai.TextPart(text)]); + } +} diff --git a/packages/flutter_genui/lib/src/core/widgets/chat_primitives.dart b/examples/travel_app/lib/src/widgets/chat_message.dart similarity index 100% rename from packages/flutter_genui/lib/src/core/widgets/chat_primitives.dart rename to examples/travel_app/lib/src/widgets/chat_message.dart diff --git a/examples/travel_app/lib/src/widgets/conversation.dart b/examples/travel_app/lib/src/widgets/conversation.dart index 19ad89672..b316b1132 100644 --- a/examples/travel_app/lib/src/widgets/conversation.dart +++ b/examples/travel_app/lib/src/widgets/conversation.dart @@ -3,92 +3,62 @@ // found in the LICENSE file. import 'package:flutter/material.dart'; - import 'package:flutter_genui/flutter_genui.dart'; -typedef UserPromptBuilder = - Widget Function(BuildContext context, UserMessage message); +import '../turn.dart'; +import 'chat_message.dart'; +/// A widget that displays a conversation between a user and an AI. class Conversation extends StatelessWidget { + /// Creates a new [Conversation] widget. const Conversation({ - super.key, required this.messages, required this.manager, required this.onEvent, - this.userPromptBuilder, - this.showInternalMessages = false, this.scrollController, + super.key, }); - final List messages; - final UiEventCallback onEvent; + /// The list of messages in the conversation. + final List messages; + + /// The [GenUiManager] that manages the UI surfaces. final GenUiManager manager; - final UserPromptBuilder? userPromptBuilder; - final bool showInternalMessages; + + /// A callback that is called when a UI event occurs. + final void Function(UiEvent) onEvent; + + /// The scroll controller for the conversation view. final ScrollController? scrollController; @override Widget build(BuildContext context) { - final renderedMessages = messages.where((message) { - if (showInternalMessages) { - return true; - } - return message is! InternalMessage && - message is! ToolResponseMessage && - message is! UserUiInteractionMessage; - }).toList(); return ListView.builder( controller: scrollController, - itemCount: renderedMessages.length, + itemCount: messages.length, itemBuilder: (context, index) { - final message = renderedMessages[index]; - switch (message) { - case UserMessage(): - return userPromptBuilder != null - ? userPromptBuilder!(context, message) - : ChatMessageWidget( - text: message.parts - .whereType() - .map((part) => part.text) - .join('\n'), - icon: Icons.person, - alignment: MainAxisAlignment.end, - ); - case AiTextMessage(): - final text = message.parts - .whereType() - .map((part) => part.text) - .join('\n'); - if (text.trim().isEmpty) { - return const SizedBox.shrink(); - } - return ChatMessageWidget( - text: text, - icon: Icons.smart_toy_outlined, - alignment: MainAxisAlignment.start, - ); - case AiUiMessage(): - return Padding( - padding: const EdgeInsets.all(16.0), - child: GenUiSurface( - key: message.uiKey, - host: manager, - surfaceId: message.surfaceId, - onEvent: onEvent, - ), - ); - case InternalMessage(): - return InternalMessageWidget(content: message.text); - case UserUiInteractionMessage(): - return InternalMessageWidget( - content: message.parts - .whereType() - .map((part) => part.text) - .join('\n'), - ); - case ToolResponseMessage(): - return InternalMessageWidget(content: message.results.toString()); - } + final message = messages[index]; + return switch (message) { + UserTurn() => ChatMessageWidget( + text: message.text, + icon: Icons.person, + alignment: MainAxisAlignment.end, + ), + UserUiInteractionTurn() => const SizedBox.shrink(), + GenUiTurn() => Padding( + padding: const EdgeInsets.all(16.0), + child: GenUiSurface( + surfaceId: message.surfaceId, + host: manager, + onEvent: onEvent, + ), + ), + AiTextTurn() => ChatMessageWidget( + text: message.text, + icon: Icons.auto_awesome, + alignment: MainAxisAlignment.start, + ), + }; }, ); } diff --git a/examples/travel_app/test/main_test.dart b/examples/travel_app/test/main_test.dart index 14e7f38dc..444171433 100644 --- a/examples/travel_app/test/main_test.dart +++ b/examples/travel_app/test/main_test.dart @@ -15,7 +15,7 @@ void main() { final mockAiClient = FakeAiClient(); // The main app expects a JSON response from generateContent. mockAiClient.response = {'result': true}; - await tester.pumpWidget(app.TravelApp(aiClient: mockAiClient)); + await tester.pumpWidget(const app.TravelApp()); await tester.enterText(find.byType(TextField), 'test prompt'); await tester.tap(find.byIcon(Icons.send)); @@ -37,7 +37,7 @@ void main() { final mockAiClient = FakeAiClient(); final completer = Completer(); mockAiClient.generateContentFuture = completer.future; - await tester.pumpWidget(app.TravelApp(aiClient: mockAiClient)); + await tester.pumpWidget(const app.TravelApp()); await tester.enterText(find.byType(TextField), 'test prompt'); await tester.tap(find.byIcon(Icons.send)); diff --git a/examples/travel_app/test/widgets/conversation_test.dart b/examples/travel_app/test/widgets/conversation_test.dart index 59cb6d2a8..dea26a3a3 100644 --- a/examples/travel_app/test/widgets/conversation_test.dart +++ b/examples/travel_app/test/widgets/conversation_test.dart @@ -5,6 +5,7 @@ import 'package:flutter/material.dart'; import 'package:flutter_genui/flutter_genui.dart'; import 'package:flutter_test/flutter_test.dart'; +import 'package:travel_app/src/turn.dart'; import 'package:travel_app/src/widgets/conversation.dart'; void main() { @@ -16,27 +17,28 @@ void main() { }); testWidgets('renders a list of messages', (WidgetTester tester) async { + final definition = UiDefinition.fromMap({ + 'surfaceId': 's1', + 'root': 'r1', + 'widgets': [ + { + 'id': 'r1', + 'widget': { + 'Text': {'text': 'Hi there!'}, + }, + }, + ], + }); final messages = [ - UserMessage.text('Hello'), - AiUiMessage( + const UserTurn('Hello'), + GenUiTurn( surfaceId: 's1', - definition: UiDefinition.fromMap({ - 'surfaceId': 's1', - 'root': 'r1', - 'widgets': [ - { - 'id': 'r1', - 'widget': { - 'Text': {'text': 'Hi there!'}, - }, - }, - ], - }), + definition: definition, ), ]; manager.addOrUpdateSurface( 's1', - (messages[1] as AiUiMessage).definition.toMap(), + definition.toMap(), ); await tester.pumpWidget( @@ -56,7 +58,7 @@ void main() { }); testWidgets('renders UserPrompt correctly', (WidgetTester tester) async { final messages = [ - const UserMessage([TextPart('Hello')]), + const UserTurn('Hello'), ]; await tester.pumpWidget( MaterialApp( @@ -74,24 +76,25 @@ void main() { }); testWidgets('renders UiResponse correctly', (WidgetTester tester) async { + final definition = UiDefinition.fromMap({ + 'surfaceId': 's1', + 'root': 'root', + 'widgets': [ + { + 'id': 'root', + 'widget': { + 'Text': {'text': 'UI Content'}, + }, + }, + ], + }); final messages = [ - AiUiMessage( + GenUiTurn( surfaceId: 's1', - definition: UiDefinition.fromMap({ - 'surfaceId': 's1', - 'root': 'root', - 'widgets': [ - { - 'id': 'root', - 'widget': { - 'Text': {'text': 'UI Content'}, - }, - }, - ], - }), + definition: definition, ), ]; - manager.addOrUpdateSurface('s1', messages[0].definition.toMap()); + manager.addOrUpdateSurface('s1', definition.toMap()); await tester.pumpWidget( MaterialApp( home: Scaffold( @@ -106,26 +109,5 @@ void main() { expect(find.byType(GenUiSurface), findsOneWidget); expect(find.text('UI Content'), findsOneWidget); }); - - testWidgets('uses custom userPromptBuilder', (WidgetTester tester) async { - final messages = [ - const UserMessage([TextPart('Hello')]), - ]; - await tester.pumpWidget( - MaterialApp( - home: Scaffold( - body: Conversation( - messages: messages, - manager: manager, - onEvent: (_) {}, - userPromptBuilder: (context, message) => - const Text('Custom User Prompt'), - ), - ), - ), - ); - expect(find.text('Custom User Prompt'), findsOneWidget); - expect(find.text('Hello'), findsNothing); - }); }); } diff --git a/packages/flutter_genui/lib/flutter_genui.dart b/packages/flutter_genui/lib/flutter_genui.dart index 70a4a0351..5d9402a01 100644 --- a/packages/flutter_genui/lib/flutter_genui.dart +++ b/packages/flutter_genui/lib/flutter_genui.dart @@ -4,6 +4,7 @@ export 'src/ai_client/ai_client.dart'; export 'src/ai_client/firebase_ai_client.dart'; +export 'src/ai_client/gemini_schema_adapter.dart'; export 'src/catalog/core_widgets/checkbox_group.dart'; export 'src/catalog/core_widgets/column.dart'; export 'src/catalog/core_widgets/elevated_button.dart'; @@ -15,12 +16,12 @@ export 'src/core/core_catalog.dart'; export 'src/core/genui_configuration.dart'; export 'src/core/genui_manager.dart'; export 'src/core/genui_surface.dart'; -export 'src/core/widgets/chat_primitives.dart'; export 'src/facade/ui_agent.dart'; export 'src/model/catalog.dart'; export 'src/model/catalog_item.dart'; export 'src/model/chat_box.dart'; export 'src/model/chat_message.dart'; +export 'src/model/tools.dart'; export 'src/model/ui_event_manager.dart'; export 'src/model/ui_models.dart'; export 'src/primitives/logging.dart';