diff --git a/examples/minimal_genui/lib/main.dart b/examples/minimal_genui/lib/main.dart index 18ab40f19..9a64a71fd 100644 --- a/examples/minimal_genui/lib/main.dart +++ b/examples/minimal_genui/lib/main.dart @@ -81,7 +81,7 @@ class _MyHomePageState extends State { Future _triggerInference() async { _chatController.setAiRequestSent(); try { - final response = await _aiClient.generateText( + final response = await _aiClient.generateText(conversation: _conversation); List.of(_chatController.conversation.value), ); _chatController.addMessage(AssistantMessage.text(response)); diff --git a/examples/travel_app/integration_test/app_test.dart b/examples/travel_app/integration_test/app_test.dart index 16c6fccc0..a214368cf 100644 --- a/examples/travel_app/integration_test/app_test.dart +++ b/examples/travel_app/integration_test/app_test.dart @@ -16,7 +16,7 @@ void main() { final mockAiClient = FakeAiClient(); mockAiClient.response = _baliResponse; - runApp(app.TravelApp(aiClient: mockAiClient)); + runApp(app.TravelApp(aiClient: fakeAiClient)); await tester.pumpAndSettle(); await tester.enterText(find.byType(EditableText), 'Plan a trip to Bali'); await tester.tap(find.byIcon(Icons.send)); diff --git a/examples/travel_app/lib/main.dart b/examples/travel_app/lib/main.dart index dad43fd9c..0763d076f 100644 --- a/examples/travel_app/lib/main.dart +++ b/examples/travel_app/lib/main.dart @@ -46,7 +46,7 @@ class TravelApp extends StatelessWidget { /// /// If null, a default [GeminiAiClient] will be created by the /// [TravelPlannerPage]. - final AiClient? aiClient; + final GeminiAiClient? aiClient; @override Widget build(BuildContext context) { @@ -83,7 +83,7 @@ class TravelPlannerPage extends StatefulWidget { /// /// If null, a default instance of [GeminiAiClient] will be created within /// the page's state. - final AiClient? aiClient; + final GeminiAiClient? aiClient; @override State createState() => _TravelPlannerPageState(); @@ -91,7 +91,7 @@ class TravelPlannerPage extends StatefulWidget { class _TravelPlannerPageState extends State { late final GenUiManager _genUiManager; - late final AiClient _aiClient; + late final GeminiAiClient _aiClient; late final UiEventManager _eventManager; final List _conversation = []; @@ -109,7 +109,11 @@ class _TravelPlannerPageState extends State { _genUiManager.updates.listen((update) { setState(() { switch (update) { - case SurfaceAdded(:final surfaceId, :final definition): + case SurfaceAdded( + :final surfaceId, + :final definition, + :final controller, + ): _conversation.add( UiResponseMessage( definition: { @@ -150,7 +154,6 @@ class _TravelPlannerPageState extends State { Future _triggerInference() async { final result = await _aiClient.generateContent( - _conversation, S.object( properties: { 'result': S.boolean( @@ -166,6 +169,7 @@ class _TravelPlannerPageState extends State { }, required: ['result'], ), + conversation: _conversation, ); if (result == null) { return; @@ -231,35 +235,7 @@ class _TravelPlannerPageState extends State { Text('Dynamic UI Demo'), ], ), - actions: [ - ValueListenableBuilder( - valueListenable: _aiClient.model, - builder: (context, currentModel, child) { - return PopupMenuButton( - icon: const Icon(Icons.psychology_outlined), - onSelected: (AiModel value) { - // Handle model selection - _aiClient.switchModel(value); - }, - itemBuilder: (BuildContext context) { - return _aiClient.models.map((model) { - return PopupMenuItem( - value: model, - child: Row( - children: [ - Text(model.displayName), - if (currentModel == model) const Icon(Icons.check), - ], - ), - ); - }).toList(); - }, - ); - }, - ), - const Icon(Icons.person_outline), - const SizedBox(width: 8.0), - ], + actions: [const Icon(Icons.person_outline), const SizedBox(width: 8.0)], ), body: SafeArea( child: Center( diff --git a/examples/travel_app/test/main_test.dart b/examples/travel_app/test/main_test.dart index b784299eb..5e1d36e39 100644 --- a/examples/travel_app/test/main_test.dart +++ b/examples/travel_app/test/main_test.dart @@ -9,25 +9,6 @@ import 'package:flutter_test/flutter_test.dart'; import 'package:travel_app/main.dart' as app; void main() { - testWidgets('Can switch models', (WidgetTester tester) async { - final mockAiClient = FakeAiClient(); - await tester.pumpWidget(app.TravelApp(aiClient: mockAiClient)); - - expect(find.text('mock1'), findsNothing); - expect(find.text('mock2'), findsNothing); - - await tester.tap(find.byIcon(Icons.psychology_outlined)); - await tester.pumpAndSettle(); - - expect(find.text('mock1'), findsOneWidget); - expect(find.text('mock2'), findsOneWidget); - - await tester.tap(find.text('mock2')); - await tester.pumpAndSettle(); - - expect(mockAiClient.model.value.displayName, 'mock2'); - }); - testWidgets('Can send a prompt', (WidgetTester tester) async { final mockAiClient = FakeAiClient(); // The main app expects a JSON response from generateContent. diff --git a/pkgs/flutter_genui/lib/flutter_genui.dart b/pkgs/flutter_genui/lib/flutter_genui.dart index 38220cc07..b394fd67b 100644 --- a/pkgs/flutter_genui/lib/flutter_genui.dart +++ b/pkgs/flutter_genui/lib/flutter_genui.dart @@ -2,7 +2,6 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -export 'src/ai_client/ai_client.dart'; export 'src/ai_client/gemini_ai_client.dart'; export 'src/catalog/core_widgets/checkbox_group.dart'; export 'src/catalog/core_widgets/column.dart'; @@ -13,7 +12,6 @@ export 'src/catalog/core_widgets/text.dart'; export 'src/catalog/core_widgets/text_field.dart'; export 'src/core/core_catalog.dart'; export 'src/core/genui_manager.dart'; -export 'src/core/surface_manager.dart'; export 'src/facade/genui_surface.dart'; export 'src/facade/to_refactor/chat_widget.dart'; export 'src/facade/to_refactor/conversation_widget.dart'; diff --git a/pkgs/flutter_genui/lib/src/ai_client/ai_client.dart b/pkgs/flutter_genui/lib/src/ai_client/ai_client.dart deleted file mode 100644 index 8ddcbb6de..000000000 --- a/pkgs/flutter_genui/lib/src/ai_client/ai_client.dart +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2025 The Flutter Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -import 'package:dart_schema_builder/dart_schema_builder.dart'; -import 'package:flutter/foundation.dart'; - -import '../model/chat_message.dart'; -import '../model/tools.dart'; - -/// An abstract class representing a type of AI model. -/// -/// This class provides a common interface for different AI models. -abstract class AiModel { - /// The display name of the model used to select the model in the UI. - String get displayName; -} - -/// An abstract class for a client that interacts with an AI model. -/// -/// This class defines the interface for sending requests to an AI model and -/// receiving responses. -abstract interface class AiClient { - /// A [ValueListenable] for the currently selected AI model. - /// - /// This allows the UI to listen for changes to the selected model. - ValueListenable get model; - - /// The list of available AI models. - List get models; - - /// Switches the AI model to the given [model]. - void switchModel(AiModel model); - - /// Generates content from the given [conversation] and returns an object of - /// type [T] that conforms to the given [outputSchema]. - /// - /// The [additionalTools] are added to the list of tools available to the - /// AI model. - Future generateContent( - List conversation, - Schema outputSchema, { - Iterable additionalTools = const [], - }); - - /// Generates a text response from the given [conversation]. - /// - /// The [additionalTools] are added to the list of tools available to the - /// AI model, but the model is not required to use them. - Future generateText( - List conversation, { - Iterable additionalTools = const [], - }); -} - -/// An exception thrown by an [AiClient] or its subclasses. -class AiClientException implements Exception { - /// Creates an [AiClientException] with the given [message]. - AiClientException(this.message); - - /// The message associated with the exception. - final String message; - - @override - String toString() => '$AiClientException: $message'; -} diff --git a/pkgs/flutter_genui/lib/src/ai_client/gemini_ai_client.dart b/pkgs/flutter_genui/lib/src/ai_client/gemini_ai_client.dart index 69045d85f..225fdda66 100644 --- a/pkgs/flutter_genui/lib/src/ai_client/gemini_ai_client.dart +++ b/pkgs/flutter_genui/lib/src/ai_client/gemini_ai_client.dart @@ -10,15 +10,49 @@ import 'package:file/local.dart'; import 'package:firebase_ai/firebase_ai.dart'; import 'package:flutter/foundation.dart'; +import '../core/surface_controller.dart'; import '../model/chat_message.dart' as msg; import '../model/tools.dart'; import '../primitives/logging.dart'; import '../primitives/simple_items.dart'; -import 'ai_client.dart'; import 'gemini_content_converter.dart'; import 'gemini_generative_model.dart'; import 'gemini_schema_adapter.dart'; +/// An abstract class representing a type of AI model. +/// +/// This class provides a common interface for different AI models. +abstract class AiModel { + /// The display name of the model used to select the model in the UI. + String get displayName; +} + +/// An exception thrown by an [AiClient] or its subclasses. +class AiClientException implements Exception { + /// Creates an [AiClientException] with the given [message]. + AiClientException(this.message); + + /// The message associated with the exception. + final String message; + + @override + String toString() => '$AiClientException: $message'; +} + +extension ContentExtension on SurfaceController { + Content toContent() { + return Content('user', [ + TextPart( + 'The following is the current UI state that you have generated, ' + 'for your information. You should use this to inform your ' + 'decision about what to do next. The user is seeing this UI, ' + 'which is on a surface with ID "${definitionNotifier.value!.surfaceId}".\n\n' + '${jsonEncode(definitionNotifier.value!)}', + ), + ]); + } +} + /// A factory for creating a [GeminiGenerativeModelInterface]. /// /// This is used to allow for custom model creation, for example, for testing. @@ -30,70 +64,30 @@ typedef GenerativeModelFactory = ToolConfig? toolConfig, }); -/// An enum for the available Gemini models. -enum GeminiModelType { - /// The Gemini 2.5 Flash model. - flash('gemini-2.5-flash', 'Gemini 2.5 Flash'), - - /// The Gemini 2.5 Pro model. - pro('gemini-2.5-pro', 'Gemini 2.5 Pro'); - - /// Creates a [GeminiModelType] with the given [modelName] and [displayName]. - const GeminiModelType(this.modelName, this.displayName); - - /// The name of the model as known by the Gemini API. - final String modelName; - - /// The human-readable name of the model. - final String displayName; -} - -/// A class that represents a Gemini model. -class GeminiModel extends AiModel { - /// Creates a new instance of [GeminiModel] as a specific [type]. - GeminiModel(this.type); - - /// The type of the model. - final GeminiModelType type; - - @override - String get displayName => type.displayName; -} - /// A basic implementation of [AiClient] for accessing a Gemini model. /// /// This class encapsulates settings for interacting with a generative AI model, /// including model selection, API keys, retry mechanisms, and tool /// configurations. It provides a [generateContent] method to interact with the /// AI model, supporting structured output and tool usage. -class GeminiAiClient implements AiClient { +class GeminiAiClient { /// Creates an [GeminiAiClient] instance with specified configurations. /// /// - [model]: The identifier of the generative AI model to use. - /// - [fileSystem]: The [FileSystem] instance for file operations, primarily /// used by tools. /// - [modelCreator]: A factory function to create the [GenerativeModel]. - /// - [maxRetries]: Maximum number of retries for API calls on transient - /// errors. - /// - [initialDelay]: Initial delay for the exponential backoff retry - /// strategy. /// - [maxConcurrentJobs]: Intended for managing concurrent AI operations, /// though not directly enforced by [generateContent] itself. /// - [tools]: A list of default [AiTool]s available to the AI. /// - [outputToolName]: The name of the internal tool used to force structured /// output from the AI. GeminiAiClient({ - GeminiModelType model = GeminiModelType.flash, this.systemInstruction, - this.fileSystem = const LocalFileSystem(), this.modelCreator = defaultGenerativeModelFactory, - this.maxRetries = 8, - this.initialDelay = const Duration(seconds: 1), - this.minDelay = const Duration(seconds: 8), this.maxConcurrentJobs = 20, this.tools = const [], this.outputToolName = 'provideFinalOutput', - }) : _model = ValueNotifier(GeminiModel(model)) { + }) { final duplicateToolNames = tools.map((t) => t.name).toSet(); if (duplicateToolNames.length != tools.length) { final duplicateTools = tools.where((t) { @@ -110,57 +104,6 @@ class GeminiAiClient implements AiClient { /// The system instruction to use for the AI model. final String? systemInstruction; - /// The name of the Gemini model to use. - /// - /// This identifier specifies which version or type of the generative AI model - /// will be invoked for content generation. - /// - /// Defaults to 'gemini-2.5-flash'. - final ValueNotifier _model; - - @override - ValueListenable get model => _model; - - @override - List get models => - GeminiModelType.values.map(GeminiModel.new).toList(); - - /// The file system to use for accessing files. - /// - /// While not directly used by [GeminiAiClient]'s core content generation - /// logic, this [FileSystem] instance can be utilized by [AiTool] - /// implementations that require file read/write capabilities. - /// - /// Defaults to a [LocalFileSystem] instance, providing access to the local - /// machine's file system. - final FileSystem fileSystem; - - /// The maximum number of retries to attempt when generating content. - /// - /// If an API call to the generative model fails with a transient error (like - /// [FirebaseAIException]), the client will attempt to retry the call up to - /// this many times. - /// - /// Defaults to 8 retries. - final int maxRetries; - - /// The initial delay between retries in seconds. - /// - /// This duration is used for the first retry attempt. Subsequent retries - /// employ an exponential backoff strategy, where the delay doubles after each - /// failed attempt, up to the [maxRetries] limit. - /// - /// Defaults to 1 second. - final Duration initialDelay; - - /// The minimum length of time to delay. - /// - /// Since the reset window for quota violations is 10 seconds, this shouldn't - /// be much less than that, or it will just wait longer. - /// - /// Defaults to 8 seconds. - final Duration minDelay; - /// The maximum number of concurrent jobs to run. /// /// This property is intended for systems that might manage multiple @@ -206,18 +149,6 @@ class GeminiAiClient implements AiClient { /// The total number of output tokens used by this client int outputTokenUsage = 0; - @override - void switchModel(AiModel newModel) { - if (newModel is! GeminiModel) { - throw ArgumentError( - 'Invalid model type: ${newModel.runtimeType} supplied to ' - '$GeminiAiClient.switchModel.', - ); - } - _model.value = newModel; - genUiLogger.info('Switched AI model to: ${newModel.displayName}'); - } - /// Generates structured content based on the provided prompts and output /// schema. /// @@ -243,27 +174,32 @@ class GeminiAiClient implements AiClient { /// output `T`. /// - [additionalTools]: A list of [AiTool]s to make available to the AI for /// this specific call, in addition to the default [tools]. - @override Future generateContent( - List conversation, dsb.Schema outputSchema, { + List? conversation, + List? content, Iterable additionalTools = const [], }) async { - return await _generateContentWithRetries(conversation, outputSchema, [ - ...tools, - ...additionalTools, - ]); + return await _generate( + outputSchema: outputSchema, + conversation: conversation, + content: content, + availableTools: [...tools, ...additionalTools], + ) + as T?; } - @override - Future generateText( - List conversation, { + Future generateText({ + List? conversation, + List? content, Iterable additionalTools = const [], }) async { - return await _generateTextWithRetries(conversation, [ - ...tools, - ...additionalTools, - ]); + return await _generate( + conversation: conversation, + content: content, + availableTools: [...tools, ...additionalTools], + ) + as String; } /// The default factory function for creating a [GenerativeModel]. @@ -276,10 +212,9 @@ class GeminiAiClient implements AiClient { List? tools, ToolConfig? toolConfig, }) { - final geminiModel = configuration._model.value; return GeminiGenerativeModel( FirebaseAI.googleAI().generativeModel( - model: geminiModel.type.modelName, + model: 'gemini-2.5-flash', systemInstruction: systemInstruction, tools: tools, toolConfig: toolConfig, @@ -287,101 +222,6 @@ class GeminiAiClient implements AiClient { ); } - Future _generateContentWithRetries( - List contents, - dsb.Schema outputSchema, - List availableTools, - ) async { - genUiLogger.fine('Generating content with retries.'); - return _generateWithRetries( - (onSuccess) async => - await _generate( - messages: contents, - availableTools: availableTools, - onSuccess: onSuccess, - outputSchema: outputSchema, - ) - as T?, - ); - } - - Future _generateTextWithRetries( - List contents, - List availableTools, - ) async { - genUiLogger.fine('Generating text with retries.'); - return _generateWithRetries( - (onSuccess) async => - await _generate( - messages: contents, - availableTools: availableTools, - onSuccess: onSuccess, - ) - as String, - ); - } - - Future _generateWithRetries( - Future Function(void Function() onSuccess) generationFunction, - ) async { - var attempts = 0; - var delay = initialDelay; - final maxTries = maxRetries + 1; // Retries plus the first attempt. - genUiLogger.fine('Starting generation with up to $maxRetries retries.'); - - Future onFail(Exception exception) async { - attempts++; - if (attempts >= maxTries) { - genUiLogger.warning('Max retries of $maxRetries reached.'); - throw exception; - } - // Make the delay at least minDelay long, since the reset window for - // exceeding the number of requests is 10 seconds long, and requesting it - // faster than that just means it makes us wait longer. - final waitTime = delay + minDelay; - genUiLogger.severe( - 'Received exception, retrying in $waitTime. Attempt $attempts of ' - '$maxTries. Exception: $exception', - ); - await Future.delayed(waitTime); - delay *= 2; - } - - while (attempts < maxTries) { - try { - final result = await generationFunction( - // Reset the delay and attempts on success. - () { - delay = initialDelay; - attempts = 0; - }, - ); - genUiLogger.fine('Generation successful.'); - return result; - } on FirebaseAIException catch (exception) { - if (exception.message.contains( - '${_model.value.type.modelName} is not found for API version', - )) { - // If the model is not found, then just throw an exception. - throw AiClientException(exception.message); - } - await onFail(exception); - } catch (exception, stack) { - genUiLogger.severe( - 'Received ' - '${exception.runtimeType}: $exception', - exception, - stack, - ); - // For other exceptions, rethrow immediately. - rethrow; - } - } - // This line should be unreachable if maxRetries > 0, but is needed for - // static analysis. - throw StateError('Exceeded maximum retries without throwing an exception.'); - } - ({List? generativeAiTools, Set allowedFunctionNames}) _setupToolsAndFunctions({ required bool isForcedToolCalling, @@ -559,14 +399,21 @@ class GeminiAiClient implements AiClient { Future _generate({ // This list is modified to include tool calls and results. - required List messages, + List? conversation, + List? content, required List availableTools, - required void Function() onSuccess, dsb.Schema? outputSchema, }) async { + if (conversation != null && content != null) { + throw ArgumentError('Cannot provide both conversation and content.'); + } + if (conversation == null && content == null) { + throw ArgumentError('Must provide either conversation or content.'); + } + final isForcedToolCalling = outputSchema != null; final converter = GeminiContentConverter(); - final contents = converter.toFirebaseAiContent(messages); + final contents = content ?? converter.toFirebaseAiContent(conversation!); final adapter = GeminiSchemaAdapter(); final (:generativeAiTools, :allowedFunctionNames) = _setupToolsAndFunctions( @@ -610,15 +457,12 @@ class GeminiAiClient implements AiClient { genUiLogger.info( '''****** Performing Inference ******\n$concatenatedContents With functions: - '${allowedFunctionNames.join(', ')}', - ''', + ${allowedFunctionNames.join(', ')}''', ); final inferenceStartTime = DateTime.now(); final response = await model.generateContent(contents); final elapsed = DateTime.now().difference(inferenceStartTime); - onSuccess(); - if (response.usageMetadata != null) { inputTokenUsage += response.usageMetadata!.promptTokenCount ?? 0; outputTokenUsage += response.usageMetadata!.candidatesTokenCount ?? 0; @@ -656,7 +500,7 @@ With functions: ); } if (candidate.text != null) { - messages.add(msg.AssistantMessage.text(candidate.text!)); + conversation?.add(msg.AssistantMessage.text(candidate.text!)); } genUiLogger.fine( 'Model returned text but no function calls with forced tool ' @@ -665,7 +509,7 @@ With functions: return null; } else { final text = candidate.text ?? ''; - messages.add(msg.AssistantMessage.text(text)); + conversation?.add(msg.AssistantMessage.text(text)); genUiLogger.fine('Returning text response: "$text"'); return text; } @@ -701,7 +545,7 @@ With functions: .toList(); if (assistantParts.isNotEmpty) { - messages.add(msg.AssistantMessage(assistantParts)); + conversation?.add(msg.AssistantMessage(assistantParts)); genUiLogger.fine( 'Added assistant message with ${assistantParts.length} parts to ' 'conversation.', @@ -720,7 +564,7 @@ With functions: }).toList(); if (toolResponseParts.isNotEmpty) { - messages.add(msg.ToolResponseMessage(toolResponseParts)); + conversation?.add(msg.ToolResponseMessage(toolResponseParts)); genUiLogger.fine( 'Added tool response message with ${toolResponseParts.length} ' 'parts to conversation.', diff --git a/pkgs/flutter_genui/lib/src/ai_client/gemini_content_converter.dart b/pkgs/flutter_genui/lib/src/ai_client/gemini_content_converter.dart index 9ff779134..d155413a5 100644 --- a/pkgs/flutter_genui/lib/src/ai_client/gemini_content_converter.dart +++ b/pkgs/flutter_genui/lib/src/ai_client/gemini_content_converter.dart @@ -8,7 +8,7 @@ import 'package:firebase_ai/firebase_ai.dart' as firebase_ai; import '../model/chat_message.dart'; import '../primitives/simple_items.dart'; -import 'ai_client.dart'; +import 'gemini_ai_client.dart'; /// A class to convert between the generic `ChatMessage` and the `firebase_ai` /// specific `Content` classes. diff --git a/pkgs/flutter_genui/lib/src/core/genui_manager.dart b/pkgs/flutter_genui/lib/src/core/genui_manager.dart index bea6e2cc8..d2c93eb50 100644 --- a/pkgs/flutter_genui/lib/src/core/genui_manager.dart +++ b/pkgs/flutter_genui/lib/src/core/genui_manager.dart @@ -2,28 +2,73 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +import 'dart:async'; + import 'package:flutter/foundation.dart'; -import '../ai_client/ai_client.dart'; import '../model/catalog.dart'; import '../model/tools.dart'; import '../model/ui_models.dart'; +import '../primitives/logging.dart'; import '../primitives/simple_items.dart'; -import 'surface_manager.dart'; +import 'core_catalog.dart'; +import 'surface_controller.dart'; import 'ui_tools.dart'; +/// A sealed class representing an update to the UI managed by [GenUiManager]. +/// +/// This class has three subclasses: [SurfaceAdded], [SurfaceUpdated], and +/// [SurfaceRemoved]. +sealed class GenUiUpdate { + /// Creates a [GenUiUpdate] for the given [surfaceId]. + const GenUiUpdate(this.surfaceId); + + /// The ID of the surface that was updated. + final String surfaceId; +} + +/// Fired when a new surface is created. +class SurfaceAdded extends GenUiUpdate { + /// Creates a [SurfaceAdded] event for the given [surfaceId] and + /// [definition]. + const SurfaceAdded(super.surfaceId, this.definition, this.controller); + + /// The definition of the new surface. + final UiDefinition definition; + + /// The controller for the new surface. + final SurfaceController controller; +} + +/// Fired when an existing surface is modified. +class SurfaceUpdated extends GenUiUpdate { + /// Creates a [SurfaceUpdated] event for the given [surfaceId] and + /// [definition]. + const SurfaceUpdated(super.surfaceId, this.definition); + + /// The new definition of the surface. + final UiDefinition definition; +} + +/// Fired when a surface is deleted. +class SurfaceRemoved extends GenUiUpdate { + /// Creates a [SurfaceRemoved] event for the given [surfaceId]. + const SurfaceRemoved(super.surfaceId); +} + class GenUiManager { - SurfaceManager surfaceManager; + GenUiManager({Catalog? catalog}) : catalog = catalog ?? coreCatalog; - GenUiManager({Catalog? catalog}) - : surfaceManager = SurfaceManager(catalog: catalog); + final _surfaces = >{}; + final _updates = StreamController.broadcast(); - Map> get surfaces => - surfaceManager.surfaces; + Map> get surfaces => _surfaces; + Map get controllers => _controllers; + final _controllers = {}; - Stream get updates => surfaceManager.updates; + Stream get updates => _updates.stream; - Catalog get catalog => surfaceManager.catalog; + final Catalog catalog; /// Returns a list of [AiTool]s that can be used to manipulate the UI. /// @@ -31,20 +76,54 @@ class GenUiManager { /// generate and modify the UI. List getTools() { return [ - AddOrUpdateSurfaceTool(surfaceManager), - DeleteSurfaceTool(surfaceManager), + AddOrUpdateSurfaceTool( + onAddOrUpdate: _addOrUpdateSurface, + catalog: catalog, + ), + DeleteSurfaceTool(onDelete: _deleteSurface), ]; } - ValueNotifier surface(String surfaceId) => - surfaceManager.surface(surfaceId); + ValueNotifier surface(String surfaceId) { + return _surfaces.putIfAbsent(surfaceId, () => ValueNotifier(null)); + } void dispose() { - surfaceManager.dispose(); + _updates.close(); + for (final notifier in _surfaces.values) { + notifier.dispose(); + } } - void addOrUpdateSurface(String s, JsonMap definition) => - surfaceManager.addOrUpdateSurface(s, definition); + void _addOrUpdateSurface(String surfaceId, JsonMap definition) { + final uiDefinition = UiDefinition.fromMap({ + 'surfaceId': surfaceId, + ...definition, + }); + final notifier = surface(surfaceId); // Gets or creates the notifier. + final isNew = notifier.value == null; + notifier.value = uiDefinition; + if (isNew) { + genUiLogger.info('Adding surface $surfaceId'); + final controller = SurfaceController( + definitionNotifier: notifier, + catalog: catalog, + ); + _controllers[surfaceId] = controller; + _updates.add(SurfaceAdded(surfaceId, uiDefinition, controller)); + } else { + genUiLogger.info('Updating surface $surfaceId'); + _updates.add(SurfaceUpdated(surfaceId, uiDefinition)); + } + } - void deleteSurface(String s) => surfaceManager.deleteSurface(s); + void _deleteSurface(String surfaceId) { + if (_surfaces.containsKey(surfaceId)) { + genUiLogger.info('Deleting surface $surfaceId'); + final notifier = _surfaces.remove(surfaceId); + _controllers.remove(surfaceId); + notifier?.dispose(); + _updates.add(SurfaceRemoved(surfaceId)); + } + } } diff --git a/pkgs/flutter_genui/lib/src/core/surface_controller.dart b/pkgs/flutter_genui/lib/src/core/surface_controller.dart new file mode 100644 index 000000000..287c0fe1d --- /dev/null +++ b/pkgs/flutter_genui/lib/src/core/surface_controller.dart @@ -0,0 +1,23 @@ +// Copyright 2025 The Flutter Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import 'package:flutter/foundation.dart'; + +import '../model/catalog.dart'; +import '../model/ui_models.dart'; + +/// A callback for when a user interacts with a widget. +typedef UiEventCallback = void Function(UiEvent event); + +class SurfaceController { + SurfaceController({required this.definitionNotifier, required this.catalog}); + + final ValueNotifier definitionNotifier; + final Catalog catalog; + + String get surfaceId => definitionNotifier.value!.surfaceId; + + /// A callback for when a user interacts with a widget. + UiEventCallback? onEvent; +} diff --git a/pkgs/flutter_genui/lib/src/core/surface_manager.dart b/pkgs/flutter_genui/lib/src/core/surface_manager.dart deleted file mode 100644 index f8aef39ef..000000000 --- a/pkgs/flutter_genui/lib/src/core/surface_manager.dart +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2025 The Flutter Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -import 'dart:async'; - -import 'package:flutter/foundation.dart'; - -import '../model/catalog.dart'; -import '../model/ui_models.dart'; -import '../primitives/logging.dart'; -import '../primitives/simple_items.dart'; -import 'core_catalog.dart'; - -/// A sealed class representing an update to the UI managed by [SurfaceManager]. -/// -/// This class has three subclasses: [SurfaceAdded], [SurfaceUpdated], and -/// [SurfaceRemoved]. -sealed class GenUiUpdate { - /// Creates a [GenUiUpdate] for the given [surfaceId]. - const GenUiUpdate(this.surfaceId); - - /// The ID of the surface that was updated. - final String surfaceId; -} - -/// Fired when a new surface is created. -class SurfaceAdded extends GenUiUpdate { - /// Creates a [SurfaceAdded] event for the given [surfaceId] and - /// [definition]. - const SurfaceAdded(super.surfaceId, this.definition); - - /// The definition of the new surface. - final UiDefinition definition; -} - -/// Fired when an existing surface is modified. -class SurfaceUpdated extends GenUiUpdate { - /// Creates a [SurfaceUpdated] event for the given [surfaceId] and - /// [definition]. - const SurfaceUpdated(super.surfaceId, this.definition); - - /// The new definition of the surface. - final UiDefinition definition; -} - -/// Fired when a surface is deleted. -class SurfaceRemoved extends GenUiUpdate { - /// Creates a [SurfaceRemoved] event for the given [surfaceId]. - const SurfaceRemoved(super.surfaceId); -} - -/// A manager for the state of the generative UI. -/// -/// This class is responsible for managing the state of all the UI "surfaces" -/// that are generated by the AI. It provides tools for the AI to manipulate -/// the UI, and it notifies listeners of changes to the UI. -class SurfaceManager { - /// Creates a new [SurfaceManager]. - /// - /// A [catalog] of UI components can be provided, otherwise the - /// [coreCatalog] will be used. - SurfaceManager({Catalog? catalog}) : catalog = catalog ?? coreCatalog; - - /// The catalog of UI components that can be used to build the UI. - final Catalog catalog; - - final _surfaces = >{}; - final _updates = StreamController.broadcast(); - - /// A map of the current UI surfaces. - Map> get surfaces => _surfaces; - - /// A stream of updates to the UI. - /// - /// This stream emits a [GenUiUpdate] whenever a surface is added, updated, - /// or removed. - Stream get updates => _updates.stream; - - /// Returns a [ValueNotifier] for the given [surfaceId]. - /// - /// This can be used to listen for changes to a specific surface. - ValueNotifier surface(String surfaceId) { - return _surfaces.putIfAbsent(surfaceId, () => ValueNotifier(null)); - } - - /// Adds or updates a UI surface. - /// - /// This method is called to update the UI. - void addOrUpdateSurface(String surfaceId, JsonMap definition) { - final uiDefinition = UiDefinition.fromMap({ - 'surfaceId': surfaceId, - ...definition, - }); - final notifier = surface(surfaceId); // Gets or creates the notifier. - final isNew = notifier.value == null; - notifier.value = uiDefinition; - if (isNew) { - genUiLogger.info('Adding surface $surfaceId'); - _updates.add(SurfaceAdded(surfaceId, uiDefinition)); - } else { - genUiLogger.info('Updating surface $surfaceId'); - _updates.add(SurfaceUpdated(surfaceId, uiDefinition)); - } - } - - /// Deletes a UI surface. - /// - /// This method is called to update the UI. - void deleteSurface(String surfaceId) { - if (_surfaces.containsKey(surfaceId)) { - genUiLogger.info('Deleting surface $surfaceId'); - final notifier = _surfaces.remove(surfaceId); - notifier?.dispose(); - _updates.add(SurfaceRemoved(surfaceId)); - } - } - - /// Disposes of the resources used by the manager. - void dispose() { - _updates.close(); - for (final notifier in _surfaces.values) { - notifier.dispose(); - } - } -} diff --git a/pkgs/flutter_genui/lib/src/core/ui_tools.dart b/pkgs/flutter_genui/lib/src/core/ui_tools.dart index d7743064d..102f512ec 100644 --- a/pkgs/flutter_genui/lib/src/core/ui_tools.dart +++ b/pkgs/flutter_genui/lib/src/core/ui_tools.dart @@ -4,9 +4,9 @@ import 'package:dart_schema_builder/dart_schema_builder.dart'; +import '../model/catalog.dart'; import '../model/tools.dart'; import '../primitives/simple_items.dart'; -import 'surface_manager.dart'; /// An [AiTool] for adding or updating a UI surface. /// @@ -14,50 +14,52 @@ import 'surface_manager.dart'; /// one with a new definition. class AddOrUpdateSurfaceTool extends AiTool { /// Creates an [AddOrUpdateSurfaceTool]. - AddOrUpdateSurfaceTool(this.manager) - : super( - name: 'addOrUpdateSurface', - description: - 'Adds a new UI surface or updates an existing one. Use this to ' - 'display new content or change what is currently visible.', - parameters: S.object( - properties: { - 'surfaceId': S.string( - description: - 'The unique identifier for the UI surface to create or ' - 'modify.', - ), - 'definition': S.object( - properties: { - 'root': S.string( - description: - 'The ID of the root widget. This ID must correspond to ' - 'the ID of one of the widgets in the `widgets` list.', - ), - 'widgets': S.list( - items: manager.catalog.schema, - description: 'A list of widget definitions.', - minItems: 1, - ), - }, - description: - 'A schema for a simple UI tree to be rendered by ' - 'Flutter.', - required: ['root', 'widgets'], - ), - }, - required: ['surfaceId', 'definition'], - ), - ); + AddOrUpdateSurfaceTool({ + required this.onAddOrUpdate, + required Catalog catalog, + }) : super( + name: 'addOrUpdateSurface', + description: + 'Adds a new UI surface or updates an existing one. Use this to ' + 'display new content or change what is currently visible.', + parameters: S.object( + properties: { + 'surfaceId': S.string( + description: + 'The unique identifier for the UI surface to create or ' + 'modify.', + ), + 'definition': S.object( + properties: { + 'root': S.string( + description: + 'The ID of the root widget. This ID must correspond to ' + 'the ID of one of the widgets in the `widgets` list.', + ), + 'widgets': S.list( + items: catalog.schema, + description: 'A list of widget definitions.', + minItems: 1, + ), + }, + description: + 'A schema for a simple UI tree to be rendered by ' + 'Flutter.', + required: ['root', 'widgets'], + ), + }, + required: ['surfaceId', 'definition'], + ), + ); - /// The [SurfaceManager] to use for updating the UI. - final SurfaceManager manager; + /// The callback to invoke when adding or updating a surface. + final void Function(String surfaceId, JsonMap definition) onAddOrUpdate; @override Future invoke(JsonMap args) async { final surfaceId = args['surfaceId'] as String; final definition = args['definition'] as JsonMap; - manager.addOrUpdateSurface(surfaceId, definition); + onAddOrUpdate(surfaceId, definition); return {'surfaceId': surfaceId, 'definition': definition}; } } @@ -67,28 +69,28 @@ class AddOrUpdateSurfaceTool extends AiTool { /// This tool allows the AI to remove a UI surface that is no longer needed. class DeleteSurfaceTool extends AiTool { /// Creates a [DeleteSurfaceTool]. - DeleteSurfaceTool(this.manager) - : super( - name: 'deleteSurface', - description: 'Removes a UI surface that is no longer needed.', - parameters: S.object( - properties: { - 'surfaceId': S.string( - description: - 'The unique identifier for the UI surface to remove.', - ), - }, - required: ['surfaceId'], - ), - ); + DeleteSurfaceTool({required this.onDelete}) + : super( + name: 'deleteSurface', + description: 'Removes a UI surface that is no longer needed.', + parameters: S.object( + properties: { + 'surfaceId': S.string( + description: + 'The unique identifier for the UI surface to remove.', + ), + }, + required: ['surfaceId'], + ), + ); - /// The [SurfaceManager] to use for updating the UI. - final SurfaceManager manager; + /// The callback to invoke when deleting a surface. + final void Function(String surfaceId) onDelete; @override Future invoke(JsonMap args) async { final surfaceId = args['surfaceId'] as String; - manager.deleteSurface(surfaceId); + onDelete(surfaceId); return {'status': 'ok'}; } } diff --git a/pkgs/flutter_genui/lib/src/facade/genui_surface.dart b/pkgs/flutter_genui/lib/src/facade/genui_surface.dart index 3f5f8acd3..9d94a24da 100644 --- a/pkgs/flutter_genui/lib/src/facade/genui_surface.dart +++ b/pkgs/flutter_genui/lib/src/facade/genui_surface.dart @@ -2,113 +2,58 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -import 'dart:async'; - import 'package:flutter/material.dart'; -import '../core/genui_manager.dart'; -import '../core/surface_manager.dart'; +import '../core/surface_controller.dart'; import '../model/ui_models.dart'; import '../primitives/logging.dart'; import '../primitives/simple_items.dart'; -/// A callback for when a user interacts with a widget. -typedef UiEventCallback = void Function(UiEvent event); - /// A widget that builds a UI dynamically from a JSON-like definition. /// /// It reports user interactions via the [onEvent] callback. -class GenUiSurface extends StatefulWidget { +class GenUiSurface extends StatelessWidget { /// Creates a new [GenUiSurface]. const GenUiSurface({ super.key, - required this.manager, - required this.surfaceId, - required this.onEvent, + required this.controller, this.defaultBuilder, }); - /// The manager that holds the state of the UI. - final GenUiManager manager; - - /// The ID of the surface that this UI belongs to. - final String surfaceId; - - /// A callback for when a user interacts with a widget. - final UiEventCallback onEvent; + /// The controller that holds the state of the UI. + final SurfaceController controller; /// A builder for the widget to display when the surface has no definition. final WidgetBuilder? defaultBuilder; - @override - State createState() => _GenUiSurfaceState(); -} - -class _GenUiSurfaceState extends State { - ValueNotifier? _definitionNotifier; - StreamSubscription? _allUpdatesSubscription; - - @override - void initState() { - super.initState(); - _init(); - } - - @override - void didUpdateWidget(covariant GenUiSurface oldWidget) { - super.didUpdateWidget(oldWidget); - - if (oldWidget.surfaceId != widget.surfaceId || - oldWidget.manager != widget.manager) { - _init(); - } - } - - void _init() { - // Reset previous subscription for updates. - _allUpdatesSubscription?.cancel(); - _allUpdatesSubscription = widget.manager.updates.listen((update) { - if (update.surfaceId == widget.surfaceId) _init(); - }); - - // Update definition if it is changed. - final newDefinitionNotifier = widget.manager.surface(widget.surfaceId); - if (newDefinitionNotifier == _definitionNotifier) return; - _definitionNotifier = newDefinitionNotifier; - setState(() {}); - } - - /// Dispatches an event by calling the public [GenUiSurface.onEvent] - /// callback. void _dispatchEvent(UiEvent event) { + final onEvent = controller.onEvent; + if (onEvent == null) { + return; + } // The event comes in without a surfaceId, which we add here. final eventMap = event.toMap(); - eventMap['surfaceId'] = widget.surfaceId; - widget.onEvent(UiEvent.fromMap(eventMap)); + eventMap['surfaceId'] = controller.surfaceId; + onEvent(UiEvent.fromMap(eventMap)); } @override Widget build(BuildContext context) { - final notifier = _definitionNotifier; - if (notifier == null) { - return const SizedBox.shrink(); - } - return ValueListenableBuilder( - valueListenable: notifier, + valueListenable: controller.definitionNotifier, builder: (context, definition, child) { - genUiLogger.info('Building surface ${widget.surfaceId}'); + genUiLogger.info('Building surface ${controller.surfaceId}'); if (definition == null) { - genUiLogger.info('Surface ${widget.surfaceId} has no definition.'); - return widget.defaultBuilder?.call(context) ?? - const SizedBox.shrink(); + genUiLogger + .info('Surface ${controller.surfaceId} has no definition.'); + return defaultBuilder?.call(context) ?? const SizedBox.shrink(); } final rootId = definition.root; if (definition.widgets.isEmpty) { - genUiLogger.warning('Surface ${widget.surfaceId} has no widgets.'); + genUiLogger.warning('Surface ${controller.surfaceId} has no widgets.'); return const SizedBox.shrink(); } - return _buildWidget(definition, rootId); + return _buildWidget(definition, rootId, context); }, ); } @@ -117,26 +62,22 @@ class _GenUiSurfaceState extends State { /// It reads a widget definition and its current state from /// `widget.definition` /// and constructs the corresponding Flutter widget. - Widget _buildWidget(UiDefinition definition, String widgetId) { + Widget _buildWidget( + UiDefinition definition, + String widgetId, + BuildContext context, + ) { var data = definition.widgets[widgetId]; if (data == null) { genUiLogger.severe('Widget with id: $widgetId not found.'); return Placeholder(child: Text('Widget with id: $widgetId not found.')); } - return widget.manager.catalog.buildWidget( + return controller.catalog.buildWidget( data as JsonMap, - (String childId) => _buildWidget(definition, childId), + (String childId) => _buildWidget(definition, childId, context), _dispatchEvent, context, ); } - - @override - void dispose() { - _allUpdatesSubscription?.cancel(); - // We should not dispose _definitionNotifier, - // because it is owned by the manager. - super.dispose(); - } -} +} \ No newline at end of file diff --git a/pkgs/flutter_genui/lib/src/facade/to_refactor/chat_widget.dart b/pkgs/flutter_genui/lib/src/facade/to_refactor/chat_widget.dart index 08b7cb77f..72a3e7bb1 100644 --- a/pkgs/flutter_genui/lib/src/facade/to_refactor/chat_widget.dart +++ b/pkgs/flutter_genui/lib/src/facade/to_refactor/chat_widget.dart @@ -7,7 +7,8 @@ import 'dart:async'; import 'package:flutter/material.dart'; import '../../core/genui_manager.dart'; -import '../../core/surface_manager.dart'; +import '../../core/surface_controller.dart'; + import '../../core/widgets/chat_primitives.dart'; import '../../model/chat_box.dart'; import '../../model/chat_message.dart'; @@ -46,7 +47,7 @@ class GenUiChatController { void _onUpdate(GenUiUpdate update) { final currentConversation = _conversation.value; switch (update) { - case SurfaceAdded(:final surfaceId, :final definition): + case SurfaceAdded(:final surfaceId, :final definition, :final controller): if (!_surfaceIds.contains(surfaceId)) { _surfaceIds.add(surfaceId); _conversation.value = [ @@ -205,13 +206,16 @@ class _GenUiChatState extends State { alignment: MainAxisAlignment.start, ); case UiResponseMessage(): + final controller = widget.controller.manager.controllers[message.surfaceId]; + if (controller == null) { + return const SizedBox.shrink(); + } + controller.onEvent = widget.onEvent; return Padding( padding: const EdgeInsets.all(16.0), child: GenUiSurface( key: message.uiKey, - manager: widget.controller.manager, - surfaceId: message.surfaceId, - onEvent: widget.onEvent, + controller: controller, ), ); case InternalMessage(): diff --git a/pkgs/flutter_genui/lib/src/facade/to_refactor/conversation_widget.dart b/pkgs/flutter_genui/lib/src/facade/to_refactor/conversation_widget.dart index d9c909e0e..21bdbac37 100644 --- a/pkgs/flutter_genui/lib/src/facade/to_refactor/conversation_widget.dart +++ b/pkgs/flutter_genui/lib/src/facade/to_refactor/conversation_widget.dart @@ -5,6 +5,7 @@ import 'package:flutter/material.dart'; import '../../core/genui_manager.dart'; +import '../../core/surface_controller.dart'; import '../../core/widgets/chat_primitives.dart'; import '../../model/chat_message.dart'; import '../genui_surface.dart'; @@ -66,13 +67,16 @@ class ConversationWidget extends StatelessWidget { alignment: MainAxisAlignment.start, ); case UiResponseMessage(): + final controller = manager.controllers[message.surfaceId]; + if (controller == null) { + return const SizedBox.shrink(); + } + controller.onEvent = onEvent; return Padding( padding: const EdgeInsets.all(16.0), child: GenUiSurface( key: message.uiKey, - manager: manager, - surfaceId: message.surfaceId, - onEvent: onEvent, + controller: controller, ), ); case InternalMessage(): diff --git a/pkgs/flutter_genui/lib/test/fake_ai_client.dart b/pkgs/flutter_genui/lib/test/fake_ai_client.dart index b81f62d78..4112bb89b 100644 --- a/pkgs/flutter_genui/lib/test/fake_ai_client.dart +++ b/pkgs/flutter_genui/lib/test/fake_ai_client.dart @@ -5,17 +5,19 @@ import 'dart:async'; import 'package:dart_schema_builder/dart_schema_builder.dart'; +import 'package:firebase_ai/firebase_ai.dart'; import 'package:flutter/foundation.dart'; +import 'package:flutter_genui/src/ai_client/gemini_ai_client.dart'; +import 'package:dart_schema_builder/dart_schema_builder.dart' as dsb; -import '../src/ai_client/ai_client.dart'; import '../src/model/chat_message.dart' as genui; import '../src/model/tools.dart'; -/// A fake implementation of [AiClient] for testing purposes. +/// A fake implementation of [GeminiAiClient] for testing purposes. /// /// This class allows for mocking the behavior of an AI client by providing /// canned responses or exceptions. It also tracks calls to its methods. -class FakeAiClient implements AiClient { +class FakeAiClient implements GeminiAiClient { /// The response to be returned by [generateContent]. Object? response; @@ -44,15 +46,16 @@ class FakeAiClient implements AiClient { @override Future generateContent( - List conversation, - Schema outputSchema, { + dsb.Schema outputSchema, { + List? conversation, + List? content, Iterable additionalTools = const [], }) async { if (responseCompleter.isCompleted) { responseCompleter = Completer(); } generateContentCallCount++; - lastConversation = conversation; + lastConversation = conversation ?? []; try { if (preGenerateContent != null) { await preGenerateContent!(); @@ -69,15 +72,16 @@ class FakeAiClient implements AiClient { } @override - Future generateText( - List conversation, { + Future generateText({ + List? conversation, + List? content, Iterable additionalTools = const [], }) async { if (responseCompleter.isCompleted) { responseCompleter = Completer(); } generateTextCallCount++; - lastConversation = conversation; + lastConversation = conversation ?? []; try { if (preGenerateContent != null) { await preGenerateContent!(); @@ -106,6 +110,28 @@ class FakeAiClient implements AiClient { void switchModel(AiModel model) { _model.value = model; } + + @override + int inputTokenUsage = 0; + + @override + int maxConcurrentJobs = 20; + + @override + GenerativeModelFactory modelCreator = + GeminiAiClient.defaultGenerativeModelFactory; + + @override + int outputTokenUsage = 0; + + @override + String outputToolName = 'provideFinalOutput'; + + @override + String? systemInstruction; + + @override + List tools = []; } /// A fake implementation of [AiModel] for testing purposes. diff --git a/pkgs/flutter_genui/test/ai_client/ai_client_test.dart b/pkgs/flutter_genui/test/ai_client/ai_client_test.dart index a2425c28d..e250f3031 100644 --- a/pkgs/flutter_genui/test/ai_client/ai_client_test.dart +++ b/pkgs/flutter_genui/test/ai_client/ai_client_test.dart @@ -12,7 +12,6 @@ import 'package:firebase_ai/firebase_ai.dart' FunctionCall, GenerateContentResponse; import 'package:firebase_ai/firebase_ai.dart' as firebase_ai; -import 'package:flutter_genui/src/ai_client/ai_client.dart'; import 'package:flutter_genui/src/ai_client/gemini_ai_client.dart'; import 'package:flutter_genui/src/model/chat_message.dart'; import 'package:flutter_genui/src/model/tools.dart'; @@ -77,9 +76,10 @@ void main() { ), ], null); - final result = await client.generateContent>([ - UserMessage.text('user prompt'), - ], S.object(properties: {'key': S.string()})); + final result = await client.generateContent>( + S.object(properties: {'key': S.string()}), + conversation: [UserMessage.text('user prompt')], + ); expect(result, isNotNull); expect(result!['key'], 'value'); @@ -125,9 +125,10 @@ void main() { ], null), ]; - final result = await client.generateContent>([ - UserMessage.text('do something'), - ], S.object(properties: {'final': S.string()})); + final result = await client.generateContent>( + S.object(properties: {'final': S.string()}), + conversation: [UserMessage.text('do something')], + ); expect(toolCalled, isTrue); expect(result, isNotNull); @@ -135,31 +136,6 @@ void main() { expect(fakeModel.generateContentCallCount, 2); }); - test('generateContent retries on failure', () async { - client = createClient(); - fakeModel.exception = FirebaseAIException('transient error'); - fakeModel.response = GenerateContentResponse([ - Candidate( - Content.model([ - FunctionCall('provideFinalOutput', { - 'output': {'key': 'value'}, - }), - ]), - [], - null, - null, - null, - ), - ], null); - - final result = await client.generateContent>([ - UserMessage.text('user prompt'), - ], S.object(properties: {'key': S.string()})); - - expect(result, isNotNull); - expect(fakeModel.generateContentCallCount, 2); - }); - test('generateContent handles tool exception', () async { final tool = DynamicAiTool( name: 'badTool', @@ -194,9 +170,10 @@ void main() { ], null), ]; - final result = await client.generateContent>([ - UserMessage.text('do something'), - ], S.object(properties: {'final': S.string()})); + final result = await client.generateContent>( + S.object(properties: {'final': S.string()}), + conversation: [UserMessage.text('do something')], + ); expect(result, isNotNull); expect(result!['final'], 'result'); @@ -206,9 +183,10 @@ void main() { client = createClient(); fakeModel.response = GenerateContentResponse([], null); - final result = await client.generateContent>([ - UserMessage.text('user prompt'), - ], S.object(properties: {})); + final result = await client.generateContent>( + S.object(properties: {}), + conversation: [UserMessage.text('user prompt')], + ); expect(result, isNull); }); @@ -226,9 +204,10 @@ void main() { ], null); expect( - () => client.generateContent>([ - UserMessage.text('user prompt'), - ], S.object(properties: {})), + () => client.generateContent>( + S.object(properties: {}), + conversation: [UserMessage.text('user prompt')], + ), throwsA(isA()), ); }); @@ -245,9 +224,10 @@ void main() { ), ], null); - final result = await client.generateContent>([ - UserMessage.text('user prompt'), - ], S.object(properties: {})); + final result = await client.generateContent>( + S.object(properties: {}), + conversation: [UserMessage.text('user prompt')], + ); expect(result, isNull); }); @@ -272,9 +252,10 @@ void main() { ), ], null); - final result = await client.generateContent>([ - UserMessage.text('user prompt'), - ], S.object(properties: {})); + final result = await client.generateContent>( + S.object(properties: {}), + conversation: [UserMessage.text('user prompt')], + ); expect(result, isNull); }); @@ -304,9 +285,10 @@ void main() { ), ], null); - await client.generateContent>([ - UserMessage.text('user prompt'), - ], S.object(properties: {'key': S.string()})); + await client.generateContent>( + S.object(properties: {'key': S.string()}), + conversation: [UserMessage.text('user prompt')], + ); expect(logMessages, isNotEmpty); }); diff --git a/pkgs/flutter_genui/test/catalog/core_widgets_test.dart b/pkgs/flutter_genui/test/catalog/core_widgets_test.dart index 1c67c9467..c2b9f0ce7 100644 --- a/pkgs/flutter_genui/test/catalog/core_widgets_test.dart +++ b/pkgs/flutter_genui/test/catalog/core_widgets_test.dart @@ -4,6 +4,7 @@ import 'package:flutter/material.dart'; import 'package:flutter_genui/flutter_genui.dart'; +import 'package:flutter_genui/src/core/surface_controller.dart'; import 'package:flutter_test/flutter_test.dart'; void main() { @@ -23,14 +24,20 @@ void main() { UiEventCallback onEvent, ) async { final manager = GenUiManager(catalog: testCatalog); - manager.addOrUpdateSurface('testSurface', definition); + final addOrUpdateSurfaceTool = manager + .getTools() + .firstWhere((tool) => tool.name == 'addOrUpdateSurface'); + await addOrUpdateSurfaceTool.invoke({ + 'surfaceId': 'testSurface', + 'definition': definition, + }); + final controller = manager.controllers['testSurface']!; + controller.onEvent = onEvent; await tester.pumpWidget( MaterialApp( home: Scaffold( body: GenUiSurface( - manager: manager, - surfaceId: 'testSurface', - onEvent: onEvent, + controller: controller, ), ), ), diff --git a/pkgs/flutter_genui/test/core/chat_controller_test.dart b/pkgs/flutter_genui/test/core/chat_controller_test.dart index fd9fd8c5f..3e3a5dc78 100644 --- a/pkgs/flutter_genui/test/core/chat_controller_test.dart +++ b/pkgs/flutter_genui/test/core/chat_controller_test.dart @@ -3,16 +3,24 @@ // found in the LICENSE file. import 'package:flutter_genui/flutter_genui.dart'; +import 'package:flutter_genui/src/model/tools.dart'; import 'package:flutter_test/flutter_test.dart'; void main() { group('GenUiChatController', () { late GenUiManager manager; late GenUiChatController controller; + late AiTool addOrUpdateSurfaceTool; + late AiTool deleteSurfaceTool; setUp(() { manager = GenUiManager(); controller = GenUiChatController(manager: manager); + addOrUpdateSurfaceTool = manager + .getTools() + .firstWhere((tool) => tool.name == 'addOrUpdateSurface'); + deleteSurfaceTool = + manager.getTools().firstWhere((tool) => tool.name == 'deleteSurface'); }); tearDown(() { @@ -34,16 +42,19 @@ void main() { testWidgets('manager SurfaceAdded update adds a new surface message', ( WidgetTester tester, ) async { - manager.addOrUpdateSurface('s2', { - 'root': 'root', - 'widgets': [ - { - 'id': 'root', - 'widget': { - 'text': {'text': 'Surface 2'}, + await addOrUpdateSurfaceTool.invoke({ + 'surfaceId': 's2', + 'definition': { + 'root': 'root', + 'widgets': [ + { + 'id': 'root', + 'widget': { + 'text': {'text': 'Surface 2'}, + }, }, - }, - ], + ], + } }); await tester.pumpAndSettle(); expect(controller.conversation.value.length, 1); @@ -55,20 +66,23 @@ void main() { testWidgets('manager SurfaceRemoved update removes a surface message', ( WidgetTester tester, ) async { - manager.addOrUpdateSurface('main_surface', { - 'root': 'root', - 'widgets': [ - { - 'id': 'root', - 'widget': { - 'text': {'text': 'Hello!'}, + await addOrUpdateSurfaceTool.invoke({ + 'surfaceId': 'main_surface', + 'definition': { + 'root': 'root', + 'widgets': [ + { + 'id': 'root', + 'widget': { + 'text': {'text': 'Hello!'}, + }, }, - }, - ], + ], + } }); await tester.pumpAndSettle(); expect(controller.conversation.value.length, 1); - manager.deleteSurface('main_surface'); + await deleteSurfaceTool.invoke({'surfaceId': 'main_surface'}); await tester.pumpAndSettle(); expect(controller.conversation.value.length, 0); }); @@ -76,28 +90,34 @@ void main() { testWidgets('manager SurfaceUpdated update modifies a surface message', ( WidgetTester tester, ) async { - manager.addOrUpdateSurface('main_surface', { - 'root': 'root', - 'widgets': [ - { - 'id': 'root', - 'widget': { - 'text': {'text': 'Hello!'}, + await addOrUpdateSurfaceTool.invoke({ + 'surfaceId': 'main_surface', + 'definition': { + 'root': 'root', + 'widgets': [ + { + 'id': 'root', + 'widget': { + 'text': {'text': 'Hello!'}, + }, }, - }, - ], + ], + } }); await tester.pumpAndSettle(); - manager.addOrUpdateSurface('main_surface', { - 'root': 'root', - 'widgets': [ - { - 'id': 'root', - 'widget': { - 'text': {'text': 'Updated'}, + await addOrUpdateSurfaceTool.invoke({ + 'surfaceId': 'main_surface', + 'definition': { + 'root': 'root', + 'widgets': [ + { + 'id': 'root', + 'widget': { + 'text': {'text': 'Updated'}, + }, }, - }, - ], + ], + } }); await tester.pumpAndSettle(); expect(controller.conversation.value.length, 1); diff --git a/pkgs/flutter_genui/test/core/conversation_widget_test.dart b/pkgs/flutter_genui/test/core/conversation_widget_test.dart index b89e3b00b..0560ba66a 100644 --- a/pkgs/flutter_genui/test/core/conversation_widget_test.dart +++ b/pkgs/flutter_genui/test/core/conversation_widget_test.dart @@ -4,15 +4,20 @@ import 'package:flutter/material.dart'; import 'package:flutter_genui/flutter_genui.dart'; +import 'package:flutter_genui/src/model/tools.dart'; import 'package:flutter_test/flutter_test.dart'; void main() { group('ConversationWidget', () { late GenUiManager manager; + late AiTool addOrUpdateSurfaceTool; setUp(() { manager = GenUiManager(catalog: coreCatalog); + addOrUpdateSurfaceTool = manager + .getTools() + .firstWhere((tool) => tool.name == 'addOrUpdateSurface'); }); testWidgets('renders a list of messages', (WidgetTester tester) async { @@ -34,10 +39,10 @@ void main() { }, ), ]; - manager.addOrUpdateSurface( - 's1', - (messages[1] as UiResponseMessage).definition, - ); + await addOrUpdateSurfaceTool.invoke({ + 'surfaceId': 's1', + 'definition': (messages[1] as UiResponseMessage).definition, + }); await tester.pumpWidget( MaterialApp( @@ -91,7 +96,10 @@ void main() { }, ), ]; - manager.addOrUpdateSurface('s1', messages[0].definition); + await addOrUpdateSurfaceTool.invoke({ + 'surfaceId': 's1', + 'definition': messages[0].definition, + }); await tester.pumpWidget( MaterialApp( home: Scaffold( diff --git a/pkgs/flutter_genui/test/core/genui_manager_test.dart b/pkgs/flutter_genui/test/core/genui_manager_test.dart index 982b079e0..d58e6646f 100644 --- a/pkgs/flutter_genui/test/core/genui_manager_test.dart +++ b/pkgs/flutter_genui/test/core/genui_manager_test.dart @@ -4,23 +4,31 @@ import 'package:flutter_genui/src/core/core_catalog.dart'; import 'package:flutter_genui/src/core/genui_manager.dart'; -import 'package:flutter_genui/src/core/surface_manager.dart'; +import 'package:flutter_genui/src/model/tools.dart'; import 'package:flutter_test/flutter_test.dart'; void main() { - group('$GenUiManager', () { + group('GenUiManager', () { late GenUiManager manager; + late AiTool addOrUpdateSurfaceTool; + late AiTool deleteSurfaceTool; setUp(() { manager = GenUiManager(catalog: coreCatalog); + addOrUpdateSurfaceTool = manager + .getTools() + .firstWhere((tool) => tool.name == 'addOrUpdateSurface'); + deleteSurfaceTool = + manager.getTools().firstWhere((tool) => tool.name == 'deleteSurface'); }); tearDown(() { manager.dispose(); }); - test('addOrUpdateSurface adds a new surface and fires SurfaceAdded with ' - 'definition', () async { + test( + 'addOrUpdateSurface tool adds a new surface and fires SurfaceAdded with definition', + () async { final definitionMap = { 'root': 'root', 'widgets': [ @@ -35,7 +43,10 @@ void main() { final futureUpdate = manager.updates.first; - manager.addOrUpdateSurface('s1', definitionMap); + await addOrUpdateSurfaceTool.invoke({ + 'surfaceId': 's1', + 'definition': definitionMap, + }); final update = await futureUpdate; @@ -44,12 +55,14 @@ void main() { final addedUpdate = update as SurfaceAdded; expect(addedUpdate.definition, isNotNull); expect(addedUpdate.definition.root, 'root'); + expect(addedUpdate.controller, isNotNull); expect(manager.surface('s1').value, isNotNull); expect(manager.surface('s1').value!.root, 'root'); }); - test('addOrUpdateSurface updates an existing surface and fires ' - 'SurfaceUpdated', () async { + test( + 'addOrUpdateSurface tool updates an existing surface and fires SurfaceUpdated', + () async { final oldDefinition = { 'root': 'root', 'widgets': [ @@ -61,7 +74,10 @@ void main() { }, ], }; - manager.addOrUpdateSurface('s1', oldDefinition); + await addOrUpdateSurfaceTool.invoke({ + 'surfaceId': 's1', + 'definition': oldDefinition, + }); final newDefinition = { 'root': 'root', @@ -76,7 +92,10 @@ void main() { }; final futureUpdate = manager.updates.first; - manager.addOrUpdateSurface('s1', newDefinition); + await addOrUpdateSurfaceTool.invoke({ + 'surfaceId': 's1', + 'definition': newDefinition, + }); final update = await futureUpdate; expect(update, isA()); @@ -91,7 +110,7 @@ void main() { expect(manager.surface('s1').value, updatedDefinition); }); - test('deleteSurface removes a surface and fires SurfaceRemoved', () async { + test('deleteSurface tool removes a surface and fires SurfaceRemoved', () async { final definition = { 'root': 'root', 'widgets': [ @@ -103,10 +122,13 @@ void main() { }, ], }; - manager.addOrUpdateSurface('s1', definition); + await addOrUpdateSurfaceTool.invoke({ + 'surfaceId': 's1', + 'definition': definition, + }); final futureUpdate = manager.updates.first; - manager.deleteSurface('s1'); + await deleteSurfaceTool.invoke({'surfaceId': 's1'}); final update = await futureUpdate; expect(update, isA()); @@ -136,4 +158,4 @@ void main() { expect(isClosed, isTrue); }); }); -} +} \ No newline at end of file diff --git a/pkgs/flutter_genui/test/core/ui_tools_test.dart b/pkgs/flutter_genui/test/core/ui_tools_test.dart new file mode 100644 index 000000000..8dbad775b --- /dev/null +++ b/pkgs/flutter_genui/test/core/ui_tools_test.dart @@ -0,0 +1,59 @@ +// Copyright 2025 The Flutter Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import 'package:flutter_genui/src/core/ui_tools.dart'; +import 'package:flutter_genui/src/model/catalog.dart'; +import 'package:flutter_test/flutter_test.dart'; + +void main() { + group('AddOrUpdateSurfaceTool', () { + test('invoke calls onAddOrUpdate with correct arguments', () async { + String? calledSurfaceId; + Map? calledDefinition; + + void fakeOnAddOrUpdate(String surfaceId, Map definition) { + calledSurfaceId = surfaceId; + calledDefinition = definition; + } + + final tool = AddOrUpdateSurfaceTool( + onAddOrUpdate: fakeOnAddOrUpdate, + catalog: Catalog([]), + ); + + final args = { + 'surfaceId': 'testSurface', + 'definition': { + 'root': 'rootWidget', + 'widgets': [ + {'id': 'rootWidget', 'type': 'text', 'content': 'Hello'} + ], + }, + }; + + await tool.invoke(args); + + expect(calledSurfaceId, 'testSurface'); + expect(calledDefinition, args['definition']); + }); + }); + + group('DeleteSurfaceTool', () { + test('invoke calls onDelete with correct arguments', () async { + String? calledSurfaceId; + + void fakeOnDelete(String surfaceId) { + calledSurfaceId = surfaceId; + } + + final tool = DeleteSurfaceTool(onDelete: fakeOnDelete); + + final args = {'surfaceId': 'testSurface'}; + + await tool.invoke(args); + + expect(calledSurfaceId, 'testSurface'); + }); + }); +} diff --git a/pkgs/flutter_genui/test/dynamic_ui_test.dart b/pkgs/flutter_genui/test/dynamic_ui_test.dart index fad0e426b..a299d92d2 100644 --- a/pkgs/flutter_genui/test/dynamic_ui_test.dart +++ b/pkgs/flutter_genui/test/dynamic_ui_test.dart @@ -13,6 +13,9 @@ void main() { WidgetTester tester, ) async { final manager = GenUiManager(catalog: testCatalog); + final addOrUpdateSurfaceTool = manager + .getTools() + .firstWhere((tool) => tool.name == 'addOrUpdateSurface'); final definition = { 'root': 'root', 'widgets': [ @@ -30,14 +33,16 @@ void main() { }, ], }; - manager.addOrUpdateSurface('testSurface', definition); + await addOrUpdateSurfaceTool.invoke({ + 'surfaceId': 'testSurface', + 'definition': definition, + }); + final controller = manager.controllers['testSurface']!; await tester.pumpWidget( MaterialApp( home: GenUiSurface( - manager: manager, - surfaceId: 'testSurface', - onEvent: (event) {}, + controller: controller, ), ), ); @@ -49,6 +54,9 @@ void main() { testWidgets('SurfaceWidget handles events', (WidgetTester tester) async { UiEvent? event; final manager = GenUiManager(catalog: testCatalog); + final addOrUpdateSurfaceTool = manager + .getTools() + .firstWhere((tool) => tool.name == 'addOrUpdateSurface'); final definition = { 'root': 'root', 'widgets': [ @@ -66,16 +74,19 @@ void main() { }, ], }; - manager.addOrUpdateSurface('testSurface', definition); + await addOrUpdateSurfaceTool.invoke({ + 'surfaceId': 'testSurface', + 'definition': definition, + }); + final controller = manager.controllers['testSurface']!; + controller.onEvent = (e) { + event = e; + }; await tester.pumpWidget( MaterialApp( home: GenUiSurface( - manager: manager, - surfaceId: 'testSurface', - onEvent: (e) { - event = e; - }, + controller: controller, ), ), ); diff --git a/pkgs/spikes/catalog_gallery/lib/main.dart b/pkgs/spikes/catalog_gallery/lib/main.dart index 892352ba8..3405f4c8d 100644 --- a/pkgs/spikes/catalog_gallery/lib/main.dart +++ b/pkgs/spikes/catalog_gallery/lib/main.dart @@ -84,7 +84,7 @@ class _CatalogViewState extends State { for (final item in items) { final data = item.exampleData!; final surfaceId = item.name; - _genUi.surfaceManager.addOrUpdateSurface(surfaceId, data); + _genUi.addOrUpdateSurface(surfaceId, data); surfaceIds.add(surfaceId); } }