diff --git a/firebaseai/src/LiveSession.cs b/firebaseai/src/LiveSession.cs index 7eb5eb9c..52030a7d 100644 --- a/firebaseai/src/LiveSession.cs +++ b/firebaseai/src/LiveSession.cs @@ -103,6 +103,28 @@ public async Task SendAsync( ModelContent? content = null, bool turnComplete = false, CancellationToken cancellationToken = default) { + // If the content has FunctionResponseParts, we handle those separately. + if (content.HasValue) { + var functionParts = content?.Parts.OfType().ToList(); + if (functionParts.Count > 0) { + Dictionary toolResponse = new() { + { "toolResponse", new Dictionary() { + { "functionResponses", functionParts.Select(frPart => (frPart as ModelContent.Part).ToJson()["functionResponse"]).ToList() } + }} + }; + var toolResponseBytes = Encoding.UTF8.GetBytes(Json.Serialize(toolResponse)); + + await InternalSendBytesAsync(new ArraySegment(toolResponseBytes), cancellationToken); + if (functionParts.Count < content?.Parts.Count) { + // There are other parts to send, so send them with the other method. + content = new ModelContent(role: content?.Role, + parts: content?.Parts.Where(p => p is not ModelContent.FunctionResponsePart)); + } else { + return; + } + } + } + // Prepare the message payload Dictionary contentDict = new() { { "turnComplete", turnComplete } diff --git a/firebaseai/src/ModelContent.cs b/firebaseai/src/ModelContent.cs index cee6cdbd..4c152142 100644 --- a/firebaseai/src/ModelContent.cs +++ b/firebaseai/src/ModelContent.cs @@ -103,8 +103,8 @@ public static ModelContent FileData(string mimeType, System.Uri uri) { /// `FunctionResponsePart` containing the given name and args. /// public static ModelContent FunctionResponse( - string name, IDictionary response) { - return new ModelContent(new FunctionResponsePart(name, response)); + string name, IDictionary response, string id = null) { + return new ModelContent(new FunctionResponsePart(name, response, id)); } // TODO: Possibly more, like Multi, Model, FunctionResponses, System (only on Dart?) @@ -236,22 +236,31 @@ Dictionary Part.ToJson() { /// The function parameters and values, matching the registered schema. /// public IReadOnlyDictionary Args { get; } + /// + /// An identifier that should be passed along in the FunctionResponsePart. + /// + public string Id { get; } /// /// Intended for internal use only. /// - internal FunctionCallPart(string name, IDictionary args) { + internal FunctionCallPart(string name, IDictionary args, string id) { Name = name; Args = new Dictionary(args); + Id = id; } Dictionary Part.ToJson() { + var jsonDict = new Dictionary() { + { "name", Name }, + { "args", Args } + }; + if (!string.IsNullOrEmpty(Id)) { + jsonDict["id"] = Id; + } + return new Dictionary() { - { "functionCall", new Dictionary() { - { "name", Name }, - { "args", Args } - } - } + { "functionCall", jsonDict } }; } } @@ -272,24 +281,33 @@ Dictionary Part.ToJson() { /// The function's response or return value. /// public IReadOnlyDictionary Response { get; } + /// + /// The id from the FunctionCallPart this is in response to. + /// + public string Id { get; } /// /// Constructs a new `FunctionResponsePart`. /// /// The name of the function that was called. /// The function's response. - public FunctionResponsePart(string name, IDictionary response) { + /// The id from the FunctionCallPart this is in response to. + public FunctionResponsePart(string name, IDictionary response, string id = null) { Name = name; Response = new Dictionary(response); + Id = id; } Dictionary Part.ToJson() { + var result = new Dictionary() { + { "name", Name }, + { "response", Response } + }; + if (!string.IsNullOrEmpty(Id)) { + result["id"] = Id; + } return new Dictionary() { - { "functionResponse", new Dictionary() { - { "name", Name }, - { "response", Response } - } - } + { "functionResponse", result } }; } } @@ -350,7 +368,8 @@ internal static class ModelContentJsonParsers { internal static ModelContent.FunctionCallPart FunctionCallPartFromJson(Dictionary jsonDict) { return new ModelContent.FunctionCallPart( jsonDict.ParseValue("name", JsonParseOptions.ThrowEverything), - jsonDict.ParseValue>("args", JsonParseOptions.ThrowEverything)); + jsonDict.ParseValue>("args", JsonParseOptions.ThrowEverything), + jsonDict.ParseValue("id")); } }