diff --git a/packages/firebase_ai/firebase_ai/example/lib/main.dart b/packages/firebase_ai/firebase_ai/example/lib/main.dart index db1344210deb..ed9748965723 100644 --- a/packages/firebase_ai/firebase_ai/example/lib/main.dart +++ b/packages/firebase_ai/firebase_ai/example/lib/main.dart @@ -31,6 +31,7 @@ import 'pages/json_schema_page.dart'; import 'pages/schema_page.dart'; import 'pages/token_count_page.dart'; import 'pages/video_page.dart'; +import 'pages/server_template_page.dart'; void main() async { WidgetsFlutterBinding.ensureInitialized(); @@ -199,6 +200,11 @@ class _HomeScreenState extends State { model: currentModel, useVertexBackend: useVertexBackend, ); + case 11: + return ServerTemplatePage( + title: 'Server Template', + useVertexBackend: useVertexBackend, + ); default: // Fallback to the first page in case of an unexpected index @@ -227,18 +233,15 @@ class _HomeScreenState extends State { style: TextStyle( fontSize: 12, color: widget.useVertexBackend - ? Theme.of(context) - .colorScheme - .onSurface - .withValues(alpha: 0.7) + ? Theme.of(context).colorScheme.onSurface.withAlpha(180) : Theme.of(context).colorScheme.primary, ), ), Switch( value: widget.useVertexBackend, onChanged: widget.onBackendChanged, - activeTrackColor: Colors.green.withValues(alpha: 0.5), - inactiveTrackColor: Colors.blueGrey.withValues(alpha: 0.5), + activeTrackColor: Colors.green.withAlpha(128), + inactiveTrackColor: Colors.blueGrey.withAlpha(128), activeThumbColor: Colors.green, inactiveThumbColor: Colors.blueGrey, ), @@ -251,7 +254,7 @@ class _HomeScreenState extends State { : Theme.of(context) .colorScheme .onSurface - .withValues(alpha: 0.7), + .withAlpha(180), ), ), ], @@ -273,7 +276,7 @@ class _HomeScreenState extends State { unselectedFontSize: 9, selectedItemColor: Theme.of(context).colorScheme.primary, unselectedItemColor: widget.useVertexBackend - ? Theme.of(context).colorScheme.onSurface.withValues(alpha: 0.7) + ? Theme.of(context).colorScheme.onSurface.withAlpha(180) : Colors.grey, items: const [ BottomNavigationBarItem( @@ -333,6 +336,13 @@ class _HomeScreenState extends State { label: 'Live', tooltip: 'Live Stream', ), + BottomNavigationBarItem( + icon: Icon( + Icons.storage, + ), + label: 'Server', + tooltip: 'Server Template', + ), ], currentIndex: widget.selectedIndex, onTap: _onItemTapped, diff --git a/packages/firebase_ai/firebase_ai/example/lib/pages/image_prompt_page.dart b/packages/firebase_ai/firebase_ai/example/lib/pages/image_prompt_page.dart index 48fc8667af59..5c5009ca3158 100644 --- a/packages/firebase_ai/firebase_ai/example/lib/pages/image_prompt_page.dart +++ b/packages/firebase_ai/firebase_ai/example/lib/pages/image_prompt_page.dart @@ -11,7 +11,6 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - import 'package:flutter/material.dart'; import 'package:firebase_ai/firebase_ai.dart'; import 'package:flutter/services.dart'; @@ -65,11 +64,13 @@ class _ImagePromptPageState extends State { var content = _generatedContent[idx]; return MessageWidget( text: content.text, - image: Image.memory( - content.imageBytes!, - cacheWidth: 400, - cacheHeight: 400, - ), + image: content.imageBytes == null + ? null + : Image.memory( + content.imageBytes!, + cacheWidth: 400, + cacheHeight: 400, + ), isFromUser: content.fromUser ?? false, ); }, diff --git a/packages/firebase_ai/firebase_ai/example/lib/pages/server_template_page.dart b/packages/firebase_ai/firebase_ai/example/lib/pages/server_template_page.dart new file mode 100644 index 000000000000..8ba346d8db9b --- /dev/null +++ b/packages/firebase_ai/firebase_ai/example/lib/pages/server_template_page.dart @@ -0,0 +1,357 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'package:flutter/material.dart'; +import '../widgets/message_widget.dart'; +import 'package:firebase_ai/firebase_ai.dart'; + +class ServerTemplatePage extends StatefulWidget { + const ServerTemplatePage({ + super.key, + required this.title, + required this.useVertexBackend, + }); + + final String title; + final bool useVertexBackend; + + @override + State createState() => _ServerTemplatePageState(); +} + +class _ServerTemplatePageState extends State { + final ScrollController _scrollController = ScrollController(); + final TextEditingController _textController = TextEditingController(); + final FocusNode _textFieldFocus = FocusNode(); + final List _messages = []; + bool _loading = false; + + TemplateGenerativeModel? _templateGenerativeModel; + TemplateChatSession? _chatSession; + TemplateImagenModel? _templateImagenModel; + + @override + void initState() { + super.initState(); + _initializeServerTemplate(); + } + + void _initializeServerTemplate() { + if (widget.useVertexBackend) { + _templateGenerativeModel = + FirebaseAI.vertexAI().templateGenerativeModel(); + _templateImagenModel = FirebaseAI.vertexAI().templateImagenModel(); + } else { + _templateGenerativeModel = + FirebaseAI.googleAI().templateGenerativeModel(); + _templateImagenModel = FirebaseAI.googleAI().templateImagenModel(); + } + _chatSession = _templateGenerativeModel?.startChat('chat_history.prompt'); + } + + void _scrollDown() { + WidgetsBinding.instance.addPostFrameCallback( + (_) => _scrollController.animateTo( + _scrollController.position.maxScrollExtent, + duration: const Duration( + milliseconds: 750, + ), + curve: Curves.easeOutCirc, + ), + ); + } + + @override + Widget build(BuildContext context) { + return Scaffold( + appBar: AppBar( + title: Text(widget.title), + ), + body: Padding( + padding: const EdgeInsets.all(8), + child: Column( + mainAxisAlignment: MainAxisAlignment.center, + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Expanded( + child: ListView.builder( + controller: _scrollController, + itemBuilder: (context, idx) { + final message = _messages[idx]; + return MessageWidget( + text: message.text, + image: message.imageBytes != null + ? Image.memory( + message.imageBytes!, + cacheWidth: 400, + cacheHeight: 400, + ) + : null, + isFromUser: message.fromUser ?? false, + ); + }, + itemCount: _messages.length, + ), + ), + Padding( + padding: const EdgeInsets.symmetric( + vertical: 25, + horizontal: 15, + ), + child: Row( + children: [ + Expanded( + child: TextField( + autofocus: true, + focusNode: _textFieldFocus, + controller: _textController, + onSubmitted: _sendServerTemplateMessage, + ), + ), + const SizedBox.square( + dimension: 15, + ), + if (!_loading) + IconButton( + onPressed: () async { + await _serverTemplateImagen(_textController.text); + }, + icon: Icon( + Icons.image_search, + color: Theme.of(context).colorScheme.primary, + ), + tooltip: 'Imagen', + ) + else + const CircularProgressIndicator(), + if (!_loading) + IconButton( + onPressed: () async { + await _serverTemplateImageInput(_textController.text); + }, + icon: Icon( + Icons.image, + color: Theme.of(context).colorScheme.primary, + ), + tooltip: 'Image Input', + ) + else + const CircularProgressIndicator(), + if (!_loading) + IconButton( + onPressed: () async { + await _serverTemplateChat(_textController.text); + }, + icon: Icon( + Icons.chat, + color: Theme.of(context).colorScheme.primary, + ), + tooltip: 'Chat', + ) + else + const CircularProgressIndicator(), + if (!_loading) + IconButton( + onPressed: () async { + await _sendServerTemplateMessage(_textController.text); + }, + icon: Icon( + Icons.send, + color: Theme.of(context).colorScheme.primary, + ), + tooltip: 'Generate', + ) + else + const CircularProgressIndicator(), + ], + ), + ), + ], + ), + ), + ); + } + + Future _serverTemplateImagen(String message) async { + setState(() { + _loading = true; + }); + MessageData? resultMessage; + try { + _messages.add(MessageData(text: message, fromUser: true)); + // TODO: Add call to Firebase AI SDK + var response = await _templateImagenModel?.generateImages( + 'generate_images.prompt', + { + 'prompt': message, + }, + ); + + if (response!.images.isNotEmpty) { + var imagenImage = response.images[0]; + + resultMessage = MessageData( + imageBytes: imagenImage.bytesBase64Encoded, + text: message, + fromUser: false, + ); + } else { + // Handle the case where no images were generated + _showError('Error: No images were generated.'); + } + + setState(() { + if (resultMessage != null) { + _messages.add(resultMessage); + } + _loading = false; + _scrollDown(); + }); + } catch (e) { + _showError(e.toString()); + setState(() { + _loading = false; + }); + } finally { + _textController.clear(); + setState(() { + _loading = false; + }); + _textFieldFocus.requestFocus(); + } + } + + Future _serverTemplateImageInput(String message) async { + setState(() { + _loading = true; + }); + + try { + _messages.add(MessageData(text: message, fromUser: true)); + // TODO: Add call to Firebase AI SDK + var response = 'Hello! This is a mocked response.'; + _messages.add(MessageData(text: response, fromUser: false)); + + setState(() { + _loading = false; + _scrollDown(); + }); + } catch (e) { + _showError(e.toString()); + setState(() { + _loading = false; + }); + } finally { + _textController.clear(); + setState(() { + _loading = false; + }); + _textFieldFocus.requestFocus(); + } + } + + Future _serverTemplateChat(String message) async { + setState(() { + _loading = true; + }); + + try { + _messages.add( + MessageData(text: message, fromUser: true), + ); + var response = await _chatSession?.sendMessage( + Content.text(message), + { + 'message': message, + }, + ); + + var text = response?.text; + + _messages.add(MessageData(text: text, fromUser: false)); + + setState(() { + _loading = false; + _scrollDown(); + }); + } catch (e) { + _showError(e.toString()); + setState(() { + _loading = false; + }); + } finally { + _textController.clear(); + setState(() { + _loading = false; + }); + _textFieldFocus.requestFocus(); + } + } + + Future _sendServerTemplateMessage(String message) async { + setState(() { + _loading = true; + }); + + try { + var response = await _templateGenerativeModel?.generateContent( + 'greeting.prompt', + { + 'name': message, + 'language': 'Chinese', + }, + ); + + _messages.add(MessageData(text: response?.text, fromUser: false)); + + setState(() { + _loading = false; + _scrollDown(); + }); + } catch (e) { + _showError(e.toString()); + setState(() { + _loading = false; + }); + } finally { + _textController.clear(); + setState(() { + _loading = false; + }); + _textFieldFocus.requestFocus(); + } + } + + void _showError(String message) { + showDialog( + context: context, + builder: (context) { + return AlertDialog( + title: const Text('Something went wrong'), + content: SingleChildScrollView( + child: SelectableText(message), + ), + actions: [ + TextButton( + onPressed: () { + Navigator.of(context).pop(); + }, + child: const Text('OK'), + ), + ], + ); + }, + ); + } +} diff --git a/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart b/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart index fddab7bbfc41..ddce7f3e98b8 100644 --- a/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart +++ b/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart @@ -33,7 +33,12 @@ export 'src/api.dart' SafetySetting, UsageMetadata; export 'src/base_model.dart' - show GenerativeModel, ImagenModel, LiveGenerativeModel; + show + GenerativeModel, + ImagenModel, + LiveGenerativeModel, + TemplateGenerativeModel, + TemplateImagenModel; export 'src/chat.dart' show ChatSession, StartChatExtension; export 'src/content.dart' show @@ -100,6 +105,9 @@ export 'src/live_api.dart' LiveServerResponse; export 'src/live_session.dart' show LiveSession; export 'src/schema.dart' show Schema, SchemaType; +export 'src/server_template/template_chat.dart' + show TemplateChatSession, StartTemplateChatExtension; + export 'src/tool.dart' show FunctionCallingConfig, diff --git a/packages/firebase_ai/firebase_ai/lib/src/base_model.dart b/packages/firebase_ai/firebase_ai/lib/src/base_model.dart index 9c28e62736be..870121e075f2 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/base_model.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/base_model.dart @@ -41,6 +41,8 @@ import 'vertex_version.dart'; part 'generative_model.dart'; part 'imagen/imagen_model.dart'; part 'live_model.dart'; +part 'server_template/template_generative_model.dart'; +part 'server_template/template_imagen_model.dart'; /// [Task] enum class for [GenerativeModel] to make request. enum Task { @@ -57,6 +59,18 @@ enum Task { predict, } +/// [TemplateTask] enum class for [TemplateGenerativeModel] to make request. +enum TemplateTask { + /// Request type for server template generate content. + templateGenerateContent, + + /// Request type for server template stream generate content + templateStreamGenerateContent, + + /// Request type for server template for Prediction Services like Imagen. + templatePredict, +} + abstract interface class _ModelUri { String get baseAuthority; String get apiVersion; @@ -94,6 +108,7 @@ final class _VertexUri implements _ModelUri { } final Uri _projectUri; + @override final ({String prefix, String name}) model; @@ -130,10 +145,12 @@ final class _GoogleAIUri implements _ModelUri { static const _apiVersion = 'v1beta'; static const _baseAuthority = 'firebasevertexai.googleapis.com'; + static Uri _googleAIBaseUri( {String apiVersion = _apiVersion, required FirebaseApp app}) => Uri.https( _baseAuthority, '$apiVersion/projects/${app.options.projectId}'); + final Uri _baseUri; @override @@ -151,6 +168,92 @@ final class _GoogleAIUri implements _ModelUri { .followedBy([model.prefix, '${model.name}:${task.name}'])); } +abstract interface class _TemplateUri { + String get baseAuthority; + String get apiVersion; + Uri templateTaskUri(TemplateTask task, String templateId); + String templateName(String templateId); +} + +final class _TemplateVertexUri implements _TemplateUri { + _TemplateVertexUri({required String location, required FirebaseApp app}) + : _templateUri = _vertexTemplateUri(app, location), + _templateName = _vertexTemplateName(app, location); + + static const _baseAuthority = 'firebasevertexai.googleapis.com'; + static const _apiVersion = 'v1beta'; + + final Uri _templateUri; + final String _templateName; + + static Uri _vertexTemplateUri(FirebaseApp app, String location) { + var projectId = app.options.projectId; + return Uri.https( + _baseAuthority, + '/$_apiVersion/projects/$projectId/locations/$location', + ); + } + + static String _vertexTemplateName(FirebaseApp app, String location) { + var projectId = app.options.projectId; + return 'projects/$projectId/locations/$location'; + } + + @override + String get baseAuthority => _baseAuthority; + + @override + String get apiVersion => _apiVersion; + + @override + Uri templateTaskUri(TemplateTask task, String templateId) { + return _templateUri.replace( + pathSegments: _templateUri.pathSegments + .followedBy(['templates', '$templateId:${task.name}'])); + } + + @override + String templateName(String templateId) => + '$_templateName/templates/$templateId'; +} + +final class _TemplateGoogleAIUri implements _TemplateUri { + _TemplateGoogleAIUri({ + required FirebaseApp app, + }) : _templateUri = _googleAITemplateUri(app: app), + _templateName = _googleAITemplateName(app: app); + + static const _baseAuthority = 'firebasevertexai.googleapis.com'; + static const _apiVersion = 'v1beta'; + final Uri _templateUri; + final String _templateName; + + static Uri _googleAITemplateUri( + {String apiVersion = _apiVersion, required FirebaseApp app}) => + Uri.https( + _baseAuthority, '$apiVersion/projects/${app.options.projectId}'); + + static String _googleAITemplateName({required FirebaseApp app}) => + 'projects/${app.options.projectId}'; + + @override + String get baseAuthority => _baseAuthority; + + @override + String get apiVersion => _apiVersion; + + @override + Uri templateTaskUri(TemplateTask task, String templateId) { + return _templateUri.replace( + pathSegments: _templateUri.pathSegments + .followedBy(['templates', '$templateId:${task.name}'])); + } + + @override + String templateName(String templateId) => + '$_templateName/templates/$templateId'; +} + /// Base class for models. /// /// Do not instantiate directly. @@ -231,3 +334,40 @@ abstract class BaseApiClientModel extends BaseModel { T Function(Map) parse) => _client.makeRequest(taskUri(task), params).then(parse); } + +abstract class BaseTemplateApiClientModel extends BaseApiClientModel { + BaseTemplateApiClientModel( + {required super.serializationStrategy, + required super.modelUri, + required super.client, + required _TemplateUri templateUri}) + : _templateUri = templateUri; + + final _TemplateUri _templateUri; + + /// Make a unary request for [task] with [templateId] and JSON encodable + /// [params]. + Future makeTemplateRequest( + TemplateTask task, + String templateId, + Map? params, + Iterable? history, + T Function(Map) parse) { + Map body = {}; + if (params != null) { + body['inputs'] = params; + } + if (history != null) { + body['history'] = history.map((c) => c.toJson()).toList(); + } + return _client + .makeRequest(templateTaskUri(task, templateId), body) + .then(parse); + } + + Uri templateTaskUri(TemplateTask task, String templateId) => + _templateUri.templateTaskUri(task, templateId); + + String templateName(String templateId) => + _templateUri.templateName(templateId); +} diff --git a/packages/firebase_ai/firebase_ai/lib/src/client.dart b/packages/firebase_ai/firebase_ai/lib/src/client.dart index 221ea50e1af1..1f06d0e9eb99 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/client.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/client.dart @@ -64,6 +64,7 @@ final class HttpApiClient implements ApiClient { Future> makeRequest( Uri uri, Map body) async { final headers = await _headers(); + print('uri: $uri \nbody: $body \nheaders: $headers'); final response = await (_httpClient?.post ?? http.post)( uri, headers: headers, diff --git a/packages/firebase_ai/firebase_ai/lib/src/firebase_ai.dart b/packages/firebase_ai/firebase_ai/lib/src/firebase_ai.dart index 7f3df0d1a3ef..78a060a2343c 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/firebase_ai.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/firebase_ai.dart @@ -198,4 +198,27 @@ class FirebaseAI extends FirebasePluginPlatform { useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens, ); } + + @experimental + TemplateGenerativeModel templateGenerativeModel() { + return createTemplateGenerativeModel( + app: app, + location: location, + useVertexBackend: _useVertexBackend, + useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens, + auth: auth, + appCheck: appCheck); + } + + @experimental + TemplateImagenModel templateImagenModel() { + return createTemplateImagenModel( + app: app, + location: location, + useVertexBackend: _useVertexBackend, + useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens, + auth: auth, + appCheck: appCheck, + ); + } } diff --git a/packages/firebase_ai/firebase_ai/lib/src/server_template/template_chat.dart b/packages/firebase_ai/firebase_ai/lib/src/server_template/template_chat.dart new file mode 100644 index 000000000000..a28092b842ac --- /dev/null +++ b/packages/firebase_ai/firebase_ai/lib/src/server_template/template_chat.dart @@ -0,0 +1,93 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +import '../api.dart'; +import '../base_model.dart'; +import '../content.dart'; +import '../utils/mutex.dart'; + +/// A back-and-forth chat with a server template. +/// +/// Records messages sent and received in [history]. The history will always +/// record the content from the first candidate in the +/// [GenerateContentResponse], other candidates may be available on the returned +/// response. The history is maintained and updated by the `google_generative_ai` +/// package and reflects the most current state of the chat session. +final class TemplateChatSession { + TemplateChatSession._( + this._templateHistoryGenerateContent, + this._templateId, + this._history, + ); + + final Future Function( + Iterable content, + String templateId, + Map params) _templateHistoryGenerateContent; + final String _templateId; + final List _history; + + final _mutex = Mutex(); + + /// The content that has been successfully sent to, or received from, the + /// generative model. + /// + /// If there are outstanding requests from calls to [sendMessage], + /// these will not be reflected in the history. + /// Messages without a candidate in the response are not recorded in history, + /// including the message sent to the model. + Iterable get history => _history.skip(0); + + /// Sends [params] to the server template as a continuation of the chat [history]. + /// + /// Prepends the history to the request and uses the provided model to + /// generate new content. + /// + /// When there are no candidates in the response, the [message] and response + /// are ignored and will not be recorded in the [history]. + Future sendMessage( + Content message, Map params) async { + final lock = await _mutex.acquire(); + try { + final response = await _templateHistoryGenerateContent( + _history.followedBy([message]), + _templateId, + params, + ); + if (response.candidates case [final candidate, ...]) { + _history.add(message); + final normalizedContent = candidate.content.role == null + ? Content('model', candidate.content.parts) + : candidate.content; + _history.add(normalizedContent); + } + return response; + } finally { + lock.release(); + } + } +} + +/// [StartTemplateChatExtension] on [GenerativeModel] +extension StartTemplateChatExtension on TemplateGenerativeModel { + /// Starts a [TemplateChatSession] that will use this model to respond to messages. + /// + /// ```dart + /// final chat = model.startChat(); + /// final response = await chat.sendMessage(Content.text('Hello there.')); + /// print(response.text); + /// ``` + TemplateChatSession startChat(String templateId, {List? history}) => + TemplateChatSession._( + templateGenerateContentWithHistory, templateId, history ?? []); +} diff --git a/packages/firebase_ai/firebase_ai/lib/src/server_template/template_generative_model.dart b/packages/firebase_ai/firebase_ai/lib/src/server_template/template_generative_model.dart new file mode 100644 index 000000000000..a84499f9f7f2 --- /dev/null +++ b/packages/firebase_ai/firebase_ai/lib/src/server_template/template_generative_model.dart @@ -0,0 +1,98 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ignore_for_file: use_late_for_private_fields_and_variables +part of '../base_model.dart'; + +@experimental +final class TemplateGenerativeModel extends BaseTemplateApiClientModel { + TemplateGenerativeModel._({ + required String location, + required FirebaseApp app, + required bool useVertexBackend, + bool? useLimitedUseAppCheckTokens, + FirebaseAppCheck? appCheck, + FirebaseAuth? auth, + http.Client? httpClient, + }) : super( + serializationStrategy: useVertexBackend + ? VertexSerialization() + : DeveloperSerialization(), + modelUri: useVertexBackend + ? _VertexUri(app: app, model: '', location: location) + : _GoogleAIUri(app: app, model: ''), + client: HttpApiClient( + apiKey: app.options.apiKey, + httpClient: httpClient, + requestHeaders: BaseModel.firebaseTokens( + appCheck, auth, app, useLimitedUseAppCheckTokens)), + templateUri: useVertexBackend + ? _TemplateVertexUri(app: app, location: location) + : _TemplateGoogleAIUri(app: app), + ); + + /// Generates content from a template with the given [templateId] and [params]. + /// + /// Sends a "templateGenerateContent" API request for the configured model. + @experimental + Future generateContent( + String templateId, + Map params, + ) => + makeTemplateRequest(TemplateTask.templateGenerateContent, templateId, + params, null, _serializationStrategy.parseGenerateContentResponse); + + /// Generates a stream of content responding to [templateId] and [params]. + /// + /// Sends a "templateStreamGenerateContent" API request for the server template, + /// and waits for the response. + @experimental + Stream generateContentStream( + String templateId, + Map params, + ) { + final response = client.streamRequest( + templateTaskUri(TemplateTask.templateStreamGenerateContent, templateId), + params); + return response.map(_serializationStrategy.parseGenerateContentResponse); + } + + @experimental + Future templateGenerateContentWithHistory( + Iterable history, + String templateId, + Map params, + ) => + makeTemplateRequest(TemplateTask.templateGenerateContent, templateId, + params, history, _serializationStrategy.parseGenerateContentResponse); +} + +/// Returns a [TemplateGenerativeModel] using it's private constructor. +@experimental +TemplateGenerativeModel createTemplateGenerativeModel({ + required FirebaseApp app, + required String location, + required bool useVertexBackend, + bool? useLimitedUseAppCheckTokens, + FirebaseAppCheck? appCheck, + FirebaseAuth? auth, +}) => + TemplateGenerativeModel._( + app: app, + appCheck: appCheck, + useVertexBackend: useVertexBackend, + useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens, + auth: auth, + location: location, + ); diff --git a/packages/firebase_ai/firebase_ai/lib/src/server_template/template_imagen_model.dart b/packages/firebase_ai/firebase_ai/lib/src/server_template/template_imagen_model.dart new file mode 100644 index 000000000000..7e978f985589 --- /dev/null +++ b/packages/firebase_ai/firebase_ai/lib/src/server_template/template_imagen_model.dart @@ -0,0 +1,76 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +part of '../base_model.dart'; + +@experimental +final class TemplateImagenModel extends BaseTemplateApiClientModel { + TemplateImagenModel._( + {required FirebaseApp app, + required String location, + required bool useVertexBackend, + bool? useLimitedUseAppCheckTokens, + FirebaseAppCheck? appCheck, + FirebaseAuth? auth}) + : _useVertexBackend = useVertexBackend, + super( + serializationStrategy: VertexSerialization(), + modelUri: useVertexBackend + ? _VertexUri(app: app, model: '', location: location) + : _GoogleAIUri(app: app, model: ''), + client: HttpApiClient( + apiKey: app.options.apiKey, + requestHeaders: BaseModel.firebaseTokens( + appCheck, auth, app, useLimitedUseAppCheckTokens)), + templateUri: useVertexBackend + ? _TemplateVertexUri(app: app, location: location) + : _TemplateGoogleAIUri(app: app), + ); + + final bool _useVertexBackend; + + /// Generates images from a template with the given [templateId] and [params]. + @experimental + Future> generateImages( + String templateId, + Map params, + ) => + makeTemplateRequest( + TemplateTask.templatePredict, + templateId, + params, + null, + (jsonObject) => + parseImagenGenerationResponse(jsonObject), + ); +} + +/// Returns a [TemplateImagenModel] using it's private constructor. +@experimental +TemplateImagenModel createTemplateImagenModel({ + required FirebaseApp app, + required String location, + required bool useVertexBackend, + bool? useLimitedUseAppCheckTokens, + FirebaseAppCheck? appCheck, + FirebaseAuth? auth, +}) => + TemplateImagenModel._( + app: app, + appCheck: appCheck, + auth: auth, + location: location, + useVertexBackend: useVertexBackend, + useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens, + );