Skip to content

[Firebase AI] Improve Live API FunctionCalling #1308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions firebaseai/src/LiveSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,28 @@ public async Task SendAsync(
ModelContent? content = null,
bool turnComplete = false,
CancellationToken cancellationToken = default) {
// If the content has FunctionResponseParts, we handle those separately.
if (content.HasValue) {
var functionParts = content?.Parts.OfType<ModelContent.FunctionResponsePart>().ToList();
if (functionParts.Count > 0) {
Dictionary<string, object> toolResponse = new() {
{ "toolResponse", new Dictionary<string, object>() {
{ "functionResponses", functionParts.Select(frPart => (frPart as ModelContent.Part).ToJson()["functionResponse"]).ToList() }
}}
};
var toolResponseBytes = Encoding.UTF8.GetBytes(Json.Serialize(toolResponse));

await InternalSendBytesAsync(new ArraySegment<byte>(toolResponseBytes), cancellationToken);
if (functionParts.Count < content?.Parts.Count) {
// There are other parts to send, so send them with the other method.
content = new ModelContent(role: content?.Role,
parts: content?.Parts.Where(p => p is not ModelContent.FunctionResponsePart));
} else {
return;
}
}
}

// Prepare the message payload
Dictionary<string, object> contentDict = new() {
{ "turnComplete", turnComplete }
Expand Down
49 changes: 34 additions & 15 deletions firebaseai/src/ModelContent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ public static ModelContent FileData(string mimeType, System.Uri uri) {
/// `FunctionResponsePart` containing the given name and args.
/// </summary>
public static ModelContent FunctionResponse(
string name, IDictionary<string, object> response) {
return new ModelContent(new FunctionResponsePart(name, response));
string name, IDictionary<string, object> response, string id = null) {
return new ModelContent(new FunctionResponsePart(name, response, id));
}

// TODO: Possibly more, like Multi, Model, FunctionResponses, System (only on Dart?)
Expand Down Expand Up @@ -236,22 +236,31 @@ Dictionary<string, object> Part.ToJson() {
/// The function parameters and values, matching the registered schema.
/// </summary>
public IReadOnlyDictionary<string, object> Args { get; }
/// <summary>
/// An identifier that should be passed along in the FunctionResponsePart.
/// </summary>
public string Id { get; }

/// <summary>
/// Intended for internal use only.
/// </summary>
internal FunctionCallPart(string name, IDictionary<string, object> args) {
internal FunctionCallPart(string name, IDictionary<string, object> args, string id) {
Name = name;
Args = new Dictionary<string, object>(args);
Id = id;
}

Dictionary<string, object> Part.ToJson() {
var jsonDict = new Dictionary<string, object>() {
{ "name", Name },
{ "args", Args }
};
if (!string.IsNullOrEmpty(Id)) {
jsonDict["id"] = Id;
}

return new Dictionary<string, object>() {
{ "functionCall", new Dictionary<string, object>() {
{ "name", Name },
{ "args", Args }
}
}
{ "functionCall", jsonDict }
};
}
}
Expand All @@ -272,24 +281,33 @@ Dictionary<string, object> Part.ToJson() {
/// The function's response or return value.
/// </summary>
public IReadOnlyDictionary<string, object> Response { get; }
/// <summary>
/// The id from the FunctionCallPart this is in response to.
/// </summary>
public string Id { get; }

/// <summary>
/// Constructs a new `FunctionResponsePart`.
/// </summary>
/// <param name="name">The name of the function that was called.</param>
/// <param name="response">The function's response.</param>
public FunctionResponsePart(string name, IDictionary<string, object> response) {
/// <param name="id">The id from the FunctionCallPart this is in response to.</param>
public FunctionResponsePart(string name, IDictionary<string, object> response, string id = null) {
Name = name;
Response = new Dictionary<string, object>(response);
Id = id;
}

Dictionary<string, object> Part.ToJson() {
var result = new Dictionary<string, object>() {
{ "name", Name },
{ "response", Response }
};
if (!string.IsNullOrEmpty(Id)) {
result["id"] = Id;
}
return new Dictionary<string, object>() {
{ "functionResponse", new Dictionary<string, object>() {
{ "name", Name },
{ "response", Response }
}
}
{ "functionResponse", result }
};
}
}
Expand Down Expand Up @@ -350,7 +368,8 @@ internal static class ModelContentJsonParsers {
internal static ModelContent.FunctionCallPart FunctionCallPartFromJson(Dictionary<string, object> jsonDict) {
return new ModelContent.FunctionCallPart(
jsonDict.ParseValue<string>("name", JsonParseOptions.ThrowEverything),
jsonDict.ParseValue<Dictionary<string, object>>("args", JsonParseOptions.ThrowEverything));
jsonDict.ParseValue<Dictionary<string, object>>("args", JsonParseOptions.ThrowEverything),
jsonDict.ParseValue<string>("id"));
}
}

Expand Down