diff --git a/packages/firebase_ai/firebase_ai/example/lib/pages/function_calling_page.dart b/packages/firebase_ai/firebase_ai/example/lib/pages/function_calling_page.dart index 4c500aefde1c..902fa9812bec 100644 --- a/packages/firebase_ai/firebase_ai/example/lib/pages/function_calling_page.dart +++ b/packages/firebase_ai/firebase_ai/example/lib/pages/function_calling_page.dart @@ -40,17 +40,124 @@ class Location { class _FunctionCallingPageState extends State { late GenerativeModel _functionCallModel; + late GenerativeModel _autoFunctionCallModel; + late GenerativeModel _parallelAutoFunctionCallModel; late GenerativeModel _codeExecutionModel; + late final AutoFunctionDeclaration _autoFetchWeatherTool; final List _messages = []; bool _loading = false; bool _enableThinking = false; + late final AutoFunctionDeclaration _autoFindRestaurantsTool; + late final AutoFunctionDeclaration _autoGetRestaurantMenuTool; + @override void initState() { super.initState(); + _autoFetchWeatherTool = AutoFunctionDeclaration( + name: 'autofetchWeather', + description: + 'Get the weather conditions for a specific city on a specific date.', + parameters: { + 'location': Schema.object( + description: + 'The name of the city and its state for which to get the weather. Only cities in the USA are supported.', + properties: { + 'city': Schema.string( + description: 'The city of the location.', + ), + 'state': Schema.string( + description: 'The state of the location.', + ), + }, + ), + 'date': Schema.string( + description: + 'The date for which to get the weather. Date must be in the format: YYYY-MM-DD.', + ), + }, + callable: _fetchWeatherCallable, + ); + _autoFindRestaurantsTool = AutoFunctionDeclaration( + name: 'findRestaurants', + description: 'Find restaurants of a certain cuisine in a given location.', + parameters: { + 'cuisine': Schema.string( + description: 'The cuisine of the restaurant.', + ), + 'location': Schema.string( + description: + 'The location to search for restaurants. e.g. San Francisco, CA', + ), + }, + callable: (args) async { + final cuisine = args['cuisine']; + final location = args['location']; + if (cuisine is String && location is String) { + return findRestaurants(cuisine, location); + } + // It's good practice to handle cases where arguments are missing or have the wrong type. + throw Exception('Missing or invalid arguments for findRestaurants'); + }, + ); + _autoGetRestaurantMenuTool = AutoFunctionDeclaration( + name: 'getRestaurantMenu', + description: 'Get the menu for a specific restaurant.', + parameters: { + 'restaurantName': Schema.string( + description: 'The name of the restaurant.', + ), + }, + callable: (args) async { + final restaurantName = args['restaurantName']! as String; + return getRestaurantMenu(restaurantName); + }, + ); _initializeModel(); } + Future> findRestaurants( + String cuisine, + String location, + ) async { + // This is a mock response. + return { + 'restaurants': [ + { + 'name': 'The Golden Spoon', + 'cuisine': 'Vegetarian', + 'location': 'San Francisco, CA', + }, + { + 'name': 'Green Leaf Bistro', + 'cuisine': 'Vegetarian', + 'location': 'San Francisco, CA', + }, + ], + }; + } + + Future> getRestaurantMenu(String restaurantName) async { + // This is a mock response. + return { + 'menu': [ + {'name': 'Lentil Soup', 'price': '8.99'}, + {'name': 'Garden Salad', 'price': '10.99'}, + {'name': 'Mushroom Risotto', 'price': '15.99'}, + ], + }; + } + + Future> _fetchWeatherCallable( + Map args, + ) async { + final locationData = args['location']! as Map; + final city = locationData['city']! as String; + final state = locationData['state']! as String; + final date = args['date']! as String; + return fetchWeather(Location(city, state), date); + } + void _initializeModel() { final generationConfig = GenerationConfig( thinkingConfig: _enableThinking @@ -60,39 +167,41 @@ class _FunctionCallingPageState extends State { ) : null, ); - if (widget.useVertexBackend) { - var vertexAI = FirebaseAI.vertexAI(auth: FirebaseAuth.instance); - _functionCallModel = vertexAI.generativeModel( - model: 'gemini-2.5-flash', - generationConfig: generationConfig, - tools: [ - Tool.functionDeclarations([fetchWeatherTool]), - ], - ); - _codeExecutionModel = vertexAI.generativeModel( - model: 'gemini-2.5-flash', - generationConfig: generationConfig, - tools: [ - Tool.codeExecution(), - ], - ); - } else { - var googleAI = FirebaseAI.googleAI(auth: FirebaseAuth.instance); - _functionCallModel = googleAI.generativeModel( - model: 'gemini-2.5-flash', - generationConfig: generationConfig, - tools: [ - Tool.functionDeclarations([fetchWeatherTool]), - ], - ); - _codeExecutionModel = googleAI.generativeModel( - model: 'gemini-2.5-flash', - generationConfig: generationConfig, - tools: [ - Tool.codeExecution(), - ], - ); - } + + final aiClient = widget.useVertexBackend + ? FirebaseAI.vertexAI(auth: FirebaseAuth.instance) + : FirebaseAI.googleAI(auth: FirebaseAuth.instance); + + _functionCallModel = aiClient.generativeModel( + model: 'gemini-2.5-flash', + generationConfig: generationConfig, + tools: [ + Tool.functionDeclarations([fetchWeatherTool]), + ], + ); + _autoFunctionCallModel = aiClient.generativeModel( + model: 'gemini-2.5-flash', + generationConfig: generationConfig, + tools: [ + Tool.functionDeclarations([_autoFetchWeatherTool]), + ], + ); + _parallelAutoFunctionCallModel = aiClient.generativeModel( + model: 'gemini-2.5-flash', + generationConfig: generationConfig, + tools: [ + Tool.functionDeclarations( + [_autoFindRestaurantsTool, _autoGetRestaurantMenuTool], + ), + ], + ); + _codeExecutionModel = aiClient.generativeModel( + model: 'gemini-2.5-flash', + generationConfig: generationConfig, + tools: [ + Tool.codeExecution(), + ], + ); } // This is a hypothetical API to return a fake weather data collection for @@ -136,6 +245,36 @@ class _FunctionCallingPageState extends State { }, ); + Future> _executeFunctionCall(FunctionCall call) async { + if (call.name == 'fetchWeather') { + final location = call.args['location']! as Map; + final date = call.args['date']! as String; + final city = location['city'] as String; + final state = location['state'] as String; + return fetchWeather(Location(city, state), date); + } + throw UnimplementedError( + 'Function not declared to the model: ${call.name}', + ); + } + + Future _runTest(Future Function() testBody) async { + if (_loading) return; + setState(() { + _loading = true; + _messages.clear(); + }); + try { + await testBody(); + } catch (e) { + _showError(e.toString()); + } finally { + setState(() { + _loading = false; + }); + } + } + @override Widget build(BuildContext context) { return Scaffold( @@ -176,28 +315,65 @@ class _FunctionCallingPageState extends State { vertical: 25, horizontal: 15, ), - child: Row( + child: Column( children: [ - Expanded( - child: ElevatedButton( - onPressed: !_loading - ? () async { - await _testFunctionCalling(); - } - : null, - child: const Text('Test Function Calling'), - ), + Row( + children: [ + Expanded( + child: ElevatedButton( + onPressed: !_loading ? _testFunctionCalling : null, + child: const Text('Manual FC'), + ), + ), + const SizedBox(width: 8), + Expanded( + child: ElevatedButton( + onPressed: !_loading ? _testCodeExecution : null, + child: const Text('Code Execution'), + ), + ), + ], ), - const SizedBox(width: 8), - Expanded( - child: ElevatedButton( - onPressed: !_loading - ? () async { - await _testCodeExecution(); - } - : null, - child: const Text('Test Code Execution'), - ), + const SizedBox(height: 8), + Row( + children: [ + Expanded( + child: ElevatedButton( + onPressed: + !_loading ? _testAutoFunctionCalling : null, + child: const Text('Auto Function Calling'), + ), + ), + const SizedBox(width: 8), + Expanded( + child: ElevatedButton( + onPressed: !_loading + ? () => _testAutoFunctionCalling(parallel: true) + : null, + child: const Text('Parallel Auto FC'), + ), + ), + ], + ), + const SizedBox(height: 8), + Row( + children: [ + Expanded( + child: ElevatedButton( + onPressed: + !_loading ? _testStreamFunctionCalling : null, + child: const Text('Stream FC'), + ), + ), + const SizedBox(width: 8), + Expanded( + child: ElevatedButton( + onPressed: + !_loading ? _testAutoStreamFunctionCalling : null, + child: const Text('Auto Stream FC'), + ), + ), + ], ), ], ), @@ -208,17 +384,154 @@ class _FunctionCallingPageState extends State { ); } - Future _testFunctionCalling() async { - setState(() { - _loading = true; - _messages.clear(); + Future _testAutoFunctionCalling({bool parallel = false}) async { + await _runTest(() async { + final model = + parallel ? _parallelAutoFunctionCallModel : _autoFunctionCallModel; + final prompt = parallel + ? 'Find me a good vegetarian restaurant in San Francisco and get its menu.' + : 'What is the weather like in Boston, MA on 10/02 in year 2024?'; + + final autoFunctionCallChat = model.startChat(); + + _messages.add(MessageData(text: prompt, fromUser: true)); + setState(() {}); + + // Send the message to the generative model. + final response = await autoFunctionCallChat.sendMessage( + Content.text(prompt), + ); + + final thought = response.thoughtSummary; + if (thought != null) { + _messages + .add(MessageData(text: thought, fromUser: false, isThought: true)); + } + + // The SDK should have handled the function call automatically. + // The final response should contain the text from the model. + if (response.text case final text?) { + _messages.add(MessageData(text: text)); + } else { + _messages.add(MessageData(text: 'No text response from model.')); + } }); - try { + } + + Future _testStreamFunctionCalling() async { + await _runTest(() async { final functionCallChat = _functionCallModel.startChat(); const prompt = - 'What is the weather like in Boston on 10/02 in year 2024?'; + 'What is the weather like in Boston, MA on 10/02 in year 2024?'; _messages.add(MessageData(text: prompt, fromUser: true)); + setState(() {}); + + // Send the message to the generative model. + final responseStream = functionCallChat.sendMessageStream( + Content.text(prompt), + ); + + GenerateContentResponse? lastResponse; + await for (final response in responseStream) { + lastResponse = response; + final thought = response.thoughtSummary; + if (thought != null) { + _messages.add( + MessageData(text: thought, fromUser: false, isThought: true), + ); + setState(() {}); + } + } + + final functionCalls = lastResponse?.functionCalls.toList(); + // When the model response with a function call, invoke the function. + if (functionCalls != null && functionCalls.isNotEmpty) { + final functionCall = functionCalls.first; + final functionResult = await _executeFunctionCall(functionCall); + // Send the response to the model so that it can use the result to + // generate text for the user. + final responseStream2 = functionCallChat.sendMessageStream( + Content.functionResponse(functionCall.name, functionResult), + ); + + var accumulatedText = ''; + _messages.add(MessageData(text: accumulatedText)); + setState(() {}); + + await for (final response in responseStream2) { + if (response.text case final text?) { + accumulatedText += text; + _messages.last = _messages.last.copyWith(text: accumulatedText); + setState(() {}); + } + } + } else if (lastResponse?.text case final text?) { + // This would be if no function call was returned. + _messages.add(MessageData(text: text)); + setState(() {}); + } else { + _messages.add(MessageData(text: 'No text response from model.')); + } + }); + } + + Future _testAutoStreamFunctionCalling() async { + await _runTest(() async { + final autoFunctionCallChat = _autoFunctionCallModel.startChat(); + const prompt = + 'What is the weather like in Boston, MA on 10/02 in year 2024?'; + + _messages.add(MessageData(text: prompt, fromUser: true)); + setState(() {}); + + // Send the message to the generative model. + final responseStream = autoFunctionCallChat.sendMessageStream( + Content.text(prompt), + ); + + var accumulatedText = ''; + MessageData? modelMessage; + + await for (final response in responseStream) { + final thought = response.thoughtSummary; + if (thought != null) { + _messages.add( + MessageData(text: thought, fromUser: false, isThought: true), + ); + setState(() {}); + } + + // The SDK should have handled the function call automatically. + // The final response should contain the text from the model. + if (response.text case final text?) { + accumulatedText += text; + if (modelMessage == null) { + modelMessage = MessageData(text: accumulatedText); + _messages.add(modelMessage); + } else { + modelMessage = modelMessage.copyWith(text: accumulatedText); + _messages.last = modelMessage; + } + setState(() {}); + } + } + + if (accumulatedText.isEmpty) { + _messages.add(MessageData(text: 'No text response from model.')); + setState(() {}); + } + }); + } + + Future _testFunctionCalling() async { + await _runTest(() async { + final functionCallChat = _functionCallModel.startChat(); + const prompt = + 'What is the weather like in Boston, MA on 10/02 in year 2024?'; + + _messages.add(MessageData(text: prompt, fromUser: true)); + setState(() {}); // Send the message to the generative model. var response = await functionCallChat.sendMessage( @@ -235,54 +548,28 @@ class _FunctionCallingPageState extends State { // When the model response with a function call, invoke the function. if (functionCalls.isNotEmpty) { final functionCall = functionCalls.first; - if (functionCall.name == 'fetchWeather') { - Map location = - functionCall.args['location']! as Map; - var date = functionCall.args['date']! as String; - var city = location['city'] as String; - var state = location['state'] as String; - final functionResult = - await fetchWeather(Location(city, state), date); - // Send the response to the model so that it can use the result to - // generate text for the user. - response = await functionCallChat.sendMessage( - Content.functionResponse(functionCall.name, functionResult), - ); - } else { - throw UnimplementedError( - 'Function not declared to the model: ${functionCall.name}', - ); - } + final functionResult = await _executeFunctionCall(functionCall); + // Send the response to the model so that it can use the result to + // generate text for the user. + response = await functionCallChat.sendMessage( + Content.functionResponse(functionCall.name, functionResult), + ); } // When the model responds with non-null text content, print it. if (response.text case final text?) { _messages.add(MessageData(text: text)); - setState(() { - _loading = false; - }); } - } catch (e) { - _showError(e.toString()); - setState(() { - _loading = false; - }); - } finally { - setState(() { - _loading = false; - }); - } + }); } Future _testCodeExecution() async { - setState(() { - _loading = true; - }); - try { + await _runTest(() async { final codeExecutionChat = _codeExecutionModel.startChat(); const prompt = 'What is the sum of the first 50 prime numbers? ' 'Generate and run code for the calculation, and make sure you get all 50.'; _messages.add(MessageData(text: prompt, fromUser: true)); + setState(() {}); final response = await codeExecutionChat.sendMessage(Content.text(prompt)); @@ -318,20 +605,7 @@ class _FunctionCallingPageState extends State { ), ); } - - setState(() { - _loading = false; - }); - } catch (e) { - _showError(e.toString()); - setState(() { - _loading = false; - }); - } finally { - setState(() { - _loading = false; - }); - } + }); } void _showError(String message) { diff --git a/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart b/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart index 6c05e772f062..ea3031b09970 100644 --- a/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart +++ b/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart @@ -113,6 +113,7 @@ export 'src/schema.dart' show Schema, SchemaType; export 'src/tool.dart' show + AutoFunctionDeclaration, FunctionCallingConfig, FunctionCallingMode, FunctionDeclaration, diff --git a/packages/firebase_ai/firebase_ai/lib/src/chat.dart b/packages/firebase_ai/firebase_ai/lib/src/chat.dart index 6a0846fd0214..f460b3749a60 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/chat.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/chat.dart @@ -17,6 +17,7 @@ import 'dart:async'; import 'api.dart'; import 'base_model.dart'; import 'content.dart'; +import 'tool.dart'; import 'utils/chat_utils.dart'; import 'utils/mutex.dart'; @@ -27,8 +28,20 @@ import 'utils/mutex.dart'; /// [GenerateContentResponse], other candidates may be available on the returned /// response. The history reflects the most current state of the chat session. final class ChatSession { - ChatSession._(this._generateContent, this._generateContentStream, - this._history, this._safetySettings, this._generationConfig); + ChatSession._( + this._generateContent, + this._generateContentStream, + this._history, + this._safetySettings, + this._generationConfig, + List? tools, + this._maxTurns) + : _autoFunctionDeclarations = tools + ?.expand((tool) => tool.autoFunctionDeclarations) + .fold({}, (map, function) { + map?[function.name] = function; + return map; + }); final Future Function(Iterable content, {List? safetySettings, GenerationConfig? generationConfig}) _generateContent; @@ -41,6 +54,8 @@ final class ChatSession { final List _history; final List? _safetySettings; final GenerationConfig? _generationConfig; + final Map? _autoFunctionDeclarations; + final int _maxTurns; /// The content that has been successfully sent to, or received from, the /// generative model. @@ -66,16 +81,56 @@ final class ChatSession { Future sendMessage(Content message) async { final lock = await _mutex.acquire(); try { - final response = await _generateContent(_history.followedBy([message]), - safetySettings: _safetySettings, generationConfig: _generationConfig); - if (response.candidates case [final candidate, ...]) { - _history.add(message); - final normalizedContent = candidate.content.role == null - ? Content('model', candidate.content.parts) - : candidate.content; - _history.add(normalizedContent); + final requestHistory = [message]; + var turn = 0; + while (turn < _maxTurns) { + final response = await _generateContent( + _history.followedBy(requestHistory), + safetySettings: _safetySettings, + generationConfig: _generationConfig); + + final functionCalls = response.functionCalls; + + // Only trigger auto-execution if: + // 1. We have auto-functions configured. + // 2. The response actually contains function calls. + // 3. ALL called functions exist in our declarations (prevents crashes). + final shouldAutoExecute = _autoFunctionDeclarations != null && + _autoFunctionDeclarations.isNotEmpty && + functionCalls.isNotEmpty && + functionCalls + .every((c) => _autoFunctionDeclarations.containsKey(c.name)); + if (!shouldAutoExecute) { + // Standard handling: Update history and return the response to the user. + if (response.candidates case [final candidate, ...]) { + _history.addAll(requestHistory); + final normalizedContent = candidate.content.role == null + ? Content('model', candidate.content.parts) + : candidate.content; + _history.add(normalizedContent); + } + return response; + } + + // Auto function execution + requestHistory.add(response.candidates.first.content); + final functionResponses = []; + for (final functionCall in functionCalls) { + final function = _autoFunctionDeclarations[functionCall.name]; + + Object? result; + try { + result = await function!.callable(functionCall.args); + } catch (e) { + result = e.toString(); + } + functionResponses + .add(FunctionResponse(functionCall.name, {'result': result})); + } + requestHistory.add(Content('function', functionResponses)); + turn++; } - return response; + throw Exception('Max turns of $_maxTurns reached.'); } finally { lock.release(); } @@ -99,28 +154,72 @@ final class ChatSession { /// Waits to read the entire streamed response before recording the message /// and response and allowing pending messages to be sent. Stream sendMessageStream(Content message) { - final controller = StreamController(sync: true); + final controller = StreamController(); _mutex.acquire().then((lock) async { try { - final responses = _generateContentStream(_history.followedBy([message]), - safetySettings: _safetySettings, - generationConfig: _generationConfig); - final content = []; - await for (final response in responses) { - if (response.candidates case [final candidate, ...]) { - content.add(candidate.content); + final requestHistory = [message]; + var turn = 0; + while (turn < _maxTurns) { + final responses = _generateContentStream( + _history.followedBy(requestHistory), + safetySettings: _safetySettings, + generationConfig: _generationConfig); + + final turnChunks = []; + await for (final response in responses) { + turnChunks.add(response); + controller.add(response); } - controller.add(response); - } - if (content.isNotEmpty) { - _history.add(message); - _history.add(historyAggregate(content)); + if (turnChunks.isEmpty) break; + final aggregatedContent = historyAggregate(turnChunks.map((r) { + final content = r.candidates.firstOrNull?.content; + if (content == null) { + throw Exception('No content in response candidate'); + } + return content; + }).toList()); + + final functionCalls = + aggregatedContent.parts.whereType().toList(); + + // Check if we should actually execute these functions. + final shouldAutoExecute = _autoFunctionDeclarations != null && + _autoFunctionDeclarations.isNotEmpty && + functionCalls.isNotEmpty && + functionCalls + .every((c) => _autoFunctionDeclarations.containsKey(c.name)); + + if (!shouldAutoExecute) { + _history.addAll(requestHistory); + _history.add(aggregatedContent); + return; + } + + requestHistory.add(aggregatedContent); + final functionResponseFutures = + functionCalls.map((functionCall) async { + final function = _autoFunctionDeclarations[functionCall.name]; + + Object? result; + try { + result = await function!.callable(functionCall.args); + } catch (e) { + result = e.toString(); + } + return FunctionResponse(functionCall.name, {'result': result}); + }); + final functionResponseParts = + await Future.wait(functionResponseFutures); + requestHistory.add(Content.functionResponses(functionResponseParts)); + turn++; } + throw Exception('Max turns of $_maxTurns reached.'); } catch (e, s) { controller.addError(e, s); + } finally { + lock.release(); + unawaited(controller.close()); } - lock.release(); - unawaited(controller.close()); }); return controller.stream; } @@ -138,7 +237,8 @@ extension StartChatExtension on GenerativeModel { ChatSession startChat( {List? history, List? safetySettings, - GenerationConfig? generationConfig}) => + GenerationConfig? generationConfig, + int? maxTurns}) => ChatSession._(generateContent, generateContentStream, history ?? [], - safetySettings, generationConfig); + safetySettings, generationConfig, tools, maxTurns ?? 5); } diff --git a/packages/firebase_ai/firebase_ai/lib/src/generative_model.dart b/packages/firebase_ai/firebase_ai/lib/src/generative_model.dart index 4f570e9446d6..2bf58351eafe 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/generative_model.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/generative_model.dart @@ -42,13 +42,12 @@ final class GenerativeModel extends BaseApiClientModel { FirebaseAuth? auth, List? safetySettings, GenerationConfig? generationConfig, - List? tools, + this.tools, ToolConfig? toolConfig, Content? systemInstruction, http.Client? httpClient, }) : _safetySettings = safetySettings ?? [], _generationConfig = generationConfig, - _tools = tools, _toolConfig = toolConfig, _systemInstruction = systemInstruction, super( @@ -74,13 +73,12 @@ final class GenerativeModel extends BaseApiClientModel { FirebaseAuth? auth, List? safetySettings, GenerationConfig? generationConfig, - List? tools, + this.tools, ToolConfig? toolConfig, Content? systemInstruction, ApiClient? apiClient, }) : _safetySettings = safetySettings ?? [], _generationConfig = generationConfig, - _tools = tools, _toolConfig = toolConfig, _systemInstruction = systemInstruction, super( @@ -98,7 +96,9 @@ final class GenerativeModel extends BaseApiClientModel { final List _safetySettings; final GenerationConfig? _generationConfig; - final List? _tools; + + /// List of [Tool] registered in the model + final List? tools; final ToolConfig? _toolConfig; final Content? _systemInstruction; @@ -125,7 +125,7 @@ final class GenerativeModel extends BaseApiClientModel { model, safetySettings ?? _safetySettings, generationConfig ?? _generationConfig, - tools ?? _tools, + tools ?? this.tools, toolConfig ?? _toolConfig, _systemInstruction, ), @@ -156,7 +156,7 @@ final class GenerativeModel extends BaseApiClientModel { model, safetySettings ?? _safetySettings, generationConfig ?? _generationConfig, - tools ?? _tools, + tools ?? this.tools, toolConfig ?? _toolConfig, _systemInstruction, )); @@ -188,7 +188,7 @@ final class GenerativeModel extends BaseApiClientModel { model, _safetySettings, _generationConfig, - _tools, + tools, _toolConfig, ); return makeRequest(Task.countTokens, parameters, diff --git a/packages/firebase_ai/firebase_ai/lib/src/tool.dart b/packages/firebase_ai/firebase_ai/lib/src/tool.dart index 25677fd9cb54..6b92fbd5d96c 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/tool.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/tool.dart @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +import 'dart:async'; + import 'schema.dart'; /// Tool details that the model may use to generate a response. @@ -95,6 +97,15 @@ final class Tool { /// A tool that allows providing URL context to the model. final UrlContext? _urlContext; + /// Returns a list of all [AutoFunctionDeclaration] objects + /// found within the [_functionDeclarations] list. + List get autoFunctionDeclarations { + return _functionDeclarations + ?.whereType() + .toList() ?? + []; + } + /// Convert to json object. Map toJson() => { if (_functionDeclarations case final _functionDeclarations?) @@ -158,7 +169,7 @@ final class CodeExecution { /// Included in this declaration are the function name and parameters. This /// FunctionDeclaration is a representation of a block of code that can be used /// as a `Tool` by the model and executed by the client. -final class FunctionDeclaration { +class FunctionDeclaration { // ignore: public_member_api_docs FunctionDeclaration(this.name, this.description, {required Map parameters, @@ -185,6 +196,29 @@ final class FunctionDeclaration { }; } +/// A [FunctionDeclaration] for auto function calling. +final class AutoFunctionDeclaration extends FunctionDeclaration { + /// Creates an [AutoFunctionDeclaration]. + /// + /// - [name]: The name of the function. + /// - [description]: A brief description of the function. + /// - [parameters]: The parameters of the function as a map of names to + /// [Schema] objects. + /// - [callable]: The actual function implementation. + AutoFunctionDeclaration({ + required String name, + required String description, + required Map parameters, + List optionalParameters = const [], + required this.callable, + }) : super(name, description, + parameters: parameters, optionalParameters: optionalParameters); + + /// The callable function that this declaration represents. + final FutureOr> Function(Map args) + callable; +} + /// Config for tools to use with model. final class ToolConfig { // ignore: public_member_api_docs diff --git a/packages/firebase_ai/firebase_ai/test/tool_test.dart b/packages/firebase_ai/firebase_ai/test/tool_test.dart new file mode 100644 index 000000000000..affd00691b59 --- /dev/null +++ b/packages/firebase_ai/firebase_ai/test/tool_test.dart @@ -0,0 +1,234 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'package:firebase_ai/src/schema.dart'; +import 'package:firebase_ai/src/tool.dart'; +import 'package:flutter_test/flutter_test.dart'; + +void main() { + group('Tool Tests', () { + test('AutoFunctionDeclaration basic properties and toJson', () async { + // Define a simple callable function + Future> myFunction(Map args) async { + return { + 'result': 'Hello, ${args['name']}!', + 'age_plus_ten': (args['age']! as int) + 10, + }; + } + + // Define the schema for the function's parameters + final parametersSchema = { + 'name': Schema.string(description: 'The name to greet'), + 'age': Schema.integer(description: 'The age of the person'), + }; + + // Create an AutoFunctionDeclaration + final autoDeclaration = AutoFunctionDeclaration( + name: 'greetUser', + description: + 'Greets a user with their name and calculates age plus ten.', + parameters: parametersSchema, + callable: myFunction, + ); + + // Verify properties + expect(autoDeclaration.name, 'greetUser'); + expect(autoDeclaration.description, + 'Greets a user with their name and calculates age plus ten.'); + expect(autoDeclaration.callable, myFunction); + + // Verify toJson output (should match FunctionDeclaration's toJson) + expect(autoDeclaration.toJson(), { + 'name': 'greetUser', + 'description': + 'Greets a user with their name and calculates age plus ten.', + 'parameters': { + 'type': 'OBJECT', + 'properties': { + 'name': {'type': 'STRING', 'description': 'The name to greet'}, + 'age': {'type': 'INTEGER', 'description': 'The age of the person'}, + }, + 'required': ['name', 'age'], + }, + }); + + // Optionally, test invoking the callable directly (simulating client execution) + final result = + await autoDeclaration.callable({'name': 'Alice', 'age': 30}); + expect(result, {'result': 'Hello, Alice!', 'age_plus_ten': 40}); + }); + + test('AutoFunctionDeclaration with optional parameters', () async { + Future> optionalParamFunction( + Map args) async { + final greeting = + args['name'] != null ? 'Hello, ${args['name']}!' : 'Hello!'; + + return {'message': greeting}; + } + + final parametersSchema = { + 'name': Schema.string(description: 'An optional name'), + }; + + final autoDeclaration = AutoFunctionDeclaration( + name: 'optionalGreet', + description: 'Greets a user, optionally by name.', + parameters: parametersSchema, + optionalParameters: const ['name'], + callable: optionalParamFunction, + ); + + expect(autoDeclaration.name, 'optionalGreet'); + expect(autoDeclaration.description, 'Greets a user, optionally by name.'); + expect(autoDeclaration.callable, optionalParamFunction); + expect(autoDeclaration.toJson(), { + 'name': 'optionalGreet', + 'description': 'Greets a user, optionally by name.', + 'parameters': { + 'type': 'OBJECT', + + 'properties': { + 'name': {'type': 'STRING', 'description': 'An optional name'}, + }, + + 'required': [], // 'name' is optional, so 'required' is empty + }, + }); + + final resultWithoutName = await autoDeclaration.callable({}); + expect(resultWithoutName, {'message': 'Hello!'}); + final resultWithName = await autoDeclaration.callable({'name': 'Bob'}); + expect(resultWithName, {'message': 'Hello, Bob!'}); + }); + + // Test FunctionCallingConfig + test('FunctionCallingConfig.auto()', () { + final config = FunctionCallingConfig.auto(); + expect(config.mode, FunctionCallingMode.auto); + expect(config.allowedFunctionNames, isNull); + expect(config.toJson(), {'mode': 'AUTO'}); + }); + + test('FunctionCallingConfig.any()', () { + final allowedNames = {'func1', 'func2'}; + final config = FunctionCallingConfig.any(allowedNames); + expect(config.mode, FunctionCallingMode.any); + expect(config.allowedFunctionNames, allowedNames); + expect(config.toJson(), { + 'mode': 'ANY', + 'allowedFunctionNames': ['func1', 'func2'], + }); + }); + + test('FunctionCallingConfig.none()', () { + final config = FunctionCallingConfig.none(); + expect(config.mode, FunctionCallingMode.none); + expect(config.allowedFunctionNames, isNull); + expect(config.toJson(), {'mode': 'NONE'}); + }); + + // Test FunctionCallingMode.toJson() + test('FunctionCallingMode.toJson()', () { + expect(FunctionCallingMode.auto.toJson(), 'AUTO'); + expect(FunctionCallingMode.any.toJson(), 'ANY'); + expect(FunctionCallingMode.none.toJson(), 'NONE'); + }); + + // Test Tool.functionDeclarations() + test('Tool.functionDeclarations()', () { + final functionDeclaration = AutoFunctionDeclaration( + name: 'myFunction', + description: 'Does something.', + parameters: {'param1': Schema.string()}, + callable: (args) async => {'result': 'Success'}, + ); + + final tool = Tool.functionDeclarations([functionDeclaration]); + + expect(tool.toJson(), { + 'functionDeclarations': [ + { + 'name': 'myFunction', + 'description': 'Does something.', + 'parameters': { + 'type': 'OBJECT', + 'properties': { + 'param1': {'type': 'STRING'}, + }, + 'required': ['param1'], + }, + } + ] + }); + }); + + // Test Tool.googleSearch() + + test('Tool.googleSearch()', () { + final tool = Tool.googleSearch(); + expect(tool.toJson(), { + 'googleSearch': {}, + }); + }); + + // Test Tool.codeExecution() + + test('Tool.codeExecution()', () { + final tool = Tool.codeExecution(); + expect(tool.toJson(), { + 'codeExecution': {}, + }); + }); + + // Test Tool.urlContext() + test('Tool.urlContext()', () { + final tool = Tool.urlContext(); + expect(tool.toJson(), { + 'urlContext': {}, + }); + }); + + // Test ToolConfig + test('ToolConfig with FunctionCallingConfig', () { + final config = ToolConfig( + functionCallingConfig: FunctionCallingConfig.auto(), + ); + expect(config.toJson(), { + 'functionCallingConfig': {'mode': 'AUTO'}, + }); + }); + + test('ToolConfig with null FunctionCallingConfig', () { + final config = ToolConfig(); + expect(config.toJson(), {}); + }); + + // Test GoogleSearch, CodeExecution, UrlContext toJson() + test('GoogleSearch.toJson()', () { + const search = GoogleSearch(); + expect(search.toJson(), {}); + }); + + test('CodeExecution.toJson()', () { + const execution = CodeExecution(); + expect(execution.toJson(), {}); + }); + + test('UrlContext.toJson()', () { + const context = UrlContext(); + expect(context.toJson(), {}); + }); + }); +}